BryanW commited on
Commit
9c8af91
·
verified ·
1 Parent(s): 5b18d32

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/functional.h +9 -0
  2. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/functional_base.h +480 -0
  3. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/functional_bfloat16.h +652 -0
  4. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/intrinsics.h +6 -0
  5. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec.h +62 -0
  6. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_half_neon.h +627 -0
  7. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_reduced_precision_common_neon.h +316 -0
  8. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_base.h +1537 -0
  9. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_convert.h +84 -0
  10. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_half.h +123 -0
  11. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_mask.h +318 -0
  12. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_n.h +412 -0
  13. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_quant.h +258 -0
  14. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/ADInterpreters.h +43 -0
  15. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchedFallback.h +86 -0
  16. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchedTensorImpl.h +181 -0
  17. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/DynamicLayer.h +129 -0
  18. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/FunctionalizeInterpreter.h +27 -0
  19. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/Interpreter.h +358 -0
  20. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/PlumbingHelper.h +68 -0
  21. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/TensorWrapper.h +108 -0
  22. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h +102 -0
  23. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/packed_params.h +78 -0
  24. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/qnnpack_utils.h +95 -0
  25. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/AtomicAddFloat.h +42 -0
  26. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/ChannelShuffleKernel.h +19 -0
  27. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/DepthwiseConvKernel.h +26 -0
  28. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/Elu.h +79 -0
  29. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/GridSamplerKernel.h +39 -0
  30. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/IndexKernelUtils.h +90 -0
  31. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/Intrinsics.h +38 -0
  32. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/IsContiguous.h +69 -0
  33. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/LogSoftmaxKernelImpl.h +342 -0
  34. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/Loops.h +400 -0
  35. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/MaxUnpoolKernel.h +19 -0
  36. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/PixelShuffleKernel.h +19 -0
  37. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/ReduceUtils.h +242 -0
  38. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.h +32 -0
  39. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/SampledAddmmKernel.h +17 -0
  40. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/SerialStackImpl.h +151 -0
  41. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/SoftmaxKernel.h +33 -0
  42. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/SpmmReduceKernel.h +27 -0
  43. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/StackKernel.h +17 -0
  44. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/UpSampleKernelAVXAntialias.h +1381 -0
  45. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/avx_mathfun.h +527 -0
  46. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/int_mm_kernel.h +43 -0
  47. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/mixed_data_type.h +46 -0
  48. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/utils.h +225 -0
  49. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/zmath.h +255 -0
  50. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/CUDAJitLoops.cuh +332 -0
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/functional.h ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/cpu/vec/functional_base.h>
5
+ #include <ATen/cpu/vec/functional_bfloat16.h>
6
+
7
+ #else
8
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
9
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/functional_base.h ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ // DO NOT DEFINE STATIC DATA IN THIS HEADER!
5
+ // See Note [Do not compile initializers with AVX]
6
+
7
+ #include <ATen/cpu/vec/vec.h>
8
+ #include <c10/util/irange.h>
9
+
10
+ namespace at {
11
+ namespace detail {
12
+ // We prefer to convert through float for reduced-precision floating
13
+ // point types if we have a Vectorized specialization for float and we
14
+ // don't have one for the actual type in question.
15
+ template <typename T>
16
+ struct should_prefer_converting_through_float
17
+ : std::bool_constant<
18
+ is_reduced_floating_point_v<T> &&
19
+ vec::is_vec_specialized_for_v<float> &&
20
+ !vec::is_vec_specialized_for_v<T>> {};
21
+
22
+ template <typename T>
23
+ constexpr auto should_prefer_converting_through_float_v =
24
+ should_prefer_converting_through_float<T>::value;
25
+ } // namespace detail
26
+
27
+ namespace vec {
28
+ // slow path
29
+ template <typename scalar_t, typename Op>
30
+ inline scalar_t vec_reduce_all(
31
+ const Op& vec_fun,
32
+ vec::Vectorized<scalar_t> acc_vec,
33
+ int64_t size) {
34
+ using Vec = vec::Vectorized<scalar_t>;
35
+ scalar_t acc_arr[Vec::size()];
36
+ acc_vec.store(acc_arr);
37
+ for (const auto i : c10::irange(1, size)) {
38
+ std::array<scalar_t, Vec::size()> acc_arr_next = {0};
39
+ acc_arr_next[0] = acc_arr[i];
40
+ Vec acc_vec_next = Vec::loadu(acc_arr_next.data());
41
+ acc_vec = vec_fun(acc_vec, acc_vec_next);
42
+ }
43
+ acc_vec.store(acc_arr);
44
+ return acc_arr[0];
45
+ }
46
+
47
+ template <typename scalar_t, typename Op>
48
+ struct VecReduceAllSIMD {
49
+ static inline scalar_t apply(
50
+ const Op& vec_fun,
51
+ const Vectorized<scalar_t>& acc_vec) {
52
+ return vec_reduce_all(vec_fun, acc_vec, Vectorized<scalar_t>::size());
53
+ }
54
+ };
55
+
56
+ #if defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && \
57
+ !defined(C10_MOBILE)
58
+ #if defined(CPU_CAPABILITY_AVX2)
59
+ template <typename Op>
60
+ struct VecReduceAllSIMD<float, Op> {
61
+ static inline float apply(
62
+ const Op& vec_fun,
63
+ const Vectorized<float>& acc_vec) {
64
+ using Vec = Vectorized<float>;
65
+ Vec v = acc_vec;
66
+ // 128-bit shuffle
67
+ Vec v1 = _mm256_permute2f128_ps(v, v, 0x1);
68
+ v = vec_fun(v, v1);
69
+ // 64-bit shuffle
70
+ v1 = _mm256_shuffle_ps(v, v, 0x4E);
71
+ v = vec_fun(v, v1);
72
+ // 32-bit shuffle
73
+ v1 = _mm256_shuffle_ps(v, v, 0xB1);
74
+ v = vec_fun(v, v1);
75
+ return _mm256_cvtss_f32(v);
76
+ }
77
+ };
78
+ #endif // defined(CPU_CAPABILITY_AVX2)
79
+ #if defined(CPU_CAPABILITY_AVX512)
80
+ template <typename Op>
81
+ struct VecReduceAllSIMD<float, Op> {
82
+ static inline float apply(
83
+ const Op& vec_fun,
84
+ const Vectorized<float>& acc_vec) {
85
+ using Vec = Vectorized<float>;
86
+ Vec v = acc_vec;
87
+ // 256-bit shuffle
88
+ Vec v1 = _mm512_shuffle_f32x4(v, v, 0x4E);
89
+ v = vec_fun(v, v1);
90
+ // 128-bit shuffle
91
+ v1 = _mm512_shuffle_f32x4(v, v, 0xB1);
92
+ v = vec_fun(v, v1);
93
+ // 64-bit shuffle
94
+ v1 = _mm512_shuffle_ps(v, v, 0x4E);
95
+ v = vec_fun(v, v1);
96
+ // 32-bit shuffle
97
+ v1 = _mm512_shuffle_ps(v, v, 0xB1);
98
+ v = vec_fun(v, v1);
99
+ return _mm512_cvtss_f32(v);
100
+ }
101
+ };
102
+ #endif // defined(CPU_CAPABILITY_AVX512)
103
+ #endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) &&
104
+ // !defined(C10_MOBILE)
105
+
106
+ #if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \
107
+ !defined(CPU_CAPABILITY_SVE)
108
+ template <typename Op>
109
+ struct VecReduceAllSIMD<float, Op> {
110
+ static inline float apply(
111
+ const Op& vec_fun,
112
+ const Vectorized<float>& acc_vec) {
113
+ using Vec = Vectorized<float>;
114
+ Vec v = acc_vec;
115
+
116
+ // 64-bit shuffle: [a1+a5, a2+a6, a3+a7, a4+a8, -, -, -, -] -> [a3+a7,
117
+ // a4+a8, a1+a5, a2+a6, -, -, -, -]
118
+ float32x4_t v1_1 = vextq_f32(v, v, 2);
119
+ Vec v1 = v1_1;
120
+ // [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -]
121
+ v = vec_fun(v, v1);
122
+
123
+ // 32-bit shuffle: [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -,
124
+ // -, -, -] -> [a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, -, -, -,
125
+ // -]
126
+ v1_1 = vrev64q_f32(v);
127
+ v1 = v1_1;
128
+ // [a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8,
129
+ // a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, -, -, -, -]
130
+ v = vec_fun(v, v1);
131
+
132
+ return v[0];
133
+ }
134
+ };
135
+
136
+ template <>
137
+ struct VecReduceAllSIMD<float, std::plus<Vectorized<float>>> {
138
+ static inline float apply(
139
+ const std::plus<Vectorized<float>>& vec_fun,
140
+ const Vectorized<float>& acc_vec) {
141
+ return vaddvq_f32(acc_vec);
142
+ }
143
+ };
144
+ #endif // defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__)
145
+ // && !defined(CPU_CAPABILITY_SVE)
146
+
147
+ #if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \
148
+ defined(CPU_CAPABILITY_SVE256)
149
+ template <typename Op>
150
+ struct VecReduceAllSIMD<float, Op> {
151
+ static inline float apply(
152
+ const Op& vec_fun,
153
+ const Vectorized<float>& acc_vec) {
154
+ using Vec = Vectorized<float>;
155
+ Vec v = acc_vec;
156
+ // 128-bit shuffle
157
+ svuint32_t ind = svdupq_n_u32(4, 5, 6, 7);
158
+ Vec v1 = svtbl_f32(v, ind);
159
+ v = vec_fun(v, v1);
160
+ // 64-bit shuffle
161
+ ind = svdupq_n_u32(2, 3, 0, 1);
162
+ v1 = svtbl_f32(v, ind);
163
+ v = vec_fun(v, v1);
164
+ // 32-bit shuffle
165
+ ind = svdupq_n_u32(1, 0, 2, 3);
166
+ v1 = svtbl_f32(v, ind);
167
+ v = vec_fun(v, v1);
168
+ return svlasta(svpfalse(), v);
169
+ }
170
+ };
171
+ #endif // defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__)
172
+ // && defined(CPU_CAPABILITY_SVE256)
173
+
174
+ template <typename scalar_t, typename Op>
175
+ inline scalar_t vec_reduce_all(
176
+ const Op& vec_fun,
177
+ const Vectorized<scalar_t>& acc_vec) {
178
+ return VecReduceAllSIMD<scalar_t, Op>::apply(vec_fun, acc_vec);
179
+ }
180
+
181
+ template <
182
+ typename scalar_t,
183
+ typename Op,
184
+ typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
185
+ inline scalar_t reduce_all(
186
+ const Op& vec_fun,
187
+ const scalar_t* data,
188
+ int64_t size) {
189
+ using Vec = vec::Vectorized<scalar_t>;
190
+ if (size < Vec::size())
191
+ return vec_reduce_all(vec_fun, Vec::loadu(data, size), size);
192
+ int64_t d = Vec::size();
193
+ Vec acc_vec = Vec::loadu(data);
194
+ for (; d < size - (size % Vec::size()); d += Vec::size()) {
195
+ Vec data_vec = Vec::loadu(data + d);
196
+ acc_vec = vec_fun(acc_vec, data_vec);
197
+ }
198
+ if (size - d > 0) {
199
+ Vec data_vec = Vec::loadu(data + d, size - d);
200
+ acc_vec = Vec::set(acc_vec, vec_fun(acc_vec, data_vec), size - d);
201
+ }
202
+ return vec_reduce_all(vec_fun, acc_vec);
203
+ }
204
+
205
+ // similar to reduce_all, but reduces into two outputs
206
+ template <
207
+ typename scalar_t,
208
+ typename Op1,
209
+ typename Op2,
210
+ typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
211
+ inline std::pair<scalar_t, scalar_t> reduce2_all(
212
+ const Op1& vec_fun1,
213
+ const Op2& vec_fun2,
214
+ const scalar_t* data,
215
+ int64_t size) {
216
+ using Vec = vec::Vectorized<scalar_t>;
217
+ if (size < Vec::size()) {
218
+ auto loaded_data = Vec::loadu(data, size);
219
+ return std::pair<scalar_t, scalar_t>(
220
+ vec_reduce_all(vec_fun1, loaded_data, size),
221
+ vec_reduce_all(vec_fun2, loaded_data, size));
222
+ }
223
+ int64_t d = Vec::size();
224
+ Vec acc_vec1 = Vec::loadu(data);
225
+ Vec acc_vec2 = Vec::loadu(data);
226
+ for (; d < size - (size % Vec::size()); d += Vec::size()) {
227
+ Vec data_vec = Vec::loadu(data + d);
228
+ acc_vec1 = vec_fun1(acc_vec1, data_vec);
229
+ acc_vec2 = vec_fun2(acc_vec2, data_vec);
230
+ }
231
+ if (size - d > 0) {
232
+ Vec data_vec = Vec::loadu(data + d, size - d);
233
+ acc_vec1 = Vec::set(acc_vec1, vec_fun1(acc_vec1, data_vec), size - d);
234
+ acc_vec2 = Vec::set(acc_vec2, vec_fun2(acc_vec2, data_vec), size - d);
235
+ }
236
+ return std::pair<scalar_t, scalar_t>(
237
+ vec_reduce_all(vec_fun1, acc_vec1), vec_reduce_all(vec_fun2, acc_vec2));
238
+ }
239
+
240
+ template <
241
+ typename scalar_t,
242
+ typename MapOp,
243
+ typename ReduceOp,
244
+ typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
245
+ inline scalar_t map_reduce_all(
246
+ const MapOp& map_fun,
247
+ const ReduceOp& red_fun,
248
+ const scalar_t* data,
249
+ int64_t size) {
250
+ using Vec = vec::Vectorized<scalar_t>;
251
+ if (size < Vec::size())
252
+ return vec_reduce_all(red_fun, map_fun(Vec::loadu(data, size)), size);
253
+ int64_t d = Vec::size();
254
+ Vec acc_vec = map_fun(Vec::loadu(data));
255
+ for (; d < size - (size % Vec::size()); d += Vec::size()) {
256
+ Vec data_vec = Vec::loadu(data + d);
257
+ data_vec = map_fun(data_vec);
258
+ acc_vec = red_fun(acc_vec, data_vec);
259
+ }
260
+ if (size - d > 0) {
261
+ Vec data_vec = Vec::loadu(data + d, size - d);
262
+ data_vec = map_fun(data_vec);
263
+ acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
264
+ }
265
+ return vec_reduce_all(red_fun, acc_vec);
266
+ }
267
+
268
+ template <
269
+ typename scalar_t,
270
+ typename MapOp,
271
+ typename ReduceOp,
272
+ typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
273
+ inline scalar_t map2_reduce_all(
274
+ const MapOp& map_fun,
275
+ const ReduceOp& red_fun,
276
+ const scalar_t* data,
277
+ const scalar_t* data2,
278
+ int64_t size) {
279
+ using Vec = vec::Vectorized<scalar_t>;
280
+ if (size < Vec::size()) {
281
+ Vec data_vec = Vec::loadu(data, size);
282
+ Vec data2_vec = Vec::loadu(data2, size);
283
+ data_vec = map_fun(data_vec, data2_vec);
284
+ return vec_reduce_all(red_fun, data_vec, size);
285
+ }
286
+ int64_t d = Vec::size();
287
+ Vec acc_vec = map_fun(Vec::loadu(data), Vec::loadu(data2));
288
+ for (; d < size - (size % Vec::size()); d += Vec::size()) {
289
+ Vec data_vec = Vec::loadu(data + d);
290
+ Vec data2_vec = Vec::loadu(data2 + d);
291
+ data_vec = map_fun(data_vec, data2_vec);
292
+ acc_vec = red_fun(acc_vec, data_vec);
293
+ }
294
+ if (size - d > 0) {
295
+ Vec data_vec = Vec::loadu(data + d, size - d);
296
+ Vec data2_vec = Vec::loadu(data2 + d, size - d);
297
+ data_vec = map_fun(data_vec, data2_vec);
298
+ acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
299
+ }
300
+ return vec_reduce_all(red_fun, acc_vec);
301
+ }
302
+
303
+ template <
304
+ typename scalar_t,
305
+ typename MapOp,
306
+ typename ReduceOp,
307
+ typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
308
+ inline scalar_t map3_reduce_all(
309
+ const MapOp& map_fun,
310
+ const ReduceOp& red_fun,
311
+ const scalar_t* data,
312
+ const scalar_t* data2,
313
+ const scalar_t* data3,
314
+ int64_t size) {
315
+ using Vec = vec::Vectorized<scalar_t>;
316
+ if (size < Vec::size()) {
317
+ Vec data_vec = Vec::loadu(data, size);
318
+ Vec data2_vec = Vec::loadu(data2, size);
319
+ Vec data3_vec = Vec::loadu(data3, size);
320
+ data_vec = map_fun(data_vec, data2_vec, data3_vec);
321
+ return vec_reduce_all(red_fun, data_vec, size);
322
+ }
323
+
324
+ int64_t d = Vec::size();
325
+ Vec acc_vec = map_fun(Vec::loadu(data), Vec::loadu(data2), Vec::loadu(data3));
326
+ for (; d < size - (size % Vec::size()); d += Vec::size()) {
327
+ Vec data_vec = Vec::loadu(data + d);
328
+ Vec data2_vec = Vec::loadu(data2 + d);
329
+ Vec data3_vec = Vec::loadu(data3 + d);
330
+ data_vec = map_fun(data_vec, data2_vec, data3_vec);
331
+ acc_vec = red_fun(acc_vec, data_vec);
332
+ }
333
+ if (size - d > 0) {
334
+ Vec data_vec = Vec::loadu(data + d, size - d);
335
+ Vec data2_vec = Vec::loadu(data2 + d, size - d);
336
+ Vec data3_vec = Vec::loadu(data3 + d, size - d);
337
+ data_vec = map_fun(data_vec, data2_vec, data3_vec);
338
+ acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
339
+ }
340
+ return vec_reduce_all(red_fun, acc_vec);
341
+ }
342
+
343
+ template <
344
+ typename scalar_t,
345
+ typename Op,
346
+ typename std::enable_if_t<
347
+ !detail::should_prefer_converting_through_float_v<scalar_t> &&
348
+ std::is_invocable_v<Op, vec::Vectorized<scalar_t>>,
349
+ int> = 0>
350
+ inline void map(
351
+ const Op& vec_fun,
352
+ scalar_t* output_data,
353
+ const scalar_t* input_data,
354
+ int64_t size) {
355
+ using Vec = vec::Vectorized<scalar_t>;
356
+ int64_t d = 0;
357
+ for (; d < size - (size % Vec::size()); d += Vec::size()) {
358
+ Vec output_vec = vec_fun(Vec::loadu(input_data + d));
359
+ output_vec.store(output_data + d);
360
+ }
361
+ if (size - d > 0) {
362
+ Vec output_vec = vec_fun(Vec::loadu(input_data + d, size - d));
363
+ output_vec.store(output_data + d, size - d);
364
+ }
365
+ }
366
+
367
+ template <
368
+ typename scalar_t,
369
+ typename Op,
370
+ typename std::enable_if_t<
371
+ !detail::should_prefer_converting_through_float_v<scalar_t> &&
372
+ std::is_invocable_v<
373
+ Op,
374
+ vec::Vectorized<scalar_t>,
375
+ vec::Vectorized<scalar_t>>,
376
+ int> = 0>
377
+ inline void map2(
378
+ const Op& vec_fun,
379
+ scalar_t* output_data,
380
+ const scalar_t* input_data,
381
+ const scalar_t* input_data2,
382
+ int64_t size) {
383
+ using Vec = vec::Vectorized<scalar_t>;
384
+ int64_t d = 0;
385
+ for (; d < size - (size % Vec::size()); d += Vec::size()) {
386
+ Vec data_vec = Vec::loadu(input_data + d);
387
+ Vec data_vec2 = Vec::loadu(input_data2 + d);
388
+ Vec output_vec = vec_fun(data_vec, data_vec2);
389
+ output_vec.store(output_data + d);
390
+ }
391
+ if (size - d > 0) {
392
+ Vec data_vec = Vec::loadu(input_data + d, size - d);
393
+ Vec data_vec2 = Vec::loadu(input_data2 + d, size - d);
394
+ Vec output_vec = vec_fun(data_vec, data_vec2);
395
+ output_vec.store(output_data + d, size - d);
396
+ }
397
+ }
398
+
399
+ template <
400
+ typename scalar_t,
401
+ typename Op,
402
+ typename std::enable_if_t<
403
+ !detail::should_prefer_converting_through_float_v<scalar_t> &&
404
+ std::is_invocable_v<
405
+ Op,
406
+ vec::Vectorized<scalar_t>,
407
+ vec::Vectorized<scalar_t>,
408
+ vec::Vectorized<scalar_t>>,
409
+ int> = 0>
410
+ inline void map3(
411
+ const Op& vec_fun,
412
+ scalar_t* output_data,
413
+ const scalar_t* input_data1,
414
+ const scalar_t* input_data2,
415
+ const scalar_t* input_data3,
416
+ int64_t size) {
417
+ using Vec = vec::Vectorized<scalar_t>;
418
+ int64_t d = 0;
419
+ for (; d < size - (size % Vec::size()); d += Vec::size()) {
420
+ Vec data_vec1 = Vec::loadu(input_data1 + d);
421
+ Vec data_vec2 = Vec::loadu(input_data2 + d);
422
+ Vec data_vec3 = Vec::loadu(input_data3 + d);
423
+ Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3);
424
+ output_vec.store(output_data + d);
425
+ }
426
+ if (size - d > 0) {
427
+ Vec data_vec1 = Vec::loadu(input_data1 + d, size - d);
428
+ Vec data_vec2 = Vec::loadu(input_data2 + d, size - d);
429
+ Vec data_vec3 = Vec::loadu(input_data3 + d, size - d);
430
+ Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3);
431
+ output_vec.store(output_data + d, size - d);
432
+ }
433
+ }
434
+
435
+ template <
436
+ typename scalar_t,
437
+ typename Op,
438
+ typename std::enable_if_t<
439
+ !detail::should_prefer_converting_through_float_v<scalar_t> &&
440
+ std::is_invocable_v<
441
+ Op,
442
+ vec::Vectorized<scalar_t>,
443
+ vec::Vectorized<scalar_t>,
444
+ vec::Vectorized<scalar_t>,
445
+ vec::Vectorized<scalar_t>>,
446
+ int> = 0>
447
+ inline void map4(
448
+ const Op& vec_fun,
449
+ scalar_t* output_data,
450
+ const scalar_t* input_data1,
451
+ const scalar_t* input_data2,
452
+ const scalar_t* input_data3,
453
+ const scalar_t* input_data4,
454
+ int64_t size) {
455
+ using Vec = vec::Vectorized<scalar_t>;
456
+ int64_t d = 0;
457
+ for (; d < size - (size % Vec::size()); d += Vec::size()) {
458
+ Vec data_vec1 = Vec::loadu(input_data1 + d);
459
+ Vec data_vec2 = Vec::loadu(input_data2 + d);
460
+ Vec data_vec3 = Vec::loadu(input_data3 + d);
461
+ Vec data_vec4 = Vec::loadu(input_data4 + d);
462
+ Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3, data_vec4);
463
+ output_vec.store(output_data + d);
464
+ }
465
+ if (size - d > 0) {
466
+ Vec data_vec1 = Vec::loadu(input_data1 + d, size - d);
467
+ Vec data_vec2 = Vec::loadu(input_data2 + d, size - d);
468
+ Vec data_vec3 = Vec::loadu(input_data3 + d, size - d);
469
+ Vec data_vec4 = Vec::loadu(input_data4 + d, size - d);
470
+ Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3, data_vec4);
471
+ output_vec.store(output_data + d, size - d);
472
+ }
473
+ }
474
+
475
+ } // namespace vec
476
+ } // namespace at
477
+
478
+ #else
479
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
480
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/functional_bfloat16.h ADDED
@@ -0,0 +1,652 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ // DO NOT DEFINE STATIC DATA IN THIS HEADER!
5
+ // See Note [Do not compile initializers with AVX]
6
+
7
+ #include <ATen/cpu/vec/vec.h>
8
+
9
+ namespace at::vec {
10
+ // BFloat16 specification
11
+ template <typename scalar_t>
12
+ struct VecScalarType {
13
+ using type = scalar_t;
14
+ };
15
+ template <>
16
+ struct VecScalarType<BFloat16> {
17
+ using type = float;
18
+ };
19
+ template <>
20
+ struct VecScalarType<Half> {
21
+ using type = float;
22
+ };
23
+
24
+ // This is different from at::acc_type since we only need to specialize BFloat16
25
+ template <typename scalar_t>
26
+ using vec_scalar_t = typename VecScalarType<scalar_t>::type;
27
+
28
+ // Vector conversion between float and bfloat16/half
29
+ template <>
30
+ inline std::tuple<Vectorized<float>, Vectorized<float>> convert_to_float<
31
+ BFloat16>(const Vectorized<BFloat16>& a) {
32
+ return convert_bfloat16_float(a);
33
+ }
34
+
35
+ template <>
36
+ inline std::tuple<Vectorized<float>, Vectorized<float>> convert_to_float<Half>(
37
+ const Vectorized<Half>& a) {
38
+ return convert_half_float(a);
39
+ }
40
+
41
+ template <>
42
+ inline Vectorized<BFloat16> convert_from_float<BFloat16>(
43
+ const Vectorized<float>& a,
44
+ const Vectorized<float>& b) {
45
+ return convert_float_bfloat16(a, b);
46
+ }
47
+
48
+ template <>
49
+ inline Vectorized<Half> convert_from_float<Half>(
50
+ const Vectorized<float>& a,
51
+ const Vectorized<float>& b) {
52
+ return convert_float_half(a, b);
53
+ }
54
+
55
+ template <
56
+ typename scalar_t,
57
+ typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
58
+ inline void load_to_float(
59
+ const scalar_t* data,
60
+ Vectorized<float>& out1,
61
+ Vectorized<float>& out2);
62
+
63
+ template <>
64
+ inline void load_to_float<BFloat16>(
65
+ const BFloat16* data,
66
+ Vectorized<float>& out1,
67
+ Vectorized<float>& out2) {
68
+ load_fp32_from_bf16(data, out1, out2);
69
+ }
70
+
71
+ template <>
72
+ inline void load_to_float<Half>(
73
+ const Half* data,
74
+ Vectorized<float>& out1,
75
+ Vectorized<float>& out2) {
76
+ load_fp32_from_fp16(data, out1, out2);
77
+ }
78
+
79
+ template <
80
+ typename scalar_t,
81
+ typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
82
+ inline void load_to_float(const scalar_t* data, Vectorized<float>& out);
83
+
84
+ template <>
85
+ inline void load_to_float<BFloat16>(
86
+ const BFloat16* data,
87
+ Vectorized<float>& out) {
88
+ load_fp32_from_bf16(data, out);
89
+ }
90
+
91
+ template <>
92
+ inline void load_to_float<Half>(const Half* data, Vectorized<float>& out) {
93
+ load_fp32_from_fp16(data, out);
94
+ }
95
+
96
+ // Note that we already have specialized member of Vectorized<scalar_t> for
97
+ // BFloat16 so the following functions would run smoothly:
98
+ // using Vec = Vectorized<BFloat16>;
99
+ // Vec one = Vec(BFloat16(1));
100
+ // vec::map([](Vec x) { return one / (one + x.exp()); }, y_ptr, x_ptr, N);
101
+ //
102
+ // Then why we still need to specialize "functional"?
103
+ // If we do specialization at Vectorized<> level, the above example would need
104
+ // 3 pairs of conversion of bf16->fp32/fp32->bf16, each for ".exp()", "+" and
105
+ // "/". If we do specialization at vec::map<>() level, we have only 1 pair of
106
+ // conversion of bf16->fp32/fp32->bf16, for the input and output BFloat16
107
+ // vector only.
108
+ //
109
+ // The following BFloat16 functionality will only do data type conversion for
110
+ // input and output vector (reduce functionality will only convert the final
111
+ // scalar back to bf16). Compared to Vectorized<> specialization,
112
+ // 1. better performance since we have less data type conversion;
113
+ // 2. less rounding error since immediate results are kept in fp32;
114
+ // 3. accumulation done on data type of fp32.
115
+ //
116
+ // If you plan to extend this file, please ensure adding unit tests at
117
+ // aten/src/ATen/test/vec_test_all_types.cpp
118
+ //
119
+ template <
120
+ typename scalar_t,
121
+ typename Op,
122
+ typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
123
+ inline float reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) {
124
+ using bVec = vec::Vectorized<scalar_t>;
125
+ using fVec = vec::Vectorized<float>;
126
+ if (size < bVec::size()) {
127
+ bVec data_bvec = bVec::loadu(data, size);
128
+ auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
129
+ if (size > fVec::size()) {
130
+ data_fvec0 = fVec::set(
131
+ data_fvec0, vec_fun(data_fvec0, data_fvec1), size - fVec::size());
132
+ return vec_reduce_all<float>(vec_fun, data_fvec0, fVec::size());
133
+ } else {
134
+ return vec_reduce_all<float>(vec_fun, data_fvec0, size);
135
+ }
136
+ }
137
+ int64_t d = bVec::size();
138
+ bVec acc_bvec = bVec::loadu(data);
139
+ auto [acc_fvec0, acc_fvec1] = convert_to_float<scalar_t>(acc_bvec);
140
+ for (; d < size - (size % bVec::size()); d += bVec::size()) {
141
+ bVec data_bvec = bVec::loadu(data + d);
142
+ auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
143
+ acc_fvec0 = vec_fun(acc_fvec0, data_fvec0);
144
+ acc_fvec1 = vec_fun(acc_fvec1, data_fvec1);
145
+ }
146
+ if (size - d > 0) {
147
+ bVec data_bvec = bVec::loadu(data + d, size - d);
148
+ auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
149
+ if (size - d > fVec::size()) {
150
+ acc_fvec0 = vec_fun(acc_fvec0, data_fvec0);
151
+ acc_fvec1 = fVec::set(
152
+ acc_fvec1, vec_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
153
+ } else {
154
+ acc_fvec0 =
155
+ fVec::set(acc_fvec0, vec_fun(acc_fvec0, data_fvec0), size - d);
156
+ }
157
+ }
158
+ acc_fvec0 = vec_fun(acc_fvec0, acc_fvec1);
159
+ return vec_reduce_all<float>(vec_fun, acc_fvec0);
160
+ }
161
+
162
+ template <
163
+ typename scalar_t,
164
+ typename Op1,
165
+ typename Op2,
166
+ typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
167
+ inline std::pair<float, float> reduce2_all(
168
+ const Op1& vec_fun1,
169
+ const Op2& vec_fun2,
170
+ const scalar_t* data,
171
+ int64_t size) {
172
+ using bVec = vec::Vectorized<scalar_t>;
173
+ using fVec = vec::Vectorized<float>;
174
+ if (size < bVec::size()) {
175
+ bVec data_bvec = bVec::loadu(data, size);
176
+ auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
177
+ if (size > fVec::size()) {
178
+ fVec acc1_fvec = fVec::set(
179
+ data_fvec0, vec_fun1(data_fvec0, data_fvec1), size - fVec::size());
180
+ fVec acc2_fvec = fVec::set(
181
+ data_fvec0, vec_fun2(data_fvec0, data_fvec1), size - fVec::size());
182
+ return std::pair<scalar_t, scalar_t>(
183
+ vec_reduce_all<float>(vec_fun1, acc1_fvec, fVec::size()),
184
+ vec_reduce_all<float>(vec_fun2, acc2_fvec, fVec::size()));
185
+ } else {
186
+ return std::pair<scalar_t, scalar_t>(
187
+ vec_reduce_all<float>(vec_fun1, data_fvec0, size),
188
+ vec_reduce_all<float>(vec_fun2, data_fvec0, size));
189
+ }
190
+ }
191
+ int64_t d = bVec::size();
192
+ bVec acc_bvec = bVec::loadu(data);
193
+ auto [acc1_fvec0, acc1_fvec1] = convert_to_float<scalar_t>(acc_bvec);
194
+ auto [acc2_fvec0, acc2_fvec1] = convert_to_float<scalar_t>(acc_bvec);
195
+ for (; d < size - (size % bVec::size()); d += bVec::size()) {
196
+ bVec data_bvec = bVec::loadu(data + d);
197
+ auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
198
+ acc1_fvec0 = vec_fun1(acc1_fvec0, data_fvec0);
199
+ acc1_fvec1 = vec_fun1(acc1_fvec1, data_fvec1);
200
+ acc2_fvec0 = vec_fun2(acc2_fvec0, data_fvec0);
201
+ acc2_fvec1 = vec_fun2(acc2_fvec1, data_fvec1);
202
+ }
203
+ if (size - d > 0) {
204
+ bVec data_bvec = bVec::loadu(data + d, size - d);
205
+ auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
206
+ if (size - d > fVec::size()) {
207
+ acc1_fvec0 = vec_fun1(acc1_fvec0, data_fvec0);
208
+ acc1_fvec1 = fVec::set(
209
+ acc1_fvec1,
210
+ vec_fun1(acc1_fvec1, data_fvec1),
211
+ size - d - fVec::size());
212
+ acc2_fvec0 = vec_fun2(acc2_fvec0, data_fvec0);
213
+ acc2_fvec1 = fVec::set(
214
+ acc2_fvec1,
215
+ vec_fun2(acc2_fvec1, data_fvec1),
216
+ size - d - fVec::size());
217
+ } else {
218
+ acc1_fvec0 =
219
+ fVec::set(acc1_fvec0, vec_fun1(acc1_fvec0, data_fvec0), size - d);
220
+ acc2_fvec0 =
221
+ fVec::set(acc2_fvec0, vec_fun2(acc2_fvec0, data_fvec0), size - d);
222
+ }
223
+ }
224
+ acc1_fvec0 = vec_fun1(acc1_fvec0, acc1_fvec1);
225
+ acc2_fvec0 = vec_fun2(acc2_fvec0, acc2_fvec1);
226
+ return std::pair<scalar_t, scalar_t>(
227
+ vec_reduce_all<float>(vec_fun1, acc1_fvec0),
228
+ vec_reduce_all<float>(vec_fun2, acc2_fvec0));
229
+ }
230
+
231
+ template <
232
+ typename scalar_t,
233
+ typename MapOp,
234
+ typename ReduceOp,
235
+ typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
236
+ inline float map_reduce_all(
237
+ const MapOp& map_fun,
238
+ const ReduceOp& red_fun,
239
+ const scalar_t* data,
240
+ int64_t size) {
241
+ using bVec = vec::Vectorized<scalar_t>;
242
+ using fVec = vec::Vectorized<float>;
243
+ if (size < bVec::size()) {
244
+ bVec data_bvec = bVec::loadu(data, size);
245
+ auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
246
+ if (size > fVec::size()) {
247
+ data_fvec0 = map_fun(data_fvec0);
248
+ data_fvec1 = map_fun(data_fvec1);
249
+ data_fvec0 = fVec::set(
250
+ data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size());
251
+ return vec_reduce_all<float>(red_fun, data_fvec0, fVec::size());
252
+ } else {
253
+ data_fvec0 = map_fun(data_fvec0);
254
+ return vec_reduce_all<float>(red_fun, data_fvec0, size);
255
+ }
256
+ }
257
+ int64_t d = bVec::size();
258
+ bVec acc_bvec = bVec::loadu(data);
259
+ auto [acc_fvec0, acc_fvec1] = convert_to_float<scalar_t>(acc_bvec);
260
+ acc_fvec0 = map_fun(acc_fvec0);
261
+ acc_fvec1 = map_fun(acc_fvec1);
262
+ for (; d < size - (size % bVec::size()); d += bVec::size()) {
263
+ bVec data_bvec = bVec::loadu(data + d);
264
+ auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
265
+ data_fvec0 = map_fun(data_fvec0);
266
+ data_fvec1 = map_fun(data_fvec1);
267
+ acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
268
+ acc_fvec1 = red_fun(acc_fvec1, data_fvec1);
269
+ }
270
+ if (size - d > 0) {
271
+ bVec data_bvec = bVec::loadu(data + d, size - d);
272
+ auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
273
+ if (size - d > fVec::size()) {
274
+ data_fvec0 = map_fun(data_fvec0);
275
+ data_fvec1 = map_fun(data_fvec1);
276
+ acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
277
+ acc_fvec1 = fVec::set(
278
+ acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
279
+ } else {
280
+ data_fvec0 = map_fun(data_fvec0);
281
+ acc_fvec0 =
282
+ fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d);
283
+ }
284
+ }
285
+ acc_fvec0 = red_fun(acc_fvec0, acc_fvec1);
286
+ return vec_reduce_all<float>(red_fun, acc_fvec0);
287
+ }
288
+
289
+ template <
290
+ typename scalar_t,
291
+ typename MapOp,
292
+ typename ReduceOp,
293
+ typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
294
+ inline float map2_reduce_all(
295
+ const MapOp& map_fun,
296
+ const ReduceOp& red_fun,
297
+ const scalar_t* data,
298
+ const scalar_t* data2,
299
+ int64_t size) {
300
+ using bVec = vec::Vectorized<scalar_t>;
301
+ using fVec = vec::Vectorized<float>;
302
+ if (size < bVec::size()) {
303
+ bVec data_bvec = bVec::loadu(data, size);
304
+ auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
305
+ bVec data2_bvec = bVec::loadu(data2, size);
306
+ auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
307
+ if (size > fVec::size()) {
308
+ data_fvec0 = map_fun(data_fvec0, data2_fvec0);
309
+ data_fvec1 = map_fun(data_fvec1, data2_fvec1);
310
+ data_fvec0 = fVec::set(
311
+ data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size());
312
+ return vec_reduce_all<float>(red_fun, data_fvec0, fVec::size());
313
+ } else {
314
+ data_fvec0 = map_fun(data_fvec0, data2_fvec0);
315
+ return vec_reduce_all<float>(red_fun, data_fvec0, size);
316
+ }
317
+ }
318
+ int64_t d = bVec::size();
319
+ bVec acc_bvec = bVec::loadu(data);
320
+ auto [acc_fvec0, acc_fvec1] = convert_to_float<scalar_t>(acc_bvec);
321
+ bVec acc2_bvec = bVec::loadu(data2);
322
+ auto [acc2_fvec0, acc2_fvec1] = convert_to_float<scalar_t>(acc2_bvec);
323
+ acc_fvec0 = map_fun(acc_fvec0, acc2_fvec0);
324
+ acc_fvec1 = map_fun(acc_fvec1, acc2_fvec1);
325
+ for (; d < size - (size % bVec::size()); d += bVec::size()) {
326
+ bVec data_bvec = bVec::loadu(data + d);
327
+ auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
328
+ bVec data2_bvec = bVec::loadu(data2 + d);
329
+ auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
330
+ data_fvec0 = map_fun(data_fvec0, data2_fvec0);
331
+ data_fvec1 = map_fun(data_fvec1, data2_fvec1);
332
+ acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
333
+ acc_fvec1 = red_fun(acc_fvec1, data_fvec1);
334
+ }
335
+ if (size - d > 0) {
336
+ bVec data_bvec = bVec::loadu(data + d, size - d);
337
+ auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
338
+ bVec data2_bvec = bVec::loadu(data2 + d, size - d);
339
+ auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
340
+ if (size - d > fVec::size()) {
341
+ data_fvec0 = map_fun(data_fvec0, data2_fvec0);
342
+ data_fvec1 = map_fun(data_fvec1, data2_fvec1);
343
+ acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
344
+ acc_fvec1 = fVec::set(
345
+ acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
346
+ } else {
347
+ data_fvec0 = map_fun(data_fvec0, data2_fvec0);
348
+ acc_fvec0 =
349
+ fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d);
350
+ }
351
+ }
352
+ acc_fvec0 = red_fun(acc_fvec0, acc_fvec1);
353
+ return vec_reduce_all<float>(red_fun, acc_fvec0);
354
+ }
355
+
356
+ template <
357
+ typename scalar_t,
358
+ typename MapOp,
359
+ typename ReduceOp,
360
+ typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
361
+ inline float map3_reduce_all(
362
+ const MapOp& map_fun,
363
+ const ReduceOp& red_fun,
364
+ const scalar_t* data,
365
+ const scalar_t* data2,
366
+ const scalar_t* data3,
367
+ int64_t size) {
368
+ using bVec = vec::Vectorized<scalar_t>;
369
+ using fVec = vec::Vectorized<float>;
370
+ if (size < bVec::size()) {
371
+ bVec data_bvec = bVec::loadu(data, size);
372
+ auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
373
+ bVec data2_bvec = bVec::loadu(data2, size);
374
+ auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
375
+ bVec data3_bvec = bVec::loadu(data3, size);
376
+ auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
377
+ if (size > fVec::size()) {
378
+ data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
379
+ data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1);
380
+ data_fvec0 = fVec::set(
381
+ data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size());
382
+ return vec_reduce_all<float>(red_fun, data_fvec0, fVec::size());
383
+ } else {
384
+ data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
385
+ return vec_reduce_all<float>(red_fun, data_fvec0, size);
386
+ }
387
+ }
388
+ int64_t d = bVec::size();
389
+ bVec acc_bvec = bVec::loadu(data);
390
+ auto [acc_fvec0, acc_fvec1] = convert_to_float<scalar_t>(acc_bvec);
391
+ bVec acc2_bvec = bVec::loadu(data2);
392
+ auto [acc2_fvec0, acc2_fvec1] = convert_to_float<scalar_t>(acc2_bvec);
393
+ bVec acc3_bvec = bVec::loadu(data3);
394
+ auto [acc3_fvec0, acc3_fvec1] = convert_to_float<scalar_t>(acc3_bvec);
395
+ acc_fvec0 = map_fun(acc_fvec0, acc2_fvec0, acc3_fvec0);
396
+ acc_fvec1 = map_fun(acc_fvec1, acc2_fvec1, acc3_fvec1);
397
+ for (; d < size - (size % bVec::size()); d += bVec::size()) {
398
+ bVec data_bvec = bVec::loadu(data + d);
399
+ auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
400
+ bVec data2_bvec = bVec::loadu(data2 + d);
401
+ auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
402
+ bVec data3_bvec = bVec::loadu(data3 + d);
403
+ auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
404
+ data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
405
+ data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1);
406
+ acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
407
+ acc_fvec1 = red_fun(acc_fvec1, data_fvec1);
408
+ }
409
+ if (size - d > 0) {
410
+ bVec data_bvec = bVec::loadu(data + d, size - d);
411
+ auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
412
+ bVec data2_bvec = bVec::loadu(data2 + d, size - d);
413
+ auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
414
+ bVec data3_bvec = bVec::loadu(data3 + d, size - d);
415
+ auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
416
+ if (size - d > fVec::size()) {
417
+ data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
418
+ data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1);
419
+ acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
420
+ acc_fvec1 = fVec::set(
421
+ acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
422
+ } else {
423
+ data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
424
+ acc_fvec0 =
425
+ fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d);
426
+ }
427
+ }
428
+ acc_fvec0 = red_fun(acc_fvec0, acc_fvec1);
429
+ return vec_reduce_all<float>(red_fun, acc_fvec0);
430
+ }
431
+
432
+ template <
433
+ typename scalar_t,
434
+ typename Op,
435
+ typename std::enable_if_t<
436
+ !(!detail::should_prefer_converting_through_float_v<scalar_t> &&
437
+ std::is_invocable_v<Op, vec::Vectorized<scalar_t>>),
438
+ int> = 0>
439
+ inline void map(
440
+ const Op& vec_fun,
441
+ scalar_t* output_data,
442
+ const scalar_t* input_data,
443
+ int64_t size) {
444
+ using bVec = vec::Vectorized<scalar_t>;
445
+ using fVec = vec::Vectorized<float>;
446
+ int64_t d = 0;
447
+ for (; d < size - (size % bVec::size()); d += bVec::size()) {
448
+ bVec data_bvec = bVec::loadu(input_data + d);
449
+ auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
450
+ fVec output_fvec0 = vec_fun(data_fvec0);
451
+ fVec output_fvec1 = vec_fun(data_fvec1);
452
+ bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
453
+ output_bvec.store(output_data + d);
454
+ }
455
+ if (size - d > 0) {
456
+ bVec data_bvec = bVec::loadu(input_data + d, size - d);
457
+ auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
458
+ fVec output_fvec0 = vec_fun(data_fvec0);
459
+ fVec output_fvec1 = vec_fun(data_fvec1);
460
+ bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
461
+ output_bvec.store(output_data + d, size - d);
462
+ }
463
+ }
464
+
465
+ template <
466
+ typename scalar_t,
467
+ typename Op,
468
+ typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
469
+ inline void map(
470
+ const Op& vec_fun,
471
+ scalar_t* output_data,
472
+ const float* input_data,
473
+ int64_t size) {
474
+ using bVec = vec::Vectorized<scalar_t>;
475
+ using fVec = vec::Vectorized<float>;
476
+ int64_t d = 0;
477
+ for (; d < size - (size % bVec::size()); d += bVec::size()) {
478
+ fVec data_fvec0 = fVec::loadu(input_data + d);
479
+ fVec data_fvec1 = fVec::loadu(input_data + d + fVec::size());
480
+ fVec output_fvec0 = vec_fun(data_fvec0);
481
+ fVec output_fvec1 = vec_fun(data_fvec1);
482
+ bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
483
+ output_bvec.store(output_data + d);
484
+ }
485
+ if (size - d > 0) {
486
+ fVec data_fvec0, data_fvec1;
487
+ if (size - d > fVec::size()) {
488
+ data_fvec0 = fVec::loadu(input_data + d);
489
+ data_fvec1 =
490
+ fVec::loadu(input_data + d + fVec::size(), size - d - fVec::size());
491
+ } else {
492
+ // choose to align with behaviour of bVec::loadu(ptr, size),
493
+ // which leaves data_fvec1 uninitialized
494
+ data_fvec0 = fVec::loadu(input_data + d, size - d);
495
+ }
496
+ fVec output_fvec0 = vec_fun(data_fvec0);
497
+ fVec output_fvec1 = vec_fun(data_fvec1);
498
+ bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
499
+ output_bvec.store(output_data + d, size - d);
500
+ }
501
+ }
502
+
503
+ template <
504
+ typename scalar_t,
505
+ typename Op,
506
+ typename std::enable_if_t<
507
+ !(!detail::should_prefer_converting_through_float_v<scalar_t> &&
508
+ std::is_invocable_v<
509
+ Op,
510
+ vec::Vectorized<scalar_t>,
511
+ vec::Vectorized<scalar_t>>),
512
+ int> = 0>
513
+ inline void map2(
514
+ const Op& vec_fun,
515
+ scalar_t* output_data,
516
+ const scalar_t* input_data,
517
+ const scalar_t* input_data2,
518
+ int64_t size) {
519
+ using bVec = vec::Vectorized<scalar_t>;
520
+ using fVec = vec::Vectorized<float>;
521
+ int64_t d = 0;
522
+ for (; d < size - (size % bVec::size()); d += bVec::size()) {
523
+ bVec data_bvec = bVec::loadu(input_data + d);
524
+ auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
525
+ bVec data2_bvec = bVec::loadu(input_data2 + d);
526
+ auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
527
+ fVec output_fvec0 = vec_fun(data_fvec0, data2_fvec0);
528
+ fVec output_fvec1 = vec_fun(data_fvec1, data2_fvec1);
529
+ bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
530
+ output_bvec.store(output_data + d);
531
+ }
532
+ if (size - d > 0) {
533
+ bVec data_bvec = bVec::loadu(input_data + d, size - d);
534
+ auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
535
+ bVec data2_bvec = bVec::loadu(input_data2 + d, size - d);
536
+ auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
537
+ fVec output_fvec0 = vec_fun(data_fvec0, data2_fvec0);
538
+ fVec output_fvec1 = vec_fun(data_fvec1, data2_fvec1);
539
+ bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
540
+ output_bvec.store(output_data + d, size - d);
541
+ }
542
+ }
543
+
544
+ template <
545
+ typename scalar_t,
546
+ typename Op,
547
+ typename std::enable_if_t<
548
+ !(!detail::should_prefer_converting_through_float_v<scalar_t> &&
549
+ std::is_invocable_v<
550
+ Op,
551
+ vec::Vectorized<scalar_t>,
552
+ vec::Vectorized<scalar_t>,
553
+ vec::Vectorized<scalar_t>>),
554
+ int> = 0>
555
+ inline void map3(
556
+ const Op& vec_fun,
557
+ scalar_t* output_data,
558
+ const scalar_t* input_data1,
559
+ const scalar_t* input_data2,
560
+ const scalar_t* input_data3,
561
+ int64_t size) {
562
+ using bVec = vec::Vectorized<scalar_t>;
563
+ using fVec = vec::Vectorized<float>;
564
+ int64_t d = 0;
565
+ for (; d < size - (size % bVec::size()); d += bVec::size()) {
566
+ bVec data1_bvec = bVec::loadu(input_data1 + d);
567
+ auto [data1_fvec0, data1_fvec1] = convert_to_float<scalar_t>(data1_bvec);
568
+ bVec data2_bvec = bVec::loadu(input_data2 + d);
569
+ auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
570
+ bVec data3_bvec = bVec::loadu(input_data3 + d);
571
+ auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
572
+ fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0);
573
+ fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1);
574
+ bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
575
+ output_bvec.store(output_data + d);
576
+ }
577
+ if (size - d > 0) {
578
+ bVec data1_bvec = bVec::loadu(input_data1 + d, size - d);
579
+ auto [data1_fvec0, data1_fvec1] = convert_to_float<scalar_t>(data1_bvec);
580
+ bVec data2_bvec = bVec::loadu(input_data2 + d, size - d);
581
+ auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
582
+ bVec data3_bvec = bVec::loadu(input_data3 + d, size - d);
583
+ auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
584
+ fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0);
585
+ fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1);
586
+ bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
587
+ output_bvec.store(output_data + d, size - d);
588
+ }
589
+ }
590
+
591
+ template <
592
+ typename scalar_t,
593
+ typename Op,
594
+ typename std::enable_if_t<
595
+ !(!detail::should_prefer_converting_through_float_v<scalar_t> &&
596
+ std::is_invocable_v<
597
+ Op,
598
+ vec::Vectorized<scalar_t>,
599
+ vec::Vectorized<scalar_t>,
600
+ vec::Vectorized<scalar_t>,
601
+ vec::Vectorized<scalar_t>>),
602
+ int> = 0>
603
+ inline void map4(
604
+ const Op& vec_fun,
605
+ scalar_t* output_data,
606
+ const scalar_t* input_data1,
607
+ const scalar_t* input_data2,
608
+ const scalar_t* input_data3,
609
+ const scalar_t* input_data4,
610
+ int64_t size) {
611
+ using bVec = vec::Vectorized<scalar_t>;
612
+ using fVec = vec::Vectorized<float>;
613
+ int64_t d = 0;
614
+ for (; d < size - (size % bVec::size()); d += bVec::size()) {
615
+ bVec data1_bvec = bVec::loadu(input_data1 + d);
616
+ auto [data1_fvec0, data1_fvec1] = convert_to_float<scalar_t>(data1_bvec);
617
+ bVec data2_bvec = bVec::loadu(input_data2 + d);
618
+ auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
619
+ bVec data3_bvec = bVec::loadu(input_data3 + d);
620
+ auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
621
+ bVec data4_bvec = bVec::loadu(input_data4 + d);
622
+ auto [data4_fvec0, data4_fvec1] = convert_to_float<scalar_t>(data4_bvec);
623
+ fVec output_fvec0 =
624
+ vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0);
625
+ fVec output_fvec1 =
626
+ vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1);
627
+ bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
628
+ output_bvec.store(output_data + d);
629
+ }
630
+ if (size - d > 0) {
631
+ bVec data1_bvec = bVec::loadu(input_data1 + d, size - d);
632
+ auto [data1_fvec0, data1_fvec1] = convert_to_float<scalar_t>(data1_bvec);
633
+ bVec data2_bvec = bVec::loadu(input_data2 + d, size - d);
634
+ auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
635
+ bVec data3_bvec = bVec::loadu(input_data3 + d, size - d);
636
+ auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
637
+ bVec data4_bvec = bVec::loadu(input_data4 + d, size - d);
638
+ auto [data4_fvec0, data4_fvec1] = convert_to_float<scalar_t>(data4_bvec);
639
+ fVec output_fvec0 =
640
+ vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0);
641
+ fVec output_fvec1 =
642
+ vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1);
643
+ bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
644
+ output_bvec.store(output_data + d, size - d);
645
+ }
646
+ }
647
+
648
+ } // namespace at::vec
649
+
650
+ #else
651
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
652
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/intrinsics.h ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #include <torch/headeronly/cpu/vec/intrinsics.h>
3
+
4
+ #else
5
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
6
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec.h ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #if defined(CPU_CAPABILITY_AVX512)
5
+ #include <ATen/cpu/vec/vec512/vec512.h>
6
+ #else
7
+ #include <ATen/cpu/vec/vec128/vec128.h>
8
+ #include <ATen/cpu/vec/vec256/vec256.h>
9
+ #endif
10
+
11
+ namespace at::vec {
12
+ // See Note [CPU_CAPABILITY namespace]
13
+ inline namespace CPU_CAPABILITY {
14
+
15
+ inline Vectorized<bool> convert_to_bool(Vectorized<int8_t> x) {
16
+ __at_align__ bool buffer[x.size()];
17
+ x.ne(Vectorized<int8_t>(0)).store(buffer);
18
+
19
+ Vectorized<bool> ret;
20
+ static_assert(x.size() == ret.size());
21
+ std::memcpy(ret, buffer, ret.size() * sizeof(bool));
22
+ return ret;
23
+ }
24
+
25
+ template <>
26
+ inline Vectorized<bool> Vectorized<bool>::loadu(const void* ptr) {
27
+ // See NOTE [Loading boolean values]
28
+ return convert_to_bool(Vectorized<int8_t>::loadu(ptr));
29
+ }
30
+
31
+ template <>
32
+ inline Vectorized<bool> Vectorized<bool>::loadu(
33
+ const void* ptr,
34
+ int64_t count) {
35
+ // See NOTE [Loading boolean values]
36
+ return convert_to_bool(Vectorized<int8_t>::loadu(ptr, count));
37
+ }
38
+
39
+ template <typename VT>
40
+ struct VecHoldType {
41
+ using hold_type = typename VT::value_type;
42
+ };
43
+
44
+ template <>
45
+ struct VecHoldType<Vectorized<BFloat16>> {
46
+ using hold_type = BFloat16;
47
+ };
48
+
49
+ template <>
50
+ struct VecHoldType<Vectorized<Half>> {
51
+ using hold_type = Half;
52
+ };
53
+
54
+ template <typename VT>
55
+ using vechold_type = typename VecHoldType<VT>::hold_type;
56
+
57
+ } // namespace CPU_CAPABILITY
58
+ } // namespace at::vec
59
+
60
+ #else
61
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
62
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_half_neon.h ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ // DO NOT DEFINE STATIC DATA IN THIS HEADER!
5
+ // See Note [Do not compile initializers with AVX]
6
+
7
+ #include <ATen/cpu/vec/intrinsics.h>
8
+ #include <ATen/cpu/vec/vec128/vec128_convert.h>
9
+ #include <ATen/cpu/vec/vec128/vec128_float_neon.h>
10
+ #include <ATen/cpu/vec/vec128/vec128_reduced_precision_common_neon.h>
11
+ #include <ATen/cpu/vec/vec_base.h>
12
+ #include <c10/util/Half.h>
13
+ #include <c10/util/irange.h>
14
+
15
+ namespace at::vec {
16
+ // See Note [CPU_CAPABILITY namespace]
17
+ inline namespace CPU_CAPABILITY {
18
+
19
+ // Right now contains only aarch64 implementation.
20
+ // Due to follow two reasons aarch32 is not currently supported.
21
+ // 1. Due to difference in ISA been aarch32 and aarch64, intrinsics
22
+ // that work for aarch64 dont work for aarch32.
23
+ // 2. Android NDK r21 has problems with compiling aarch32.
24
+ // Clang seg faults.
25
+ // https://github.com/android/ndk/issues/1248
26
+ // https://bugs.llvm.org/show_bug.cgi?id=45824
27
+ // Most likely we will do aarch32 support with inline asm.
28
+ #if !defined(C10_MOBILE) && defined(__aarch64__)
29
+
30
+ #ifdef __BIG_ENDIAN__
31
+ #error "Big endian is not supported."
32
+ #endif
33
+
34
+ template <int index, bool mask_val>
35
+ struct BlendHalfRegs {
36
+ static float16x8_t impl(
37
+ const float16x8_t& a,
38
+ const float16x8_t& b,
39
+ float16x8_t& res);
40
+ };
41
+
42
+ template <int index>
43
+ struct BlendHalfRegs<index, true> {
44
+ static float16x8_t impl(
45
+ const float16x8_t& a,
46
+ const float16x8_t& b,
47
+ float16x8_t& res) {
48
+ return vsetq_lane_f16(vgetq_lane_f16(b, index), res, index);
49
+ }
50
+ };
51
+
52
+ template <int index>
53
+ struct BlendHalfRegs<index, false> {
54
+ static float16x8_t impl(
55
+ const float16x8_t& a,
56
+ const float16x8_t& b,
57
+ float16x8_t& res) {
58
+ return vsetq_lane_f16(vgetq_lane_f16(a, index), res, index);
59
+ }
60
+ };
61
+
62
+ template <>
63
+ struct is_vec_specialized_for<c10::Half> : std::bool_constant<true> {};
64
+
65
+ // On ARM, Half type supports float16_t->Half constructor and Half->float16_t
66
+ // conversion
67
+ template <>
68
+ class Vectorized<c10::Half> : public Vectorized16<
69
+ float16x8_t,
70
+ c10::Half,
71
+ BlendHalfRegs,
72
+ Vectorized<c10::Half>> {
73
+ using Base = Vectorized16<
74
+ float16x8_t,
75
+ c10::Half,
76
+ BlendHalfRegs,
77
+ Vectorized<c10::Half>>;
78
+ friend Base;
79
+
80
+ private:
81
+ // We use these private map functions to implement various methods
82
+ Vectorized<c10::Half> map_with_vec_float_method(
83
+ Vectorized<float> (Vectorized<float>::*m)() const) const {
84
+ float32x4_t v00 = vcvt_f32_f16(vget_low_f16(values));
85
+ float32x4_t v01 = vcvt_f32_f16(vget_high_f16(values));
86
+ Vectorized<float> mv0 = (Vectorized<float>(v00).*m)();
87
+ Vectorized<float> mv1 = (Vectorized<float>(v01).*m)();
88
+ float16x4_t r00 = vcvt_f16_f32(mv0);
89
+ float16x4_t r01 = vcvt_f16_f32(mv1);
90
+ return Vectorized<c10::Half>(vcombine_f16(r00, r01));
91
+ }
92
+
93
+ Vectorized<c10::Half> map2_with_vec_float_method(
94
+ const Vectorized<c10::Half>& second,
95
+ Vectorized<float> (Vectorized<float>::*m)(const Vectorized<float>&)
96
+ const) const {
97
+ float32x4_t v00 = vcvt_f32_f16(vget_low_f16(values));
98
+ float32x4_t v01 = vcvt_f32_f16(vget_high_f16(values));
99
+ float32x4_t second_v00 = vcvt_f32_f16(vget_low_f16(second.values));
100
+ float32x4_t second_v01 = vcvt_f32_f16(vget_high_f16(second.values));
101
+ Vectorized<float> mv0 =
102
+ (Vectorized<float>(v00).*m)(Vectorized<float>(second_v00));
103
+ Vectorized<float> mv1 =
104
+ (Vectorized<float>(v01).*m)(Vectorized<float>(second_v01));
105
+ float16x4_t r00 = vcvt_f16_f32(mv0);
106
+ float16x4_t r01 = vcvt_f16_f32(mv1);
107
+
108
+ // Pack result into Vectorized<c10::Half>
109
+ return Vectorized<c10::Half>(vcombine_f16(r00, r01));
110
+ }
111
+
112
+ Vectorized<c10::Half> map2_bitmask_with_vec_float_method(
113
+ const Vectorized<c10::Half>& second,
114
+ Vectorized<float> (Vectorized<float>::*m)(const Vectorized<float>&)
115
+ const) const {
116
+ float32x4_t v00 = vcvt_f32_f16(vget_low_f16(values));
117
+ float32x4_t v01 = vcvt_f32_f16(vget_high_f16(values));
118
+ float32x4_t second_v00 = vcvt_f32_f16(vget_low_f16(second.values));
119
+ float32x4_t second_v01 = vcvt_f32_f16(vget_high_f16(second.values));
120
+ Vectorized<float> mv0 =
121
+ (Vectorized<float>(v00).*m)(Vectorized<float>(second_v00));
122
+ Vectorized<float> mv1 =
123
+ (Vectorized<float>(v01).*m)(Vectorized<float>(second_v01));
124
+ // Assume the operator returns a bitmask, not "real" floats, and
125
+ // just narrow the bits. All-ones is a NaN and will get mangled by
126
+ // conversion!
127
+ float16x4_t r00 =
128
+ vreinterpret_f16_u16(vmovn_u32(vreinterpretq_u32_f32(mv0)));
129
+ float16x4_t r01 =
130
+ vreinterpret_f16_u16(vmovn_u32(vreinterpretq_u32_f32(mv1)));
131
+
132
+ // Pack result into Vectorized<c10::Half>
133
+ return Vectorized<c10::Half>(vcombine_f16(r00, r01));
134
+ }
135
+
136
+ public:
137
+ using Vectorized16::Vectorized16;
138
+
139
+ Vectorized() = default;
140
+
141
+ // A ctor that accepts c10::Half is needed to fit interface with vec_base.h
142
+ // A second constructor that takes float16_t is also included
143
+ Vectorized(c10::Half val) : Vectorized((float16_t)val) {}
144
+ Vectorized(float16_t val) : Vectorized16(vdupq_n_f16(val)) {}
145
+ Vectorized(
146
+ value_type val0,
147
+ value_type val1,
148
+ value_type val2,
149
+ value_type val3,
150
+ value_type val4,
151
+ value_type val5,
152
+ value_type val6,
153
+ value_type val7)
154
+ : Vectorized16(
155
+ float16x8_t{val0, val1, val2, val3, val4, val5, val6, val7}) {}
156
+
157
+ static Vectorized<c10::Half> blendv(
158
+ const Vectorized<c10::Half>& a,
159
+ const Vectorized<c10::Half>& b,
160
+ const Vectorized<c10::Half>& mask) {
161
+ // Note: using blendv is very awkward because 0xFFFF is one of
162
+ // many NaN's in FP16 It's unfortunate that the mask has type Half
163
+ // (required from vec_base)
164
+
165
+ // TODO
166
+ // NB: This requires that each value, i.e., each uint value,
167
+ // of the mask either all be zeros or all be 1s.
168
+ // We perhaps need some kind of an assert?
169
+ // But that will affect performance.
170
+
171
+ // NOTE [vbslq_f16]: vbslq_f16 doesn't work on clang without
172
+ // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC. vbslq_u16 generates the
173
+ // same instruction anyway. see https://godbolt.org/z/cY4a55Y7P
174
+ Vectorized<c10::Half> vec(mask.values);
175
+ vec.values = vreinterpretq_f16_u16(vbslq_u16(
176
+ vreinterpretq_u16_f16(vec.values),
177
+ vreinterpretq_u16_f16(b.values),
178
+ vreinterpretq_u16_f16(a.values)));
179
+ return vec;
180
+ }
181
+ static Vectorized<c10::Half> set(
182
+ const Vectorized<c10::Half>& a,
183
+ const Vectorized<c10::Half>& b,
184
+ int64_t count = size()) {
185
+ uint16_t pre_mask[size()] = {0};
186
+ for (int i = 0; i < count; i++) {
187
+ pre_mask[i] = 0xFFFF;
188
+ }
189
+ uint16x8_t mask = vld1q_u16(pre_mask);
190
+
191
+ // Using blendv is awkward because 0xFFFF is one of many NaN's in FP16
192
+ // so we directly use vbslq_u16 instead. (See NOTE [vbslq_f16] above.)
193
+ Vectorized<c10::Half> vec(vreinterpretq_f16_u16(vbslq_u16(
194
+ mask,
195
+ vreinterpretq_u16_f16(b.values),
196
+ vreinterpretq_u16_f16(a.values))));
197
+
198
+ return vec;
199
+ }
200
+ static Vectorized<c10::Half> loadu(const void* ptr, int64_t count = size()) {
201
+ if (count == size()) {
202
+ return vld1q_f16(reinterpret_cast<const float16_t*>(ptr));
203
+ }
204
+ __at_align__ float16_t tmp_values[size()];
205
+ for (const auto i : c10::irange(size())) {
206
+ tmp_values[i] = 0;
207
+ }
208
+ std::memcpy(
209
+ tmp_values,
210
+ reinterpret_cast<const float16_t*>(ptr),
211
+ count * sizeof(float16_t));
212
+ return vld1q_f16(reinterpret_cast<const float16_t*>(tmp_values));
213
+ }
214
+ void store(void* ptr, int64_t count = size()) const {
215
+ if (count == size()) {
216
+ vst1q_f16(reinterpret_cast<float16_t*>(ptr), values);
217
+ return;
218
+ } else {
219
+ float16_t tmp_values[size()];
220
+ vst1q_f16(reinterpret_cast<float16_t*>(tmp_values), values);
221
+ std::memcpy(ptr, tmp_values, count * sizeof(float16_t));
222
+ }
223
+ }
224
+ int zero_mask() const {
225
+ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
226
+ uint16x8_t is_zero_vec = vceqzq_f16(values);
227
+ const int16x8_t shift = vcombine_s16(
228
+ vcreate_s16(
229
+ 0x0 | (int64_t(0x1) << 16) | (int64_t(0x2) << 32) |
230
+ (int64_t(0x3) << 48)),
231
+ vcreate_s16(
232
+ 0x4 | (int64_t(0x5) << 16) | (int64_t(0x6) << 32) |
233
+ (int64_t(0x7) << 48)));
234
+ uint16x8_t bits_vec =
235
+ vshlq_u16(vandq_u16(is_zero_vec, vdupq_n_u16(1)), shift);
236
+ return vaddvq_u16(bits_vec);
237
+ #else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
238
+ // use known working implementation.
239
+ __at_align__ value_type tmp[size()];
240
+ store(tmp);
241
+ int mask = 0;
242
+ for (int i = 0; i < size(); ++i) {
243
+ if (tmp[i] == 0) {
244
+ mask |= (1 << i);
245
+ }
246
+ }
247
+ return mask;
248
+ #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
249
+ }
250
+ Vectorized<c10::Half> isnan() const {
251
+ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
252
+ return vreinterpretq_f16_u16(vmvnq_u16(vceqq_f16(values, values)));
253
+ #else
254
+ // NOTE: we could make this faster by doing vectorized checks of
255
+ // exponent/payload bits.
256
+ __at_align__ c10::Half tmp[size()];
257
+ __at_align__ c10::Half res[size()];
258
+ store(tmp);
259
+ for (const auto i : c10::irange(size())) {
260
+ if (_isnan(tmp[i])) {
261
+ std::memset(static_cast<void*>(&res[i]), 0xFF, sizeof(c10::Half));
262
+ } else {
263
+ std::memset(static_cast<void*>(&res[i]), 0, sizeof(c10::Half));
264
+ }
265
+ }
266
+ return loadu(res);
267
+ #endif
268
+ }
269
+ bool has_inf_nan() const {
270
+ __at_align__ c10::Half tmp[size()];
271
+ store(tmp);
272
+ for (const auto i : c10::irange(size())) {
273
+ if (_isnan(tmp[i]) || _isinf(tmp[i])) {
274
+ return true;
275
+ }
276
+ }
277
+ return false;
278
+ }
279
+ Vectorized<c10::Half> abs() const {
280
+ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
281
+ return Vectorized<c10::Half>(vabsq_f16(values));
282
+ #else
283
+ return map_with_vec_float_method(&Vectorized<float>::abs);
284
+ #endif
285
+ }
286
+ Vectorized<c10::Half> frac() const;
287
+ Vectorized<c10::Half> neg() const {
288
+ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
289
+ return Vectorized<c10::Half>(vnegq_f16(values));
290
+ #else
291
+ return map_with_vec_float_method(&Vectorized<float>::neg);
292
+ #endif
293
+ }
294
+ Vectorized<c10::Half> trunc() const {
295
+ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
296
+ return Vectorized<c10::Half>(vrndq_f16(values));
297
+ #else
298
+ return map_with_vec_float_method(&Vectorized<float>::trunc);
299
+ #endif
300
+ }
301
+ Vectorized<c10::Half> sqrt() const {
302
+ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
303
+ return Vectorized<c10::Half>(vsqrtq_f16(values));
304
+ #else
305
+ return map_with_vec_float_method(&Vectorized<float>::sqrt);
306
+ #endif
307
+ }
308
+ Vectorized<c10::Half> reciprocal() const {
309
+ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
310
+ auto ones = vdupq_n_f16(1.0f);
311
+ return Vectorized<c10::Half>(vdivq_f16(ones, values));
312
+ #else
313
+ return map_with_vec_float_method(&Vectorized<float>::reciprocal);
314
+ #endif
315
+ }
316
+ Vectorized<c10::Half> operator==(const Vectorized<c10::Half>& other) const {
317
+ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
318
+ return Vectorized<c10::Half>(
319
+ vreinterpretq_f16_u16(vceqq_f16(values, other.values)));
320
+ #else
321
+ return map2_bitmask_with_vec_float_method(
322
+ other, &Vectorized<float>::operator==);
323
+ #endif
324
+ }
325
+
326
+ Vectorized<c10::Half> operator!=(const Vectorized<c10::Half>& other) const {
327
+ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
328
+ return Vectorized<c10::Half>(
329
+ vreinterpretq_f16_u16(vmvnq_u16(vceqq_f16(values, other.values))));
330
+ #else
331
+ return map2_bitmask_with_vec_float_method(
332
+ other, &Vectorized<float>::operator!=);
333
+ #endif
334
+ }
335
+
336
+ Vectorized<c10::Half> operator<(const Vectorized<c10::Half>& other) const {
337
+ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
338
+ return Vectorized<c10::Half>(
339
+ vreinterpretq_f16_u16(vcltq_f16(values, other.values)));
340
+ #else
341
+ return map2_bitmask_with_vec_float_method(
342
+ other, &Vectorized<float>::operator<);
343
+ #endif
344
+ }
345
+
346
+ Vectorized<c10::Half> operator<=(const Vectorized<c10::Half>& other) const {
347
+ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
348
+ return Vectorized<c10::Half>(
349
+ vreinterpretq_f16_u16(vcleq_f16(values, other.values)));
350
+ #else
351
+ return map2_bitmask_with_vec_float_method(
352
+ other, &Vectorized<float>::operator<=);
353
+ #endif
354
+ }
355
+
356
+ Vectorized<c10::Half> operator>(const Vectorized<c10::Half>& other) const {
357
+ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
358
+ return Vectorized<c10::Half>(
359
+ vreinterpretq_f16_u16(vcgtq_f16(values, other.values)));
360
+ #else
361
+ return map2_bitmask_with_vec_float_method(
362
+ other, &Vectorized<float>::operator>);
363
+ #endif
364
+ }
365
+
366
+ Vectorized<c10::Half> operator>=(const Vectorized<c10::Half>& other) const {
367
+ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
368
+ return Vectorized<c10::Half>(
369
+ vreinterpretq_f16_u16(vcgeq_f16(values, other.values)));
370
+ #else
371
+ return map2_bitmask_with_vec_float_method(
372
+ other, &Vectorized<float>::operator>=);
373
+ #endif
374
+ }
375
+
376
+ Vectorized<c10::Half> eq(const Vectorized<c10::Half>& other) const;
377
+ Vectorized<c10::Half> ne(const Vectorized<c10::Half>& other) const;
378
+ Vectorized<c10::Half> gt(const Vectorized<c10::Half>& other) const;
379
+ Vectorized<c10::Half> ge(const Vectorized<c10::Half>& other) const;
380
+ Vectorized<c10::Half> lt(const Vectorized<c10::Half>& other) const;
381
+ Vectorized<c10::Half> le(const Vectorized<c10::Half>& other) const;
382
+ }; // Vectorized<Half>
383
+
384
+ inline std::tuple<Vectorized<float>, Vectorized<float>> convert_half_float(
385
+ const Vectorized<Half>& a) {
386
+ static_assert(Vectorized<Half>::size() == 2 * Vectorized<float>::size());
387
+ float16x8_t x = a;
388
+ float32x4_t x1 = vcvt_f32_f16(vget_low_f16(x));
389
+ float32x4_t x2 = vcvt_f32_f16(vget_high_f16(x));
390
+ return {Vectorized<float>(x1), Vectorized<float>(x2)};
391
+ }
392
+ inline Vectorized<Half> convert_float_half(
393
+ const Vectorized<float>& a,
394
+ const Vectorized<float>& b) {
395
+ static_assert(Vectorized<Half>::size() == 2 * Vectorized<float>::size());
396
+ float32x4_t x = a;
397
+ float32x4_t y = b;
398
+ float16x4_t x1 = vcvt_f16_f32(x);
399
+ float16x4_t x2 = vcvt_f16_f32(y);
400
+ return Vectorized<Half>(vcombine_f16(x1, x2));
401
+ }
402
+
403
+ template <typename Op>
404
+ Vectorized<c10::Half> binary_operator_via_float(
405
+ Op op,
406
+ const Vectorized<c10::Half>& a,
407
+ const Vectorized<c10::Half>& b) {
408
+ const auto [a_float_low, a_float_high] = convert_half_float(a);
409
+ const auto [b_float_low, b_float_high] = convert_half_float(b);
410
+ return convert_float_half(
411
+ op(a_float_low, b_float_low), op(a_float_high, b_float_high));
412
+ }
413
+
414
+ template <>
415
+ Vectorized<c10::Half> inline operator+(
416
+ const Vectorized<c10::Half>& a,
417
+ const Vectorized<c10::Half>& b) {
418
+ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
419
+ return Vectorized<c10::Half>(vaddq_f16(a, b));
420
+ #else
421
+ return binary_operator_via_float(std::plus<Vectorized<float>>(), a, b);
422
+ #endif
423
+ }
424
+
425
+ template <>
426
+ Vectorized<c10::Half> inline operator-(
427
+ const Vectorized<c10::Half>& a,
428
+ const Vectorized<c10::Half>& b) {
429
+ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
430
+ return Vectorized<c10::Half>(vsubq_f16(a, b));
431
+ #else
432
+ return binary_operator_via_float(std::minus<Vectorized<float>>(), a, b);
433
+ #endif
434
+ }
435
+
436
+ template <>
437
+ Vectorized<c10::Half> inline operator*(
438
+ const Vectorized<c10::Half>& a,
439
+ const Vectorized<c10::Half>& b) {
440
+ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
441
+ return Vectorized<c10::Half>(vmulq_f16(a, b));
442
+ #else
443
+ return binary_operator_via_float(std::multiplies<Vectorized<float>>(), a, b);
444
+ #endif
445
+ }
446
+
447
+ template <>
448
+ Vectorized<c10::Half> inline operator/(
449
+ const Vectorized<c10::Half>& a,
450
+ const Vectorized<c10::Half>& b) {
451
+ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
452
+ return Vectorized<c10::Half>(vdivq_f16(a, b));
453
+ #else
454
+ return binary_operator_via_float(std::divides<Vectorized<float>>(), a, b);
455
+ #endif
456
+ }
457
+
458
+ // frac. Implement this here so we can use subtraction
459
+ inline Vectorized<c10::Half> Vectorized<c10::Half>::frac() const {
460
+ return *this - this->trunc();
461
+ }
462
+
463
+ // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
464
+ // either input is a NaN.
465
+ template <>
466
+ Vectorized<c10::Half> inline maximum(
467
+ const Vectorized<c10::Half>& a,
468
+ const Vectorized<c10::Half>& b) {
469
+ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
470
+ return Vectorized<c10::Half>(vmaxq_f16(a, b));
471
+ #else
472
+ return binary_operator_via_float(
473
+ static_cast<Vectorized<float> (*)(
474
+ const Vectorized<float>&, const Vectorized<float>&)>(&maximum),
475
+ a,
476
+ b);
477
+ #endif
478
+ }
479
+
480
+ // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
481
+ // either input is a NaN.
482
+ template <>
483
+ Vectorized<c10::Half> inline minimum(
484
+ const Vectorized<c10::Half>& a,
485
+ const Vectorized<c10::Half>& b) {
486
+ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
487
+ return Vectorized<c10::Half>(vminq_f16(a, b));
488
+ #else
489
+ return binary_operator_via_float(
490
+ static_cast<Vectorized<float> (*)(
491
+ const Vectorized<float>&, const Vectorized<float>&)>(&minimum),
492
+ a,
493
+ b);
494
+ #endif
495
+ }
496
+
497
+ template <>
498
+ Vectorized<c10::Half> inline clamp(
499
+ const Vectorized<c10::Half>& a,
500
+ const Vectorized<c10::Half>& min,
501
+ const Vectorized<c10::Half>& max) {
502
+ return minimum(max, maximum(min, a));
503
+ }
504
+
505
+ template <>
506
+ Vectorized<c10::Half> inline clamp_max(
507
+ const Vectorized<c10::Half>& a,
508
+ const Vectorized<c10::Half>& max) {
509
+ return minimum(max, a);
510
+ }
511
+
512
+ template <>
513
+ Vectorized<c10::Half> inline clamp_min(
514
+ const Vectorized<c10::Half>& a,
515
+ const Vectorized<c10::Half>& min) {
516
+ return maximum(min, a);
517
+ }
518
+
519
+ template <>
520
+ Vectorized<c10::Half> inline operator&(
521
+ const Vectorized<c10::Half>& a,
522
+ const Vectorized<c10::Half>& b) {
523
+ return Vectorized<c10::Half>(vreinterpretq_f16_u16(
524
+ vandq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b))));
525
+ }
526
+
527
+ template <>
528
+ Vectorized<c10::Half> inline operator|(
529
+ const Vectorized<c10::Half>& a,
530
+ const Vectorized<c10::Half>& b) {
531
+ return Vectorized<c10::Half>(vreinterpretq_f16_u16(
532
+ vorrq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b))));
533
+ }
534
+
535
+ template <>
536
+ Vectorized<c10::Half> inline operator^(
537
+ const Vectorized<c10::Half>& a,
538
+ const Vectorized<c10::Half>& b) {
539
+ return Vectorized<c10::Half>(vreinterpretq_f16_u16(
540
+ veorq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b))));
541
+ }
542
+
543
+ inline Vectorized<c10::Half> Vectorized<c10::Half>::eq(
544
+ const Vectorized<c10::Half>& other) const {
545
+ return (*this == other) & Vectorized<c10::Half>(1);
546
+ }
547
+
548
+ inline Vectorized<c10::Half> Vectorized<c10::Half>::ne(
549
+ const Vectorized<c10::Half>& other) const {
550
+ return (*this != other) & Vectorized<c10::Half>(1);
551
+ }
552
+
553
+ inline Vectorized<c10::Half> Vectorized<c10::Half>::gt(
554
+ const Vectorized<c10::Half>& other) const {
555
+ return (*this > other) & Vectorized<c10::Half>(1);
556
+ }
557
+
558
+ inline Vectorized<c10::Half> Vectorized<c10::Half>::ge(
559
+ const Vectorized<c10::Half>& other) const {
560
+ return (*this >= other) & Vectorized<c10::Half>(1);
561
+ }
562
+
563
+ inline Vectorized<c10::Half> Vectorized<c10::Half>::lt(
564
+ const Vectorized<c10::Half>& other) const {
565
+ return (*this < other) & Vectorized<c10::Half>(1);
566
+ }
567
+
568
+ inline Vectorized<c10::Half> Vectorized<c10::Half>::le(
569
+ const Vectorized<c10::Half>& other) const {
570
+ return (*this <= other) & Vectorized<c10::Half>(1);
571
+ }
572
+
573
+ template <>
574
+ Vectorized<c10::Half> inline fmadd(
575
+ const Vectorized<c10::Half>& a,
576
+ const Vectorized<c10::Half>& b,
577
+ const Vectorized<c10::Half>& c) {
578
+ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
579
+ return Vectorized<c10::Half>(vfmaq_f16(c, a, b));
580
+ #else
581
+ return a * b + c;
582
+ #endif
583
+ }
584
+
585
+ template <>
586
+ Vectorized<c10::Half> inline fnmadd(
587
+ const Vectorized<c10::Half>& a,
588
+ const Vectorized<c10::Half>& b,
589
+ const Vectorized<c10::Half>& c) {
590
+ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
591
+ return Vectorized<c10::Half>(vfmsq_f16(c, a, b));
592
+ #else
593
+ return -a * b + c;
594
+ #endif
595
+ }
596
+
597
+ template <>
598
+ Vectorized<c10::Half> inline fmsub(
599
+ const Vectorized<c10::Half>& a,
600
+ const Vectorized<c10::Half>& b,
601
+ const Vectorized<c10::Half>& c) {
602
+ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
603
+ return Vectorized<c10::Half>(vnegq_f16(vfmsq_f16(c, a, b)));
604
+ #else
605
+ return a * b - c;
606
+ #endif
607
+ }
608
+
609
+ template <>
610
+ Vectorized<c10::Half> inline fnmsub(
611
+ const Vectorized<c10::Half>& a,
612
+ const Vectorized<c10::Half>& b,
613
+ const Vectorized<c10::Half>& c) {
614
+ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
615
+ return Vectorized<c10::Half>(vnegq_f16(vfmaq_f16(c, a, b)));
616
+ #else
617
+ return -a * b - c;
618
+ #endif
619
+ }
620
+ #endif // !defined(C10_MOBILE) && defined(__aarch64__)
621
+
622
+ } // namespace CPU_CAPABILITY
623
+ } // namespace at::vec
624
+
625
+ #else
626
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
627
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_reduced_precision_common_neon.h ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ // Shared code for bfloat16 and float16.
4
+
5
+ // DO NOT DEFINE STATIC DATA IN THIS HEADER!
6
+ // See Note [Do not compile initializers with AVX]
7
+
8
+ namespace at::vec {
9
+ inline namespace CPU_CAPABILITY {
10
+
11
+ // Shared implementation between Vectorized<c10::Half> and
12
+ // Vectorized<c10::BFloat16>. Uses CRTP to allow derived class
13
+ // customization.
14
+ template <
15
+ typename VecT,
16
+ typename ValueT,
17
+ template <int, bool> typename BlendRegs,
18
+ typename Derived>
19
+ struct Vectorized16 {
20
+ protected:
21
+ VecT values;
22
+
23
+ public:
24
+ using value_type = ValueT;
25
+ using size_type = int;
26
+ static constexpr size_type size() {
27
+ static_assert(sizeof(VecT) == 8 * sizeof(value_type));
28
+ return 8;
29
+ }
30
+
31
+ protected:
32
+ Derived map2(
33
+ const Derived& second,
34
+ value_type (*const f)(value_type, value_type)) const {
35
+ __at_align__ value_type tmp_first[size()];
36
+ __at_align__ value_type tmp_second[size()];
37
+ static_cast<const Derived*>(this)->store(
38
+ tmp_first); // store this to tmp_first
39
+ second.store(tmp_second);
40
+ for (const auto i : c10::irange(size())) {
41
+ tmp_first[i] = f(tmp_first[i], tmp_second[i]);
42
+ }
43
+ return Derived::loadu(tmp_first);
44
+ }
45
+
46
+ public:
47
+ Vectorized16() = default;
48
+ Vectorized16(VecT v) : values(v) {}
49
+
50
+ operator VecT() const {
51
+ return values;
52
+ }
53
+
54
+ template <int64_t mask>
55
+ static Derived blend(const Derived& a, const Derived& b) {
56
+ Derived vec;
57
+ vec.values = BlendRegs < 0,
58
+ (mask & 0x01) != 0 > ::impl(a.values, b.values, vec.values);
59
+ vec.values = BlendRegs < 1,
60
+ (mask & 0x02) != 0 > ::impl(a.values, b.values, vec.values);
61
+ vec.values = BlendRegs < 2,
62
+ (mask & 0x04) != 0 > ::impl(a.values, b.values, vec.values);
63
+ vec.values = BlendRegs < 3,
64
+ (mask & 0x08) != 0 > ::impl(a.values, b.values, vec.values);
65
+
66
+ vec.values = BlendRegs < 4,
67
+ (mask & 0x10) != 0 > ::impl(a.values, b.values, vec.values);
68
+ vec.values = BlendRegs < 5,
69
+ (mask & 0x20) != 0 > ::impl(a.values, b.values, vec.values);
70
+ vec.values = BlendRegs < 6,
71
+ (mask & 0x40) != 0 > ::impl(a.values, b.values, vec.values);
72
+ vec.values = BlendRegs < 7,
73
+ (mask & 0x80) != 0 > ::impl(a.values, b.values, vec.values);
74
+
75
+ return vec;
76
+ }
77
+
78
+ template <typename step_t>
79
+ static Derived arange(
80
+ value_type base = 0,
81
+ step_t step = static_cast<step_t>(1)) {
82
+ const Derived base_vec(base);
83
+ const Derived step_vec(step);
84
+ const Derived step_sizes(
85
+ value_type(0),
86
+ value_type(1),
87
+ value_type(2),
88
+ value_type(3),
89
+ value_type(4),
90
+ value_type(5),
91
+ value_type(6),
92
+ value_type(7));
93
+ return fmadd(step_sizes, step_vec, base_vec);
94
+ }
95
+
96
+ // Very slow implementation of indexing.
97
+ // Only required because vec256_qint refers to this.
98
+ // Once we specialize that implementation for ARM
99
+ // this should be removed. TODO (kimishpatel)
100
+ value_type operator[](int idx) const {
101
+ __at_align__ value_type tmp[size()];
102
+ static_cast<const Derived*>(this)->store(tmp);
103
+ return tmp[idx];
104
+ }
105
+
106
+ int zero_mask() const {
107
+ __at_align__ value_type tmp[size()];
108
+ static_cast<const Derived*>(this)->store(tmp);
109
+ int mask = 0;
110
+ for (int i = 0; i < size(); ++i) {
111
+ if (tmp[i] == 0) {
112
+ mask |= (1 << i);
113
+ }
114
+ }
115
+ return mask;
116
+ }
117
+
118
+ Derived map(value_type (*const f)(value_type)) const {
119
+ __at_align__ value_type tmp[size()];
120
+ static_cast<const Derived*>(this)->store(tmp);
121
+ for (const auto i : c10::irange(size())) {
122
+ tmp[i] = f(tmp[i]);
123
+ }
124
+ return Derived::loadu(tmp);
125
+ }
126
+
127
+ Derived angle() const {
128
+ auto zero = Derived(0);
129
+ auto pi = Derived(c10::pi<value_type>);
130
+ auto tmp =
131
+ Derived::blendv(zero, pi, *static_cast<const Derived*>(this) < zero);
132
+ return Derived::blendv(
133
+ tmp,
134
+ *static_cast<const Derived*>(this),
135
+ static_cast<const Derived*>(this)->isnan());
136
+ }
137
+ Derived real() const {
138
+ return *this;
139
+ }
140
+ Derived imag() const {
141
+ return Derived(0);
142
+ }
143
+ Derived conj() const {
144
+ return *this;
145
+ }
146
+
147
+ // Sleef does not support FP16/BF16, so many math functions are applied by
148
+ // converting to FP32, applying the math function, and then converting back to
149
+ // FP16/BF16.
150
+ Derived acos() const {
151
+ return static_cast<const Derived*>(this)->map_with_vec_float_method(
152
+ &Vectorized<float>::acos);
153
+ }
154
+ Derived acosh() const {
155
+ return static_cast<const Derived*>(this)->map_with_vec_float_method(
156
+ &Vectorized<float>::acosh);
157
+ }
158
+ Derived asin() const {
159
+ return static_cast<const Derived*>(this)->map_with_vec_float_method(
160
+ &Vectorized<float>::asin);
161
+ }
162
+ Derived asinh() const {
163
+ return static_cast<const Derived*>(this)->map_with_vec_float_method(
164
+ &Vectorized<float>::asinh);
165
+ }
166
+ Derived atan() const {
167
+ return static_cast<const Derived*>(this)->map_with_vec_float_method(
168
+ &Vectorized<float>::atan);
169
+ }
170
+ Derived atanh() const {
171
+ return static_cast<const Derived*>(this)->map_with_vec_float_method(
172
+ &Vectorized<float>::atanh);
173
+ }
174
+ Derived atan2(const Derived& exp) const {
175
+ return static_cast<const Derived*>(this)->map2_with_vec_float_method(
176
+ exp, &Vectorized<float>::atan2);
177
+ }
178
+ Derived copysign(const Derived& sign) const {
179
+ return static_cast<const Derived*>(this)->map2_with_vec_float_method(
180
+ sign, &Vectorized<float>::copysign);
181
+ }
182
+ Derived erf() const {
183
+ return static_cast<const Derived*>(this)->map_with_vec_float_method(
184
+ &Vectorized<float>::erf);
185
+ }
186
+ Derived erfc() const {
187
+ return static_cast<const Derived*>(this)->map_with_vec_float_method(
188
+ &Vectorized<float>::erfc);
189
+ }
190
+ Derived erfinv() const {
191
+ return static_cast<const Derived*>(this)->map_with_vec_float_method(
192
+ &Vectorized<float>::erfinv);
193
+ }
194
+ Derived exp() const {
195
+ return static_cast<const Derived*>(this)->map_with_vec_float_method(
196
+ &Vectorized<float>::exp);
197
+ }
198
+ Derived exp2() const {
199
+ return static_cast<const Derived*>(this)->map_with_vec_float_method(
200
+ &Vectorized<float>::exp2);
201
+ }
202
+ Derived expm1() const {
203
+ return static_cast<const Derived*>(this)->map_with_vec_float_method(
204
+ &Vectorized<float>::expm1);
205
+ }
206
+ Derived exp_u20() const {
207
+ return static_cast<const Derived*>(this)->map_with_vec_float_method(
208
+ &Vectorized<float>::exp_u20);
209
+ }
210
+ Derived fexp_u20() const {
211
+ return static_cast<const Derived*>(this)->map_with_vec_float_method(
212
+ &Vectorized<float>::exp_u20);
213
+ }
214
+ Derived fmod(const Derived& q) const {
215
+ // This function is questionable with a conversion, so we use map2
216
+ return map2(q, std::fmod);
217
+ }
218
+ Derived hypot(const Derived& b) const {
219
+ return static_cast<const Derived*>(this)->map2_with_vec_float_method(
220
+ b, &Vectorized<float>::hypot);
221
+ }
222
+ Derived i0() const {
223
+ return static_cast<const Derived*>(this)->map_with_vec_float_method(
224
+ &Vectorized<float>::i0);
225
+ }
226
+ Derived i0e() const {
227
+ return static_cast<const Derived*>(this)->map_with_vec_float_method(
228
+ &Vectorized<float>::i0e);
229
+ }
230
+ Derived digamma() const {
231
+ return static_cast<const Derived*>(this)->map_with_vec_float_method(
232
+ &Vectorized<float>::digamma);
233
+ }
234
+ Derived igamma(const Derived& x) const {
235
+ return static_cast<const Derived*>(this)->map2_with_vec_float_method(
236
+ x, &Vectorized<float>::igamma);
237
+ }
238
+ Derived igammac(const Derived& x) const {
239
+ return static_cast<const Derived*>(this)->map2_with_vec_float_method(
240
+ x, &Vectorized<float>::igammac);
241
+ }
242
+ Derived log() const {
243
+ return static_cast<const Derived*>(this)->map_with_vec_float_method(
244
+ &Vectorized<float>::log);
245
+ }
246
+ Derived log10() const {
247
+ return static_cast<const Derived*>(this)->map_with_vec_float_method(
248
+ &Vectorized<float>::log10);
249
+ }
250
+ Derived log1p() const {
251
+ return static_cast<const Derived*>(this)->map_with_vec_float_method(
252
+ &Vectorized<float>::log1p);
253
+ }
254
+ Derived log2() const {
255
+ return static_cast<const Derived*>(this)->map_with_vec_float_method(
256
+ &Vectorized<float>::log2);
257
+ }
258
+ Derived nextafter(const Derived& b) const {
259
+ // This function does not make sense with conversion, so we use map2
260
+ return map2(b, std::nextafter);
261
+ }
262
+ Derived sin() const {
263
+ return static_cast<const Derived*>(this)->map_with_vec_float_method(
264
+ &Vectorized<float>::sin);
265
+ }
266
+ Derived sinh() const {
267
+ return static_cast<const Derived*>(this)->map_with_vec_float_method(
268
+ &Vectorized<float>::sinh);
269
+ }
270
+ Derived cos() const {
271
+ return static_cast<const Derived*>(this)->map_with_vec_float_method(
272
+ &Vectorized<float>::cos);
273
+ }
274
+ Derived cosh() const {
275
+ return static_cast<const Derived*>(this)->map_with_vec_float_method(
276
+ &Vectorized<float>::cosh);
277
+ }
278
+ Derived ceil() const {
279
+ // This function is questionable with a conversion, so we use map
280
+ return map(at::native::ceil_impl);
281
+ }
282
+ Derived floor() const {
283
+ // This function is questionable with a conversion, so we use map
284
+ return map(at::native::floor_impl);
285
+ }
286
+ Derived round() const {
287
+ // This function is questionable with a conversion, so we use map
288
+ return map(at::native::round_impl);
289
+ }
290
+ Derived tan() const {
291
+ return static_cast<const Derived*>(this)->map_with_vec_float_method(
292
+ &Vectorized<float>::tan);
293
+ }
294
+ Derived tanh() const {
295
+ return static_cast<const Derived*>(this)->map_with_vec_float_method(
296
+ &Vectorized<float>::tanh);
297
+ }
298
+ Derived lgamma() const {
299
+ return static_cast<const Derived*>(this)->map_with_vec_float_method(
300
+ &Vectorized<float>::lgamma);
301
+ }
302
+ Derived rsqrt() const {
303
+ return static_cast<const Derived*>(this)->sqrt().reciprocal();
304
+ }
305
+ Derived pow(const Derived& exp) const {
306
+ return static_cast<const Derived*>(this)->map2_with_vec_float_method(
307
+ exp, &Vectorized<float>::pow);
308
+ }
309
+ };
310
+
311
+ } // namespace CPU_CAPABILITY
312
+ } // namespace at::vec
313
+
314
+ #else
315
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
316
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_base.h ADDED
@@ -0,0 +1,1537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ #if defined(__GNUC__) && __GNUC__ == 10 && __GNUC_MINOR__ <= 2 && \
4
+ defined(__ARM_FEATURE_SVE)
5
+ // Workaround for https: //gcc.gnu.org/bugzilla/show_bug.cgi?id=117161
6
+ #pragma GCC optimize("no-tree-vectorize")
7
+ #endif
8
+
9
+ // DO NOT DEFINE STATIC DATA IN THIS HEADER!
10
+ // See Note [Do not compile initializers with AVX]
11
+ //
12
+ // Note [Do not compile initializers with AVX]
13
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
14
+ // If you define a static initializer in this file, the initialization will use
15
+ // AVX instructions because these object files are compiled with AVX enabled.
16
+ // We need to avoid non-trivial global data in these architecture specific files
17
+ // because there's no way to guard the global initializers with CPU capability
18
+ // detection.
19
+ //
20
+ // See https://github.com/pytorch/pytorch/issues/37577 for an instance
21
+ // of this bug in the past.
22
+
23
+ #include <algorithm>
24
+ #include <array>
25
+ #include <cassert>
26
+ #include <climits>
27
+ #include <cmath>
28
+ #include <cstring>
29
+ #include <functional>
30
+ #include <type_traits>
31
+
32
+ #include <ATen/NumericUtils.h>
33
+ #include <ATen/cpu/vec/intrinsics.h>
34
+ #include <ATen/native/Math.h>
35
+ #include <ATen/native/cpu/zmath.h>
36
+ #include <c10/macros/Macros.h>
37
+ #include <c10/util/BFloat16-math.h>
38
+ #include <c10/util/BFloat16.h>
39
+ #include <c10/util/Half.h>
40
+ #include <c10/util/Load.h>
41
+ #include <c10/util/TypeCast.h>
42
+ #include <c10/util/copysign.h>
43
+ #include <c10/util/irange.h>
44
+
45
+ #if defined(__GNUC__)
46
+ #define __FORCE_INLINE __attribute__((always_inline)) inline
47
+ #elif defined(_MSC_VER)
48
+ #define __FORCE_INLINE __forceinline
49
+ #endif
50
+
51
+ #if defined(_MSC_FULL_VER)
52
+ /*
53
+ https://learn.microsoft.com/en-us/cpp/overview/compiler-versions?view=msvc-170
54
+ Use _MSC_FULL_VER to identify current compiler is msvc,
55
+ Windows llvm will not have this definition.
56
+ */
57
+ #define __msvc_cl__
58
+ #endif
59
+
60
+ // These macros helped us unify vec_base.h
61
+ #ifdef CPU_CAPABILITY_AVX512
62
+ #if defined(__GNUC__)
63
+ #define __at_align__ __attribute__((aligned(64)))
64
+ #elif defined(_WIN32)
65
+ #define __at_align__ __declspec(align(64))
66
+ #else
67
+ #define __at_align__
68
+ #endif
69
+ #define VECTOR_WIDTH 64
70
+ #define int_vector __m512i
71
+ #elif defined(__aarch64__) && \
72
+ !defined(CPU_CAPABILITY_SVE) // CPU_CAPABILITY_AVX512
73
+ // SVE code expects 256-vectors; leave that set for SVE?
74
+ #if defined(__GNUC__)
75
+ #define __at_align__ __attribute__((aligned(16)))
76
+ #elif defined(_WIN32)
77
+ #define __at_align__ __declspec(align(16))
78
+ #else
79
+ #define __at_align__
80
+ #endif
81
+ #define VECTOR_WIDTH 16
82
+ #else // CPU_CAPABILITY_AVX512
83
+ #if defined(__GNUC__)
84
+ #define __at_align__ __attribute__((aligned(32)))
85
+ #elif defined(_WIN32)
86
+ #define __at_align__ __declspec(align(32))
87
+ #else
88
+ #define __at_align__
89
+ #endif
90
+ #define VECTOR_WIDTH 32
91
+ #define int_vector __m256i
92
+ #endif // CPU_CAPABILITY_AVX512
93
+
94
+ namespace at::vec {
95
+ // See Note [CPU_CAPABILITY namespace]
96
+ inline namespace CPU_CAPABILITY {
97
+ // at::Half and at::BFloat16 should be treated as floating point
98
+ template <typename T>
99
+ struct is_floating_point
100
+ : std::integral_constant<
101
+ bool,
102
+ std::is_floating_point_v<T> || std::is_same_v<T, at::Half> ||
103
+ std::is_same_v<T, at::BFloat16>> {};
104
+
105
+ template <typename T>
106
+ constexpr bool is_floating_point_v = is_floating_point<T>::value;
107
+
108
+ template <typename T>
109
+ struct is_reduced_floating_point
110
+ : std::integral_constant<
111
+ bool,
112
+ std::is_same_v<T, at::Half> || std::is_same_v<T, at::BFloat16>> {};
113
+
114
+ template <typename T>
115
+ constexpr bool is_reduced_floating_point_v =
116
+ is_reduced_floating_point<T>::value;
117
+
118
+ template <typename T>
119
+ struct is_8bit_integer
120
+ : std::integral_constant<
121
+ bool,
122
+ std::is_same_v<T, unsigned char> || std::is_same_v<T, signed char>> {
123
+ };
124
+
125
+ template <typename T>
126
+ constexpr bool is_8bit_integer_v = is_8bit_integer<T>::value;
127
+
128
+ template <size_t n>
129
+ struct int_of_size;
130
+
131
+ #define DEFINE_INT_OF_SIZE(int_t) \
132
+ template <> \
133
+ struct int_of_size<sizeof(int_t)> { \
134
+ using type = int_t; \
135
+ }
136
+
137
+ DEFINE_INT_OF_SIZE(int64_t);
138
+ DEFINE_INT_OF_SIZE(int32_t);
139
+ DEFINE_INT_OF_SIZE(int16_t);
140
+ DEFINE_INT_OF_SIZE(int8_t);
141
+
142
+ #undef DEFINE_INT_OF_SIZE
143
+
144
+ template <typename T>
145
+ using int_same_size_t = typename int_of_size<sizeof(T)>::type;
146
+
147
+ /**
148
+ * Detect at compile time whether Vectorized has an explicit
149
+ * specialization for T. (You are required to specialize this type
150
+ * whenever you specialize Vectorized). Useful for generic algorithms
151
+ * to decide whether to rely on a specialization being fast. For
152
+ * example, they might choose to handle reduced-precision floating
153
+ * point types directly if they're supported, or convert through float
154
+ * if not.
155
+ */
156
+ #if defined(__s390x__)
157
+ template <class T, class TEMP = void>
158
+ #else
159
+ template <typename T>
160
+ #endif
161
+ struct is_vec_specialized_for : std::bool_constant<false> {
162
+ };
163
+
164
+ template <typename T>
165
+ constexpr bool is_vec_specialized_for_v = is_vec_specialized_for<T>::value;
166
+
167
+ // NOTE: If you specialize Vectorized on a type, you must define all
168
+ // operations! You must also specialize is_vec_specialized_for for
169
+ // that type.
170
+
171
+ // emulates Vectorized types
172
+ #if defined(__s390x__)
173
+ template <class T, class TEMP = void>
174
+ #else
175
+ template <class T>
176
+ #endif
177
+ struct Vectorized {
178
+ private:
179
+ __at_align__ T values[VECTOR_WIDTH / sizeof(T)];
180
+
181
+ public:
182
+ using value_type = T;
183
+ using size_type = int;
184
+
185
+ static constexpr size_type kSize = VECTOR_WIDTH / sizeof(T);
186
+ static constexpr size_type size() {
187
+ return kSize;
188
+ }
189
+ Vectorized() : values{static_cast<T>(0)} {}
190
+ Vectorized(T val) {
191
+ for (int i = 0; i != size(); i++) {
192
+ values[i] = val;
193
+ }
194
+ }
195
+ template <
196
+ typename... Args,
197
+ typename = std::enable_if_t<(sizeof...(Args) == size())>>
198
+ Vectorized(Args... vals) : values{vals...} {}
199
+ Vectorized(const T (&arr)[kSize]) {
200
+ std::memcpy(values, arr, sizeof(values));
201
+ }
202
+ // This also implies const T& operator[](int idx) const
203
+ inline operator const T*() const {
204
+ return values;
205
+ }
206
+ // This also implies T& operator[](int idx)
207
+ inline operator T*() {
208
+ return values;
209
+ }
210
+ // Return the values as char* for type punning
211
+ auto as_bytes() const -> const char* {
212
+ return reinterpret_cast<const char*>(values);
213
+ }
214
+ template <int64_t mask_>
215
+ static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b) {
216
+ int64_t mask = mask_;
217
+ Vectorized vector;
218
+ for (const auto i : c10::irange(size())) {
219
+ if (mask & 0x01) {
220
+ vector[i] = b[i];
221
+ } else {
222
+ vector[i] = a[i];
223
+ }
224
+ mask = mask >> 1;
225
+ }
226
+ return vector;
227
+ }
228
+ // Workaround for https: //gcc.gnu.org/bugzilla/show_bug.cgi?id=117001
229
+ #if __GNUC__ <= 12 && !defined(__clang__) && defined(__ARM_FEATURE_SVE)
230
+ static Vectorized<T> __attribute__((optimize("-fno-tree-loop-vectorize")))
231
+ blendv(
232
+ const Vectorized<T>& a,
233
+ #else
234
+ static Vectorized<T> blendv(
235
+ const Vectorized<T>& a,
236
+ #endif
237
+ const Vectorized<T>& b,
238
+ const Vectorized<T>& mask) {
239
+ Vectorized vector;
240
+ int_same_size_t<T> buffer[size()];
241
+ mask.store(buffer);
242
+ for (const auto i : c10::irange(size())) {
243
+ if (buffer[i] & 0x01) {
244
+ vector[i] = b[i];
245
+ } else {
246
+ vector[i] = a[i];
247
+ }
248
+ }
249
+ return vector;
250
+ }
251
+ template <typename step_t> // step sometimes requires a higher precision type
252
+ // (e.g., T=int, step_t=double)
253
+ static Vectorized<T> arange(
254
+ T base = static_cast<T>(0),
255
+ step_t step = static_cast<step_t>(1)) {
256
+ Vectorized vector;
257
+ for (const auto i : c10::irange(size())) {
258
+ vector.values[i] = base + i * step;
259
+ }
260
+ return vector;
261
+ }
262
+ static Vectorized<T> set(
263
+ const Vectorized<T>& a,
264
+ const Vectorized<T>& b,
265
+ int64_t count = size()) {
266
+ Vectorized vector;
267
+ for (const auto i : c10::irange(size())) {
268
+ if (i < count) {
269
+ vector[i] = b[i];
270
+ } else {
271
+ vector[i] = a[i];
272
+ }
273
+ }
274
+ return vector;
275
+ }
276
+ static Vectorized<T> loadu(const void* ptr) {
277
+ Vectorized vector;
278
+ std::memcpy(vector.values, ptr, VECTOR_WIDTH);
279
+ return vector;
280
+ }
281
+ static Vectorized<T> loadu(const void* ptr, int64_t count) {
282
+ Vectorized vector;
283
+ std::memcpy(vector.values, ptr, count * sizeof(T));
284
+ return vector;
285
+ }
286
+ static Vectorized<T> loadu_one_fourth(const void* ptr) {
287
+ static_assert(
288
+ std::is_same_v<T, signed char> || std::is_same_v<T, unsigned char>,
289
+ "For byte types only");
290
+ return Vectorized::loadu(ptr, 8);
291
+ }
292
+
293
+ void store(void* ptr, int count = size()) const {
294
+ std::memcpy(ptr, values, count * sizeof(T));
295
+ }
296
+ int zero_mask() const {
297
+ // returns an integer mask where all zero elements are translated to 1-bit
298
+ // and others are translated to 0-bit
299
+ int mask = 0;
300
+ for (int i = 0; i < size(); ++i) {
301
+ if (values[i] == static_cast<T>(0)) {
302
+ mask |= (1 << i);
303
+ }
304
+ }
305
+ return mask;
306
+ }
307
+ Vectorized<T> isnan() const {
308
+ Vectorized<T> vector;
309
+ for (int64_t i = 0; i != size(); i++) {
310
+ if (_isnan(values[i])) {
311
+ std::memset(static_cast<void*>(vector.values + i), 0xFF, sizeof(T));
312
+ } else {
313
+ std::memset(static_cast<void*>(vector.values + i), 0, sizeof(T));
314
+ }
315
+ }
316
+ return vector;
317
+ }
318
+ bool has_inf_nan() const {
319
+ for (int64_t i = 0; i != size(); i++) {
320
+ if (_isnan(values[i]) || _isinf(values[i])) {
321
+ return true;
322
+ }
323
+ }
324
+ return false;
325
+ }
326
+ // MSVC versions between 14.36 and 14.42 has a loop unrolling bug on Windows
327
+ // Arm64
328
+ // See
329
+ // https://developercommunity.visualstudio.com/t/MSVC-loop-unrolling-problem-194033813-/10720692
330
+ #if defined(_WIN32) && defined(__aarch64__) && \
331
+ ((_MSVC_VER >= 1936) && (_MSVC_VER <= 1942))
332
+ Vectorized<T> map(T (*const f)(T)) const {
333
+ Vectorized<T> ret;
334
+ for (int64_t i = 0; i < size(); i++) {
335
+ ret[i] = f(values[i]);
336
+ if (++i < size())
337
+ ret[i] = f(values[i]);
338
+ }
339
+ return ret;
340
+ }
341
+ T reduce(T (*const f)(T)) const {
342
+ T ret = 0;
343
+ for (int64_t i = 0; i < size(); i++) {
344
+ ret = f(ret, values[i]);
345
+ if (++i < size())
346
+ ret = f(ret, values[i]);
347
+ }
348
+ return ret;
349
+ }
350
+ #else
351
+ Vectorized<T> map(T (*const f)(T)) const {
352
+ Vectorized<T> ret;
353
+ for (int64_t i = 0; i != size(); i++) {
354
+ ret[i] = f(values[i]);
355
+ }
356
+ return ret;
357
+ }
358
+ T reduce(T (*const f)(T)) const {
359
+ T ret = 0;
360
+ for (int64_t i = 0; i != size(); i++) {
361
+ ret = f(ret, values[i]);
362
+ }
363
+ return ret;
364
+ }
365
+ #endif
366
+ Vectorized<T> map(T (*const f)(const T&)) const {
367
+ Vectorized<T> ret;
368
+ for (int64_t i = 0; i != size(); i++) {
369
+ ret[i] = f(values[i]);
370
+ }
371
+ return ret;
372
+ }
373
+ T reduce(T (*const f)(const T&)) const {
374
+ T ret = 0;
375
+ for (int64_t i = 0; i != size(); i++) {
376
+ ret = f(ret, values[i]);
377
+ }
378
+ return ret;
379
+ }
380
+ template <
381
+ typename other_t_abs = T,
382
+ typename std::enable_if_t<
383
+ !is_floating_point_v<other_t_abs> &&
384
+ !c10::is_complex<other_t_abs>::value,
385
+ int> = 0>
386
+ Vectorized<T> abs() const {
387
+ // other_t_abs is for SFINAE and clarity. Make sure it is not changed.
388
+ static_assert(std::is_same_v<other_t_abs, T>, "other_t_abs must be T");
389
+ return map([](T x) -> T { return x < static_cast<T>(0) ? -x : x; });
390
+ }
391
+ template <
392
+ typename float_t_abs = T,
393
+ typename std::enable_if_t<is_floating_point_v<float_t_abs>, int> = 0>
394
+ Vectorized<T> abs() const {
395
+ // float_t_abs is for SFINAE and clarity. Make sure it is not changed.
396
+ static_assert(std::is_same_v<float_t_abs, T>, "float_t_abs must be T");
397
+ // Specifically deal with floating-point because the generic code above
398
+ // won't handle -0.0 (which should result in 0.0) properly.
399
+ return map([](T x) -> T { return std::abs(x); });
400
+ }
401
+ template <
402
+ typename complex_t_abs = T,
403
+ typename std::enable_if_t<c10::is_complex<complex_t_abs>::value, int> = 0>
404
+ Vectorized<T> abs() const {
405
+ // complex_t_abs is for SFINAE and clarity. Make sure it is not changed.
406
+ static_assert(std::is_same_v<complex_t_abs, T>, "complex_t_abs must be T");
407
+ // Specifically map() does not perform the type conversion needed by abs.
408
+ return map([](T x) { return static_cast<T>(std::abs(x)); });
409
+ }
410
+
411
+ template <
412
+ typename other_t_sgn = T,
413
+ typename std::enable_if_t<c10::is_complex<other_t_sgn>::value, int> = 0>
414
+ Vectorized<T> sgn() const {
415
+ return map(at::native::sgn_impl);
416
+ }
417
+
418
+ template <
419
+ typename other_t_angle = T,
420
+ typename std::enable_if_t<!c10::is_complex<other_t_angle>::value, int> =
421
+ 0>
422
+ Vectorized<T> angle() const {
423
+ // other_t_angle is for SFINAE and clarity. Make sure it is not changed.
424
+ static_assert(std::is_same_v<other_t_angle, T>, "other_t_angle must be T");
425
+ return map(at::native::angle_impl<T>); // compiler is unable to resolve the
426
+ // overload without <T>
427
+ }
428
+ template <
429
+ typename complex_t_angle = T,
430
+ typename std::enable_if_t<c10::is_complex<complex_t_angle>::value, int> =
431
+ 0>
432
+ Vectorized<T> angle() const {
433
+ // complex_t_angle is for SFINAE and clarity. Make sure it is not changed.
434
+ static_assert(
435
+ std::is_same_v<complex_t_angle, T>, "complex_t_angle must be T");
436
+ return map([](T x) { return static_cast<T>(std::arg(x)); });
437
+ }
438
+ template <
439
+ typename other_t_real = T,
440
+ typename std::enable_if_t<!c10::is_complex<other_t_real>::value, int> = 0>
441
+ Vectorized<T> real() const {
442
+ // other_t_real is for SFINAE and clarity. Make sure it is not changed.
443
+ static_assert(std::is_same_v<other_t_real, T>, "other_t_real must be T");
444
+ return *this;
445
+ }
446
+ template <
447
+ typename complex_t_real = T,
448
+ typename std::enable_if_t<c10::is_complex<complex_t_real>::value, int> =
449
+ 0>
450
+ Vectorized<T> real() const {
451
+ // complex_t_real is for SFINAE and clarity. Make sure it is not changed.
452
+ static_assert(
453
+ std::is_same_v<complex_t_real, T>, "complex_t_real must be T");
454
+ return map([](T x) { return static_cast<T>(x.real()); });
455
+ }
456
+ template <
457
+ typename other_t_imag = T,
458
+ typename std::enable_if_t<!c10::is_complex<other_t_imag>::value, int> = 0>
459
+ Vectorized<T> imag() const {
460
+ // other_t_imag is for SFINAE and clarity. Make sure it is not changed.
461
+ static_assert(std::is_same_v<other_t_imag, T>, "other_t_imag must be T");
462
+ return Vectorized(0);
463
+ }
464
+ template <
465
+ typename complex_t_imag = T,
466
+ typename std::enable_if_t<c10::is_complex<complex_t_imag>::value, int> =
467
+ 0>
468
+ Vectorized<T> imag() const {
469
+ // complex_t_imag is for SFINAE and clarity. Make sure it is not changed.
470
+ static_assert(
471
+ std::is_same_v<complex_t_imag, T>, "complex_t_imag must be T");
472
+ return map([](T x) { return static_cast<T>(x.imag()); });
473
+ }
474
+ template <
475
+ typename other_t_conj = T,
476
+ typename std::enable_if_t<!c10::is_complex<other_t_conj>::value, int> = 0>
477
+ Vectorized<T> conj() const {
478
+ // other_t_conj is for SFINAE and clarity. Make sure it is not changed.
479
+ static_assert(std::is_same_v<other_t_conj, T>, "other_t_conj must be T");
480
+ return *this;
481
+ }
482
+ template <
483
+ typename complex_t_conj = T,
484
+ typename std::enable_if_t<c10::is_complex<complex_t_conj>::value, int> =
485
+ 0>
486
+ Vectorized<T> conj() const {
487
+ // complex_t_conj is for SFINAE and clarity. Make sure it is not changed.
488
+ static_assert(
489
+ std::is_same_v<complex_t_conj, T>, "complex_t_conj must be T");
490
+ return map([](T x) { return static_cast<T>(std::conj(x)); });
491
+ }
492
+ Vectorized<T> acos() const {
493
+ return map(std::acos);
494
+ }
495
+ Vectorized<T> acosh() const {
496
+ return map(std::acosh);
497
+ }
498
+ Vectorized<T> asin() const {
499
+ return map(std::asin);
500
+ }
501
+ Vectorized<T> asinh() const {
502
+ return map(std::asinh);
503
+ }
504
+ Vectorized<T> atan() const {
505
+ return map(std::atan);
506
+ }
507
+ Vectorized<T> atanh() const {
508
+ return map(std::atanh);
509
+ }
510
+ Vectorized<T> atan2(const Vectorized<T>& exp) const {
511
+ Vectorized<T> ret;
512
+ for (const auto i : c10::irange(size())) {
513
+ ret[i] = std::atan2(values[i], exp[i]);
514
+ }
515
+ return ret;
516
+ }
517
+ template <
518
+ typename U = T,
519
+ typename std::enable_if_t<is_floating_point_v<U>, int> = 0>
520
+ Vectorized<T> copysign(const Vectorized<T>& sign) const {
521
+ Vectorized<T> ret;
522
+ for (size_type i = 0; i < size(); i++) {
523
+ ret[i] = c10::copysign(values[i], sign[i]);
524
+ }
525
+ return ret;
526
+ }
527
+ Vectorized<T> erf() const {
528
+ return map(std::erf);
529
+ }
530
+ Vectorized<T> erfc() const {
531
+ return map(std::erfc);
532
+ }
533
+ Vectorized<T> erfinv() const {
534
+ return map(calc_erfinv);
535
+ }
536
+ Vectorized<T> exp() const {
537
+ return map(std::exp);
538
+ }
539
+ Vectorized<T> exp2() const {
540
+ return map(exp2_impl);
541
+ }
542
+ Vectorized<T> expm1() const {
543
+ return map(std::expm1);
544
+ }
545
+ Vectorized<T> exp_u20() const {
546
+ return map(std::exp);
547
+ }
548
+ Vectorized<T> fexp_u20() const {
549
+ return map(std::exp);
550
+ }
551
+ Vectorized<T> frac() const {
552
+ return *this - this->trunc();
553
+ }
554
+ template <
555
+ typename U = T,
556
+ typename std::enable_if_t<is_floating_point_v<U>, int> = 0>
557
+ Vectorized<T> fmod(const Vectorized<T>& q) const {
558
+ // U is for SFINAE purposes only. Make sure it is not changed.
559
+ static_assert(std::is_same_v<U, T>, "U must be T");
560
+ Vectorized<T> ret;
561
+ for (const auto i : c10::irange(size())) {
562
+ ret[i] = std::fmod(values[i], q[i]);
563
+ }
564
+ return ret;
565
+ }
566
+ Vectorized<T> log() const {
567
+ return map(std::log);
568
+ }
569
+ Vectorized<T> log10() const {
570
+ return map(std::log10);
571
+ }
572
+ Vectorized<T> log1p() const {
573
+ return map(std::log1p);
574
+ }
575
+ template <
576
+ typename other_t_log2 = T,
577
+ typename std::enable_if_t<!c10::is_complex<other_t_log2>::value, int> = 0>
578
+ Vectorized<T> log2() const {
579
+ // other_t_log2 is for SFINAE and clarity. Make sure it is not changed.
580
+ static_assert(std::is_same_v<other_t_log2, T>, "other_t_log2 must be T");
581
+ return map(std::log2);
582
+ }
583
+ template <
584
+ typename complex_t_log2 = T,
585
+ typename std::enable_if_t<c10::is_complex<complex_t_log2>::value, int> =
586
+ 0>
587
+ Vectorized<T> log2() const {
588
+ // complex_t_log2 is for SFINAE and clarity. Make sure it is not changed.
589
+ static_assert(
590
+ std::is_same_v<complex_t_log2, T>, "complex_t_log2 must be T");
591
+ const T log_2 = T(std::log(2.0));
592
+ return Vectorized(map(std::log)) / Vectorized(log_2);
593
+ }
594
+ Vectorized<T> ceil() const {
595
+ return map(at::native::ceil_impl);
596
+ }
597
+ Vectorized<T> cos() const {
598
+ return map(std::cos);
599
+ }
600
+ Vectorized<T> cosh() const {
601
+ return map(std::cosh);
602
+ }
603
+ Vectorized<T> floor() const {
604
+ return map(at::native::floor_impl);
605
+ }
606
+ Vectorized<T> hypot(const Vectorized<T>& b) const {
607
+ Vectorized<T> ret;
608
+ for (const auto i : c10::irange(size())) {
609
+ ret[i] = std::hypot(values[i], b[i]);
610
+ }
611
+ return ret;
612
+ }
613
+ Vectorized<T> i0() const {
614
+ return map(calc_i0);
615
+ }
616
+ Vectorized<T> i0e() const {
617
+ return map(calc_i0e);
618
+ }
619
+ Vectorized<T> digamma() const {
620
+ return map(calc_digamma);
621
+ }
622
+ Vectorized<T> igamma(const Vectorized<T>& x) const {
623
+ Vectorized<T> ret;
624
+ for (const auto i : c10::irange(size())) {
625
+ ret[i] = calc_igamma(values[i], x[i]);
626
+ }
627
+ return ret;
628
+ }
629
+ Vectorized<T> igammac(const Vectorized<T>& x) const {
630
+ Vectorized<T> ret;
631
+ for (const auto i : c10::irange(size())) {
632
+ ret[i] = calc_igammac(values[i], x[i]);
633
+ }
634
+ return ret;
635
+ }
636
+ Vectorized<T> neg() const {
637
+ // NB: the trailing return type is needed because we need to coerce the
638
+ // return value back to T in the case of unary operator- incurring a
639
+ // promotion
640
+ return map([](T x) -> T { return -x; });
641
+ }
642
+ Vectorized<T> nextafter(const Vectorized<T>& b) const {
643
+ Vectorized<T> ret;
644
+ for (const auto i : c10::irange(size())) {
645
+ ret[i] = std::nextafter(values[i], b[i]);
646
+ }
647
+ return ret;
648
+ }
649
+ Vectorized<T> round() const {
650
+ // We do not use std::round because we would like to round midway numbers to
651
+ // the nearest even integer.
652
+ return map(at::native::round_impl);
653
+ }
654
+ Vectorized<T> sin() const {
655
+ return map(std::sin);
656
+ }
657
+ Vectorized<T> sinh() const {
658
+ return map(std::sinh);
659
+ }
660
+ Vectorized<T> tan() const {
661
+ return map(std::tan);
662
+ }
663
+ Vectorized<T> tanh() const {
664
+ return map(std::tanh);
665
+ }
666
+ Vectorized<T> trunc() const {
667
+ return map(at::native::trunc_impl);
668
+ }
669
+ Vectorized<T> lgamma() const {
670
+ return map(std::lgamma);
671
+ }
672
+ Vectorized<T> sqrt() const {
673
+ return map(std::sqrt);
674
+ }
675
+ Vectorized<T> reciprocal() const {
676
+ return map([](T x) { return (T)1 / x; });
677
+ }
678
+ Vectorized<T> rsqrt() const {
679
+ return map([](T x) { return (T)1 / std::sqrt(x); });
680
+ }
681
+ Vectorized<T> pow(const Vectorized<T>& exp) const {
682
+ Vectorized<T> ret;
683
+ for (const auto i : c10::irange(size())) {
684
+ ret[i] = std::pow(values[i], exp[i]);
685
+ }
686
+ return ret;
687
+ }
688
+ T reduce_add() const {
689
+ return reduce([](T x, T y) -> T { return x + y; });
690
+ }
691
+ T reduce_max() const {
692
+ return reduce(std::max);
693
+ }
694
+
695
+ private:
696
+ template <typename Op>
697
+ inline Vectorized<T> binary_pred(const Vectorized<T>& other, Op op) const {
698
+ // All bits are set to 1 if the pred is true, otherwise 0.
699
+ Vectorized<T> vector;
700
+ for (int64_t i = 0; i != size(); i++) {
701
+ if (op(values[i], other.values[i])) {
702
+ std::memset(static_cast<void*>(vector.values + i), 0xFF, sizeof(T));
703
+ } else {
704
+ std::memset(static_cast<void*>(vector.values + i), 0, sizeof(T));
705
+ }
706
+ }
707
+ return vector;
708
+ }
709
+
710
+ public:
711
+ Vectorized<T> operator==(const Vectorized<T>& other) const {
712
+ return binary_pred(other, std::equal_to<T>());
713
+ }
714
+ Vectorized<T> operator!=(const Vectorized<T>& other) const {
715
+ return binary_pred(other, std::not_equal_to<T>());
716
+ }
717
+ Vectorized<T> operator>=(const Vectorized<T>& other) const {
718
+ return binary_pred(other, std::greater_equal<T>());
719
+ }
720
+ Vectorized<T> operator<=(const Vectorized<T>& other) const {
721
+ return binary_pred(other, std::less_equal<T>());
722
+ }
723
+ Vectorized<T> operator>(const Vectorized<T>& other) const {
724
+ return binary_pred(other, std::greater<T>());
725
+ }
726
+ Vectorized<T> operator<(const Vectorized<T>& other) const {
727
+ return binary_pred(other, std::less<T>());
728
+ }
729
+
730
+ private:
731
+ template <typename Op>
732
+ inline Vectorized<T> binary_pred_bool(const Vectorized<T>& other, Op op)
733
+ const {
734
+ // 1 if the pred is true, otherwise 0.
735
+ Vectorized<T> vector;
736
+ for (int i = 0; i != size(); ++i) {
737
+ vector[i] = static_cast<T>(op(values[i], other.values[i]));
738
+ }
739
+ return vector;
740
+ }
741
+
742
+ public:
743
+ Vectorized<T> eq(const Vectorized<T>& other) const {
744
+ return binary_pred_bool(other, std::equal_to<T>());
745
+ }
746
+ Vectorized<T> ne(const Vectorized<T>& other) const {
747
+ return binary_pred_bool(other, std::not_equal_to<T>());
748
+ }
749
+ Vectorized<T> gt(const Vectorized<T>& other) const {
750
+ return binary_pred_bool(other, std::greater<T>());
751
+ }
752
+ Vectorized<T> ge(const Vectorized<T>& other) const {
753
+ return binary_pred_bool(other, std::greater_equal<T>());
754
+ }
755
+ Vectorized<T> lt(const Vectorized<T>& other) const {
756
+ return binary_pred_bool(other, std::less<T>());
757
+ }
758
+ Vectorized<T> le(const Vectorized<T>& other) const {
759
+ return binary_pred_bool(other, std::less_equal<T>());
760
+ }
761
+ };
762
+
763
+ template <class T>
764
+ Vectorized<T> inline operator-(const Vectorized<T>& a) {
765
+ return a.neg();
766
+ }
767
+
768
+ // There is an implicit conversion that would make this work if
769
+ // these operators weren't template functions, but they are template
770
+ // functions (and can't be moved to be non-member friends defined in
771
+ // the class body as suggested in
772
+ // https://stackoverflow.com/questions/9787593/implicit-type-conversion-with-template/9788255#9788255
773
+ // because we have a lot of disparate specializations of
774
+ // Vectorized). So, just explicitly make scalars work.
775
+ #define VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(name) \
776
+ template <class T> \
777
+ Vectorized<T> inline name(const Vectorized<T>& a, T b) { \
778
+ return name(a, Vectorized<T>(b)); \
779
+ } \
780
+ template <class T> \
781
+ Vectorized<T> inline name(T a, const Vectorized<T>& b) { \
782
+ return name(Vectorized<T>(a), b); \
783
+ }
784
+ #define VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(op) \
785
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(operator op)
786
+
787
+ template <class T>
788
+ Vectorized<T> inline operator+(const Vectorized<T>& a, const Vectorized<T>& b) {
789
+ Vectorized<T> c;
790
+ for (int i = 0; i != Vectorized<T>::size(); i++) {
791
+ c[i] = a[i] + b[i];
792
+ }
793
+ return c;
794
+ }
795
+
796
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(+)
797
+
798
+ template <class T>
799
+ Vectorized<T> inline operator-(const Vectorized<T>& a, const Vectorized<T>& b) {
800
+ Vectorized<T> c;
801
+ for (int i = 0; i != Vectorized<T>::size(); i++) {
802
+ c[i] = a[i] - b[i];
803
+ }
804
+ return c;
805
+ }
806
+
807
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(-)
808
+
809
+ template <class T>
810
+ Vectorized<T> inline operator*(const Vectorized<T>& a, const Vectorized<T>& b) {
811
+ Vectorized<T> c;
812
+ for (int i = 0; i != Vectorized<T>::size(); i++) {
813
+ c[i] = a[i] * b[i];
814
+ }
815
+ return c;
816
+ }
817
+
818
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(*)
819
+
820
+ template <class T>
821
+ Vectorized<T> inline operator/(const Vectorized<T>& a, const Vectorized<T>& b)
822
+ __ubsan_ignore_float_divide_by_zero__ {
823
+ Vectorized<T> c;
824
+ for (int i = 0; i != Vectorized<T>::size(); i++) {
825
+ c[i] = a[i] / b[i];
826
+ }
827
+ return c;
828
+ }
829
+
830
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(/)
831
+
832
+ template <class T, typename std::enable_if_t<!is_floating_point_v<T>, int> = 0>
833
+ Vectorized<T> inline operator%(const Vectorized<T>& a, const Vectorized<T>& b)
834
+ __ubsan_ignore_float_divide_by_zero__ {
835
+ return a - a / b * b;
836
+ }
837
+
838
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(%)
839
+
840
+ template <class T>
841
+ Vectorized<T> inline operator||(
842
+ const Vectorized<T>& a,
843
+ const Vectorized<T>& b) {
844
+ Vectorized<T> c;
845
+ for (int i = 0; i != Vectorized<T>::size(); i++) {
846
+ c[i] = a[i] || b[i];
847
+ }
848
+ return c;
849
+ }
850
+
851
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(||)
852
+
853
+ // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
854
+ // either input is a NaN.
855
+ template <
856
+ class T,
857
+ typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
858
+ Vectorized<T> inline maximum(const Vectorized<T>& a, const Vectorized<T>& b) {
859
+ Vectorized<T> c;
860
+ for (int i = 0; i != Vectorized<T>::size(); i++) {
861
+ c[i] = (a[i] > b[i]) ? a[i] : b[i];
862
+ if (_isnan(a[i])) {
863
+ // If either input is NaN, propagate a NaN.
864
+ // NOTE: The case where b[i] was NaN is handled correctly by the naive
865
+ // ternary operator above.
866
+ c[i] = a[i];
867
+ }
868
+ }
869
+ return c;
870
+ }
871
+
872
+ template <
873
+ class T,
874
+ typename std::enable_if_t<c10::is_complex<T>::value, int> = 0>
875
+ Vectorized<T> inline maximum(const Vectorized<T>& a, const Vectorized<T>& b) {
876
+ Vectorized<T> c;
877
+ for (int i = 0; i != Vectorized<T>::size(); i++) {
878
+ c[i] = (std::abs(a[i]) > std::abs(b[i])) ? a[i] : b[i];
879
+ if (_isnan(a[i])) {
880
+ // If either input is NaN, propagate a NaN.
881
+ // NOTE: The case where b[i] was NaN is handled correctly by the naive
882
+ // ternary operator above.
883
+ c[i] = a[i];
884
+ }
885
+ }
886
+ return c;
887
+ }
888
+
889
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(maximum)
890
+
891
+ // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
892
+ // either input is a NaN.
893
+ template <
894
+ class T,
895
+ typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
896
+ Vectorized<T> inline minimum(const Vectorized<T>& a, const Vectorized<T>& b) {
897
+ Vectorized<T> c;
898
+ for (int i = 0; i != Vectorized<T>::size(); i++) {
899
+ c[i] = (a[i] < b[i]) ? a[i] : b[i];
900
+ if (_isnan(a[i])) {
901
+ // If either input is NaN, propagate a NaN.
902
+ // NOTE: The case where b[i] was NaN is handled correctly by the naive
903
+ // ternary operator above.
904
+ c[i] = a[i];
905
+ }
906
+ }
907
+ return c;
908
+ }
909
+
910
+ template <
911
+ class T,
912
+ typename std::enable_if_t<c10::is_complex<T>::value, int> = 0>
913
+ Vectorized<T> inline minimum(const Vectorized<T>& a, const Vectorized<T>& b) {
914
+ Vectorized<T> c;
915
+ for (int i = 0; i != Vectorized<T>::size(); i++) {
916
+ c[i] = (std::abs(a[i]) < std::abs(b[i])) ? a[i] : b[i];
917
+ if (_isnan(a[i])) {
918
+ // If either input is NaN, propagate a NaN.
919
+ // NOTE: The case where b[i] was NaN is handled correctly by the naive
920
+ // ternary operator above.
921
+ c[i] = a[i];
922
+ }
923
+ }
924
+ return c;
925
+ }
926
+
927
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(minimum)
928
+
929
+ template <
930
+ class T,
931
+ typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
932
+ Vectorized<T> inline clamp(
933
+ const Vectorized<T>& a,
934
+ const Vectorized<T>& min_vec,
935
+ const Vectorized<T>& max_vec) {
936
+ Vectorized<T> c;
937
+ for (int i = 0; i != Vectorized<T>::size(); i++) {
938
+ c[i] = std::min(std::max(a[i], min_vec[i]), max_vec[i]);
939
+ }
940
+ return c;
941
+ }
942
+
943
+ #define VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(name) \
944
+ template <class T> \
945
+ Vectorized<T> inline name( \
946
+ const Vectorized<T>& a, const Vectorized<T>& b, T c) { \
947
+ return name(a, b, Vectorized<T>(c)); \
948
+ } \
949
+ \
950
+ template <class T> \
951
+ Vectorized<T> inline name( \
952
+ const Vectorized<T>& a, T b, const Vectorized<T>& c) { \
953
+ return name(a, Vectorized<T>(b), c); \
954
+ } \
955
+ \
956
+ template <class T> \
957
+ Vectorized<T> inline name(const Vectorized<T>& a, T b, T c) { \
958
+ return name(a, Vectorized<T>(b), Vectorized<T>(c)); \
959
+ } \
960
+ \
961
+ template <class T> \
962
+ Vectorized<T> inline name( \
963
+ T a, const Vectorized<T>& b, const Vectorized<T>& c) { \
964
+ return name(Vectorized<T>(a), b, c); \
965
+ } \
966
+ \
967
+ template <class T> \
968
+ Vectorized<T> inline name(T a, const Vectorized<T>& b, T c) { \
969
+ return name(Vectorized<T>(a), b, Vectorized<T>(c)); \
970
+ } \
971
+ \
972
+ template <class T> \
973
+ Vectorized<T> inline name(T a, T b, const Vectorized<T>& c) { \
974
+ return name(Vectorized<T>(a), Vectorized<T>(b), c); \
975
+ }
976
+
977
+ VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(clamp)
978
+
979
+ template <
980
+ class T,
981
+ typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
982
+ Vectorized<T> inline clamp_max(
983
+ const Vectorized<T>& a,
984
+ const Vectorized<T>& max_vec) {
985
+ Vectorized<T> c;
986
+ for (int i = 0; i != Vectorized<T>::size(); i++) {
987
+ c[i] = a[i] > max_vec[i] ? max_vec[i] : a[i];
988
+ }
989
+ return c;
990
+ }
991
+
992
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(clamp_max)
993
+
994
+ template <
995
+ class T,
996
+ typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
997
+ Vectorized<T> inline clamp_min(
998
+ const Vectorized<T>& a,
999
+ const Vectorized<T>& min_vec) {
1000
+ Vectorized<T> c;
1001
+ for (int i = 0; i != Vectorized<T>::size(); i++) {
1002
+ c[i] = a[i] < min_vec[i] ? min_vec[i] : a[i];
1003
+ }
1004
+ return c;
1005
+ }
1006
+
1007
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(clamp_min)
1008
+
1009
+ struct Vectorizedi;
1010
+
1011
+ #if defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)
1012
+ template <class T, typename Op>
1013
+ static inline Vectorized<T> bitwise_binary_op(
1014
+ const Vectorized<T>& a,
1015
+ const Vectorized<T>& b,
1016
+ Op op) {
1017
+ int_vector buffer;
1018
+ #if defined(CPU_CAPABILITY_AVX2)
1019
+ int_vector a_buffer =
1020
+ _mm256_load_si256(reinterpret_cast<const int_vector*>((const T*)a));
1021
+ int_vector b_buffer =
1022
+ _mm256_load_si256(reinterpret_cast<const int_vector*>((const T*)b));
1023
+ #elif defined(CPU_CAPABILITY_AVX512)
1024
+ int_vector a_buffer =
1025
+ _mm512_load_si512(reinterpret_cast<const int_vector*>((const T*)a));
1026
+ int_vector b_buffer =
1027
+ _mm512_load_si512(reinterpret_cast<const int_vector*>((const T*)b));
1028
+ #endif
1029
+ buffer = op(a_buffer, b_buffer);
1030
+ __at_align__ T results[Vectorized<T>::size()];
1031
+
1032
+ #if defined(CPU_CAPABILITY_AVX2)
1033
+ _mm256_store_si256(reinterpret_cast<int_vector*>(results), buffer);
1034
+ #elif defined(CPU_CAPABILITY_AVX512)
1035
+ _mm512_store_si512(reinterpret_cast<int_vector*>(results), buffer);
1036
+ #endif
1037
+ return Vectorized<T>::loadu(results);
1038
+ }
1039
+
1040
+ template <
1041
+ class T,
1042
+ typename std::enable_if_t<
1043
+ !std::is_base_of<Vectorizedi, Vectorized<T>>::value,
1044
+ int> = 0>
1045
+ inline Vectorized<T> operator&(const Vectorized<T>& a, const Vectorized<T>& b) {
1046
+ // We enclose _mm512_and_si512 or _mm256_and_si256 with lambda because it is
1047
+ // always_inline
1048
+ #if defined(CPU_CAPABILITY_AVX2)
1049
+ return bitwise_binary_op(
1050
+ a, b, [](int_vector a, int_vector b) { return _mm256_and_si256(a, b); });
1051
+ #elif defined(CPU_CAPABILITY_AVX512)
1052
+ return bitwise_binary_op(
1053
+ a, b, [](int_vector a, int_vector b) { return _mm512_and_si512(a, b); });
1054
+ #endif
1055
+ }
1056
+ template <
1057
+ class T,
1058
+ typename std::enable_if_t<
1059
+ !std::is_base_of<Vectorizedi, Vectorized<T>>::value,
1060
+ int> = 0>
1061
+ inline Vectorized<T> operator|(const Vectorized<T>& a, const Vectorized<T>& b) {
1062
+ // We enclose _mm512_or_si512 or _mm256_or_si256 with lambda because it is
1063
+ // always_inline
1064
+ #if defined(CPU_CAPABILITY_AVX2)
1065
+ return bitwise_binary_op(
1066
+ a, b, [](int_vector a, int_vector b) { return _mm256_or_si256(a, b); });
1067
+ #elif defined(CPU_CAPABILITY_AVX512)
1068
+ return bitwise_binary_op(
1069
+ a, b, [](int_vector a, int_vector b) { return _mm512_or_si512(a, b); });
1070
+ #endif
1071
+ }
1072
+ template <
1073
+ class T,
1074
+ typename std::enable_if_t<
1075
+ !std::is_base_of<Vectorizedi, Vectorized<T>>::value,
1076
+ int> = 0>
1077
+ inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) {
1078
+ // We enclose _mm512_xor_si512 or _mm256_xor_si256 with lambda because it is
1079
+ // always_inline
1080
+ #if defined(CPU_CAPABILITY_AVX2)
1081
+ return bitwise_binary_op(
1082
+ a, b, [](int_vector a, int_vector b) { return _mm256_xor_si256(a, b); });
1083
+ #elif defined(CPU_CAPABILITY_AVX512)
1084
+ return bitwise_binary_op(
1085
+ a, b, [](int_vector a, int_vector b) { return _mm512_xor_si512(a, b); });
1086
+ #endif
1087
+ }
1088
+
1089
+ #else
1090
+
1091
+ template <typename T>
1092
+ auto load(char const* data) -> T {
1093
+ T ret;
1094
+ std::memcpy(&ret, data, sizeof(ret));
1095
+ return ret;
1096
+ }
1097
+
1098
+ template <class T, typename Op>
1099
+ static inline Vectorized<T> bitwise_binary_op(
1100
+ const Vectorized<T>& a,
1101
+ const Vectorized<T>& b,
1102
+ Op op) {
1103
+ static constexpr uint32_t element_no = VECTOR_WIDTH / sizeof(intmax_t);
1104
+ __at_align__ intmax_t buffer[element_no];
1105
+ static_assert(
1106
+ VECTOR_WIDTH % sizeof(intmax_t) == 0,
1107
+ "VECTOR_WIDTH not a multiple of sizeof(intmax_t)");
1108
+ static_assert(
1109
+ sizeof(buffer) == sizeof(Vectorized<T>),
1110
+ "sizeof(buffer) must match sizeof(Vectorized<T>)");
1111
+ // We should be using memcpy in order to respect the strict aliasing rule
1112
+ // see: https://github.com/pytorch/pytorch/issues/66119
1113
+ // Using char* is defined in the C11 standard 6.5 Expression paragraph 7
1114
+ // (http://www.open-std.org/jtc1/sc22/wg14/www/docs/n1570.pdf)
1115
+ const auto* a_data = a.as_bytes();
1116
+ const auto* b_data = b.as_bytes();
1117
+ // load each intmax_t chunk and process; increase pointers by sizeof(intmax_t)
1118
+ for (auto& out : buffer) {
1119
+ out = op(load<intmax_t>(a_data), load<intmax_t>(b_data));
1120
+ a_data += sizeof(intmax_t);
1121
+ b_data += sizeof(intmax_t);
1122
+ }
1123
+ assert(a_data == a.as_bytes() + sizeof(a));
1124
+ assert(b_data == b.as_bytes() + sizeof(b));
1125
+ return Vectorized<T>::loadu(buffer);
1126
+ }
1127
+
1128
+ template <
1129
+ class T,
1130
+ typename std::
1131
+ enable_if_t<!std::is_base_of_v<Vectorizedi, Vectorized<T>>, int> = 0>
1132
+ inline Vectorized<T> operator&(const Vectorized<T>& a, const Vectorized<T>& b) {
1133
+ return bitwise_binary_op(a, b, std::bit_and<intmax_t>());
1134
+ }
1135
+ template <
1136
+ class T,
1137
+ typename std::
1138
+ enable_if_t<!std::is_base_of_v<Vectorizedi, Vectorized<T>>, int> = 0>
1139
+ inline Vectorized<T> operator|(const Vectorized<T>& a, const Vectorized<T>& b) {
1140
+ return bitwise_binary_op(a, b, std::bit_or<intmax_t>());
1141
+ }
1142
+ template <
1143
+ class T,
1144
+ typename std::
1145
+ enable_if_t<!std::is_base_of_v<Vectorizedi, Vectorized<T>>, int> = 0>
1146
+ inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) {
1147
+ return bitwise_binary_op(a, b, std::bit_xor<intmax_t>());
1148
+ }
1149
+
1150
+ #endif // defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)
1151
+
1152
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(&)
1153
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(|)
1154
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(^)
1155
+
1156
+ template <
1157
+ class T,
1158
+ typename std::
1159
+ enable_if_t<!std::is_base_of_v<Vectorizedi, Vectorized<T>>, int> = 0>
1160
+ inline Vectorized<T> operator~(const Vectorized<T>& a) {
1161
+ using int_t = int_same_size_t<T>;
1162
+ Vectorized<T> ones(c10::bit_cast<T>((int_t)(~(int_t)0))); // All bits are 1
1163
+ return a ^ ones;
1164
+ }
1165
+
1166
+ template <class T>
1167
+ Vectorized<T> inline operator<<(
1168
+ const Vectorized<T>& a,
1169
+ const Vectorized<T>& b) {
1170
+ constexpr T max_shift = sizeof(T) * CHAR_BIT;
1171
+ Vectorized<T> c;
1172
+ for (int i = 0; i != Vectorized<T>::size(); i++) {
1173
+ T shift = b[i];
1174
+ if ((static_cast<std::make_signed_t<T>>(shift) < 0) ||
1175
+ (shift >= max_shift)) {
1176
+ c[i] = 0;
1177
+ } else {
1178
+ c[i] = static_cast<std::make_unsigned_t<T>>(a[i]) << shift;
1179
+ }
1180
+ }
1181
+ return c;
1182
+ }
1183
+
1184
+ template <class T>
1185
+ Vectorized<T> inline operator>>(
1186
+ const Vectorized<T>& a,
1187
+ const Vectorized<T>& b) {
1188
+ // right shift value to retain sign bit for signed and no bits for unsigned
1189
+ constexpr T max_shift = sizeof(T) * CHAR_BIT - std::is_signed_v<T>;
1190
+ Vectorized<T> c;
1191
+ for (int i = 0; i != Vectorized<T>::size(); i++) {
1192
+ T shift = b[i];
1193
+ if ((static_cast<std::make_signed_t<T>>(shift) < 0) ||
1194
+ (shift >= max_shift)) {
1195
+ c[i] = a[i] >> max_shift;
1196
+ } else {
1197
+ c[i] = a[i] >> shift;
1198
+ }
1199
+ }
1200
+ return c;
1201
+ }
1202
+
1203
+ template <typename T>
1204
+ inline Vectorized<T>& operator+=(Vectorized<T>& a, const Vectorized<T>& b) {
1205
+ a = a + b;
1206
+ return a;
1207
+ }
1208
+ template <typename T>
1209
+ inline Vectorized<T>& operator-=(Vectorized<T>& a, const Vectorized<T>& b) {
1210
+ a = a - b;
1211
+ return a;
1212
+ }
1213
+ template <typename T>
1214
+ inline Vectorized<T>& operator/=(Vectorized<T>& a, const Vectorized<T>& b) {
1215
+ a = a / b;
1216
+ return a;
1217
+ }
1218
+ template <typename T>
1219
+ inline Vectorized<T>& operator%=(Vectorized<T>& a, const Vectorized<T>& b) {
1220
+ a = a % b;
1221
+ return a;
1222
+ }
1223
+ template <typename T>
1224
+ inline Vectorized<T>& operator*=(Vectorized<T>& a, const Vectorized<T>& b) {
1225
+ a = a * b;
1226
+ return a;
1227
+ }
1228
+
1229
+ template <typename T>
1230
+ inline Vectorized<T>& operator<<=(Vectorized<T>& a, const Vectorized<T>& b) {
1231
+ a = a << b;
1232
+ return a;
1233
+ }
1234
+
1235
+ template <typename T>
1236
+ inline Vectorized<T>& operator>>=(Vectorized<T>& a, const Vectorized<T>& b) {
1237
+ a = a >> b;
1238
+ return a;
1239
+ }
1240
+
1241
+ template <typename T>
1242
+ inline Vectorized<T> fmadd(
1243
+ const Vectorized<T>& a,
1244
+ const Vectorized<T>& b,
1245
+ const Vectorized<T>& c) {
1246
+ return a * b + c;
1247
+ }
1248
+
1249
+ VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(fmadd)
1250
+
1251
+ template <typename T>
1252
+ inline Vectorized<T> fnmadd(
1253
+ const Vectorized<T>& a,
1254
+ const Vectorized<T>& b,
1255
+ const Vectorized<T>& c) {
1256
+ return -(a * b) + c;
1257
+ }
1258
+
1259
+ VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(fnmadd)
1260
+
1261
+ template <typename T>
1262
+ inline Vectorized<T> fmsub(
1263
+ const Vectorized<T>& a,
1264
+ const Vectorized<T>& b,
1265
+ const Vectorized<T>& c) {
1266
+ return a * b - c;
1267
+ }
1268
+
1269
+ VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(fmsub)
1270
+
1271
+ template <typename T>
1272
+ inline Vectorized<T> fnmsub(
1273
+ const Vectorized<T>& a,
1274
+ const Vectorized<T>& b,
1275
+ const Vectorized<T>& c) {
1276
+ return -(a * b) - c;
1277
+ }
1278
+
1279
+ VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(fnmsub)
1280
+
1281
+ template <typename T>
1282
+ Vectorized<T> inline operator&&(
1283
+ const Vectorized<T>& a,
1284
+ const Vectorized<T>& b) {
1285
+ Vectorized<T> ret;
1286
+ for (int i = 0; i != Vectorized<T>::size(); i++) {
1287
+ ret[i] = a[i] && b[i];
1288
+ }
1289
+ return ret;
1290
+ }
1291
+
1292
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(&&)
1293
+
1294
+ template <int64_t scale = 1, typename T = void>
1295
+ std::enable_if_t<
1296
+ scale == 1 || scale == 2 || scale == 4 || scale == 8,
1297
+ Vectorized<
1298
+ T>> inline gather(T const* base_addr, const Vectorized<int_same_size_t<T>>& vindex) {
1299
+ static constexpr int size = Vectorized<T>::size();
1300
+ int_same_size_t<T> index_arr[size];
1301
+ vindex.store(static_cast<void*>(index_arr));
1302
+ T buffer[size];
1303
+ for (const auto i : c10::irange(size)) {
1304
+ buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)];
1305
+ }
1306
+ return Vectorized<T>::loadu(static_cast<void*>(buffer));
1307
+ }
1308
+
1309
+ template <int64_t scale = 1, typename T = void>
1310
+ std::
1311
+ enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<T>> inline mask_gather(
1312
+ const Vectorized<T>& src,
1313
+ T const* base_addr,
1314
+ const Vectorized<int_same_size_t<T>>& vindex,
1315
+ Vectorized<T>& mask) {
1316
+ static constexpr int size = Vectorized<T>::size();
1317
+ T src_arr[size];
1318
+ int_same_size_t<T> mask_arr[size]; // use int type so we can logical and
1319
+ int_same_size_t<T> index_arr[size];
1320
+ src.store(static_cast<void*>(src_arr));
1321
+ mask.store(static_cast<void*>(mask_arr));
1322
+ vindex.store(static_cast<void*>(index_arr));
1323
+ T buffer[size];
1324
+ for (const auto i : c10::irange(size)) {
1325
+ if (mask_arr[i] & 0x01) { // check highest bit
1326
+ buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)];
1327
+ } else {
1328
+ buffer[i] = src_arr[i];
1329
+ }
1330
+ }
1331
+ mask = Vectorized<T>(static_cast<T>(0)); // "zero out" mask
1332
+ return Vectorized<T>::loadu(static_cast<void*>(buffer));
1333
+ }
1334
+
1335
+ // Cast a given vector to another type without changing the bits representation.
1336
+ // So a Vectorized<double> of 512 bits containing all ones can be cast to a
1337
+ // Vectorized<int64_t> of 512 bits containing all ones (i.e., eight negative
1338
+ // 1s). A Vec<double> of 256 bits containing all ones can be cast to a
1339
+ // Vec<int64_t> of 256 bits containing all ones (i.e., four negative 1s).
1340
+ // There is a struct here because we don't have static_if and I can't
1341
+ // partially specialize a templated function.
1342
+ template <typename dst_t, typename src_t>
1343
+ struct CastImpl {
1344
+ static inline Vectorized<dst_t> apply(const Vectorized<src_t>& src) {
1345
+ src_t src_arr[Vectorized<src_t>::size()];
1346
+ src.store(static_cast<void*>(src_arr));
1347
+ return Vectorized<dst_t>::loadu(static_cast<const void*>(src_arr));
1348
+ }
1349
+ };
1350
+
1351
+ template <typename scalar_t>
1352
+ struct CastImpl<scalar_t, scalar_t> {
1353
+ static inline Vectorized<scalar_t> apply(const Vectorized<scalar_t>& src) {
1354
+ return src;
1355
+ }
1356
+ };
1357
+
1358
+ template <typename dst_t, typename src_t>
1359
+ inline Vectorized<dst_t> cast(const Vectorized<src_t>& src) {
1360
+ return CastImpl<dst_t, src_t>::apply(src);
1361
+ }
1362
+
1363
+ template <typename T, typename IntType = int_same_size_t<T>>
1364
+ inline Vectorized<IntType> convert_to_int_of_same_size(
1365
+ const Vectorized<T>& src) {
1366
+ static_assert(sizeof(T) == sizeof(IntType));
1367
+ static constexpr int size = Vectorized<T>::size();
1368
+
1369
+ std::array<T, size> src_arr = {};
1370
+ src.store(static_cast<void*>(src_arr.data()));
1371
+ std::array<IntType, size> buffer;
1372
+ std::transform(
1373
+ src_arr.cbegin(), src_arr.cend(), buffer.begin(), [](const T& x) {
1374
+ return static_cast<IntType>(x);
1375
+ });
1376
+ return Vectorized<IntType>::loadu(static_cast<const void*>(buffer.data()));
1377
+ }
1378
+
1379
+ template <typename T, typename IntType = int_same_size_t<T>>
1380
+ inline Vectorized<T> convert_to_fp_of_same_size(
1381
+ const Vectorized<IntType>& src) {
1382
+ static_assert(sizeof(T) == sizeof(IntType));
1383
+ static constexpr int size = Vectorized<T>::size();
1384
+
1385
+ std::array<IntType, size> src_arr;
1386
+ src.store(static_cast<void*>(src_arr.data()));
1387
+ std::array<T, size> buffer;
1388
+ std::transform(
1389
+ src_arr.cbegin(), src_arr.cend(), buffer.begin(), [](const IntType& x) {
1390
+ return static_cast<T>(x);
1391
+ });
1392
+ return Vectorized<T>::loadu(static_cast<const void*>(buffer.data()));
1393
+ }
1394
+
1395
+ // clang-format off
1396
+ // Example inputs for AVX512:
1397
+ // a Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
1398
+ // b Vectorized<float> = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
1399
+ // returns:
1400
+ // Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
1401
+ // Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
1402
+ // Example inputs for AVX2: a Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3}
1403
+ // b Vectorized<float> = {a4, b4, a5, b5, a6, b6, a7, b7}
1404
+ // returns: Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7}
1405
+ // Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7}
1406
+ // clang-format on
1407
+ template <typename T>
1408
+ inline std::enable_if_t<
1409
+ Vectorized<T>::size() % 2 == 0,
1410
+ std::pair<Vectorized<T>, Vectorized<T>>>
1411
+ deinterleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
1412
+ static constexpr int size = Vectorized<T>::size();
1413
+ static constexpr int half_size = size / 2;
1414
+ T a_arr[size];
1415
+ T b_arr[size];
1416
+ T buffer1[size];
1417
+ T buffer2[size];
1418
+ a.store(static_cast<void*>(a_arr));
1419
+ b.store(static_cast<void*>(b_arr));
1420
+ for (const auto i : c10::irange(half_size)) {
1421
+ buffer1[i] = a_arr[i * 2];
1422
+ buffer1[half_size + i] = b_arr[i * 2];
1423
+ buffer2[i] = a_arr[i * 2 + 1];
1424
+ buffer2[half_size + i] = b_arr[i * 2 + 1];
1425
+ }
1426
+ return std::make_pair(
1427
+ Vectorized<T>::loadu(static_cast<void*>(buffer1)),
1428
+ Vectorized<T>::loadu(static_cast<void*>(buffer2)));
1429
+ }
1430
+
1431
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(deinterleave2)
1432
+
1433
+ // clang-format off
1434
+ // inverse operation of deinterleave2
1435
+ // Example inputs for AVX512:
1436
+ // a Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
1437
+ // b Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
1438
+ // returns, for AVX512:
1439
+ // Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
1440
+ // Vectorized<float> = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
1441
+ // Example inputs for AVX2 : a Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7}
1442
+ // b Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7}
1443
+ // returns: Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3}
1444
+ // Vectorized<float> = {a4, b4, a5, b5, a6, b6, a7, b7}
1445
+ // clang-format on
1446
+ template <typename T>
1447
+ inline std::enable_if_t<
1448
+ Vectorized<T>::size() % 2 == 0,
1449
+ std::pair<Vectorized<T>, Vectorized<T>>>
1450
+ interleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
1451
+ static constexpr int size = Vectorized<T>::size();
1452
+ static constexpr int half_size = size / 2;
1453
+ T a_arr[size];
1454
+ T b_arr[size];
1455
+ T buffer1[size];
1456
+ T buffer2[size];
1457
+ a.store(static_cast<void*>(a_arr));
1458
+ b.store(static_cast<void*>(b_arr));
1459
+ for (const auto i : c10::irange(half_size)) {
1460
+ buffer1[i * 2] = a_arr[i];
1461
+ buffer1[i * 2 + 1] = b_arr[i];
1462
+ buffer2[i * 2] = a_arr[half_size + i];
1463
+ buffer2[i * 2 + 1] = b_arr[half_size + i];
1464
+ }
1465
+ return std::make_pair(
1466
+ Vectorized<T>::loadu(static_cast<void*>(buffer1)),
1467
+ Vectorized<T>::loadu(static_cast<void*>(buffer2)));
1468
+ }
1469
+
1470
+ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(interleave2)
1471
+
1472
+ #undef VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC
1473
+ #undef VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP
1474
+ #undef VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC
1475
+
1476
+ template <typename src_T, typename dst_T>
1477
+ inline void convert(const src_T* src, dst_T* dst, int64_t n) {
1478
+ #ifndef _MSC_VER
1479
+ #pragma unroll
1480
+ #endif
1481
+ for ([[maybe_unused]] const auto i : c10::irange(n)) {
1482
+ *dst = c10::convert<dst_T>(c10::load(src));
1483
+ src++;
1484
+ dst++;
1485
+ }
1486
+ }
1487
+
1488
+ template <typename T>
1489
+ inline Vectorized<T> flip(const Vectorized<T>& data) {
1490
+ static constexpr int size = Vectorized<T>::size();
1491
+ T output[size];
1492
+ T buffer[size];
1493
+ data.store(static_cast<void*>(buffer));
1494
+ for (const auto i : c10::irange(size)) {
1495
+ output[i] = buffer[size - i - 1];
1496
+ }
1497
+ return Vectorized<T>::loadu(static_cast<void*>(output));
1498
+ }
1499
+
1500
+ // Transpose the `src` buffer of type `T` and size (M,N) into the `dst` buffer.
1501
+ // `ld_src` is the leading dimension of `src` and `ld_dst` is the leading
1502
+ // dimension of `dst`.
1503
+ template <typename T>
1504
+ inline void transpose_mxn(
1505
+ const T* src,
1506
+ int64_t ld_src,
1507
+ T* dst,
1508
+ int64_t ld_dst,
1509
+ int M,
1510
+ int N) {
1511
+ for (int i = 0; i < M; i++) {
1512
+ for (int j = 0; j < N; j++) {
1513
+ dst[j * ld_dst + i] = src[i * ld_src + j];
1514
+ }
1515
+ }
1516
+ }
1517
+
1518
+ template <typename T, int M, int N>
1519
+ inline void transpose_mxn(
1520
+ const T* src,
1521
+ int64_t ld_src,
1522
+ T* dst,
1523
+ int64_t ld_dst) {
1524
+ transpose_mxn<T>(src, ld_src, dst, ld_dst, M, N);
1525
+ }
1526
+
1527
+ } // namespace CPU_CAPABILITY
1528
+ } // namespace at::vec
1529
+
1530
+ // additional headers for more operations that depend on vec_base
1531
+ #include <ATen/cpu/vec/vec_convert.h>
1532
+ #include <ATen/cpu/vec/vec_mask.h>
1533
+ #include <ATen/cpu/vec/vec_n.h>
1534
+
1535
+ #else
1536
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
1537
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_convert.h ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/cpu/vec/vec_base.h>
5
+ #include <ATen/cpu/vec/vec_n.h>
6
+
7
+ namespace at::vec {
8
+ inline namespace CPU_CAPABILITY {
9
+
10
+ template <
11
+ typename dst_t,
12
+ int dst_n,
13
+ typename src_t,
14
+ int src_n,
15
+ typename Enabled = void>
16
+ struct VecConvert {
17
+ static inline VectorizedN<dst_t, dst_n> apply(
18
+ const VectorizedN<src_t, src_n>& src) {
19
+ constexpr int count = std::min(
20
+ VectorizedN<src_t, src_n>::size(), VectorizedN<dst_t, dst_n>::size());
21
+ __at_align__ src_t src_buf[VectorizedN<src_t, src_n>::size()];
22
+ src.store(src_buf);
23
+ __at_align__ dst_t dst_buf[VectorizedN<dst_t, dst_n>::size()];
24
+ for (int i = 0; i < count; i++) {
25
+ dst_buf[i] = static_cast<dst_t>(src_buf[i]);
26
+ }
27
+ return VectorizedN<dst_t, dst_n>::loadu(dst_buf, count);
28
+ }
29
+ };
30
+
31
+ template <typename dst_t, typename src_t>
32
+ inline std::enable_if_t<std::is_same_v<dst_t, src_t>, Vectorized<src_t>> convert(
33
+ const Vectorized<src_t>& src) {
34
+ return src;
35
+ }
36
+
37
+ template <typename dst_t, typename src_t>
38
+ inline std::enable_if_t<!std::is_same_v<dst_t, src_t>, Vectorized<dst_t>>
39
+ convert(const Vectorized<src_t>& src) {
40
+ return VecConvert<dst_t, 1, src_t, 1>::apply(src);
41
+ }
42
+
43
+ template <
44
+ typename dst_t,
45
+ int dst_n,
46
+ typename src_t,
47
+ int src_n,
48
+ std::enable_if_t<dst_n != 1, int> = 0>
49
+ inline VectorizedN<dst_t, dst_n> convert(const VectorizedN<src_t, src_n>& src) {
50
+ return VecConvert<dst_t, dst_n, src_t, src_n>::apply(src);
51
+ }
52
+
53
+ template <
54
+ typename dst_t,
55
+ int dst_n,
56
+ typename src_t,
57
+ int src_n,
58
+ bool keep = false,
59
+ std::enable_if_t<dst_n == 1, int> = 0>
60
+ inline std::conditional_t<keep, VectorizedN<dst_t, 1>, Vectorized<dst_t>>
61
+ convert(const VectorizedN<src_t, src_n>& src) {
62
+ return VecConvert<dst_t, dst_n, src_t, src_n>::apply(src);
63
+ }
64
+
65
+ } // namespace CPU_CAPABILITY
66
+
67
+ template <
68
+ typename scalar_t,
69
+ typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
70
+ inline std::tuple<Vectorized<float>, Vectorized<float>> convert_to_float(
71
+ const Vectorized<scalar_t>&);
72
+
73
+ template <
74
+ typename scalar_t,
75
+ typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
76
+ inline Vectorized<scalar_t> convert_from_float(
77
+ const Vectorized<float>&,
78
+ const Vectorized<float>&);
79
+
80
+ } // namespace at::vec
81
+
82
+ #else
83
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
84
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_half.h ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/cpu/vec/intrinsics.h>
5
+ #include <c10/util/Exception.h>
6
+
7
+ #include <torch/headeronly/cpu/vec/vec_half.h>
8
+
9
+ namespace at::vec {
10
+ // See Note [CPU_CAPABILITY namespace]
11
+ inline namespace CPU_CAPABILITY {
12
+
13
+ // Transpose a [2, 32] matrix to [32, 2]
14
+ // Note: the output leading dimension should be 2,
15
+ // that is, the output must be contiguous
16
+ template <typename scalar_t, typename = std::enable_if_t<sizeof(scalar_t) == 2>>
17
+ static inline void transpose_pad_2x32_block(
18
+ const scalar_t* src,
19
+ scalar_t* dst,
20
+ int64_t ld_src,
21
+ int krem = 2,
22
+ int nrem = 32) {
23
+ #if defined(CPU_CAPABILITY_AVX512)
24
+ __m512i r0, r1;
25
+ __m512i d0, d1;
26
+ // load
27
+ if (nrem < 32) {
28
+ __mmask32 mask_krem_v = (1LL << nrem) - 1;
29
+ r0 = _mm512_maskz_loadu_epi16(mask_krem_v, src);
30
+ // if krem is not 2, pad with zeros
31
+ if (krem == 2) {
32
+ r1 = _mm512_maskz_loadu_epi16(mask_krem_v, src + ld_src);
33
+ } else {
34
+ r1 = _mm512_setzero_si512();
35
+ }
36
+ } else {
37
+ r0 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src));
38
+ if (krem == 2) {
39
+ r1 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + ld_src));
40
+ } else {
41
+ r1 = _mm512_setzero_si512();
42
+ }
43
+ }
44
+ // transpose
45
+ d0 = _mm512_unpacklo_epi16(r0, r1);
46
+ d1 = _mm512_unpackhi_epi16(r0, r1);
47
+ r0 = _mm512_shuffle_i32x4(d0, d1, 0x88);
48
+ r1 = _mm512_shuffle_i32x4(d0, d1, 0xdd);
49
+ d0 = _mm512_shuffle_i32x4(r0, r1, 0x88);
50
+ d1 = _mm512_shuffle_i32x4(r0, r1, 0xdd);
51
+
52
+ // store
53
+ if (nrem < 16) {
54
+ __mmask32 mask_rem_v = (1LL << (nrem * 2)) - 1;
55
+ _mm512_mask_storeu_epi16(dst, mask_rem_v, d0);
56
+ } else if (nrem == 16) {
57
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0);
58
+ } else if (nrem < 32) {
59
+ __mmask32 mask_rem_v = (1LL << (nrem * 2 - 32)) - 1;
60
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0);
61
+ _mm512_mask_storeu_epi16(
62
+ reinterpret_cast<__m512i*>(dst + 32), mask_rem_v, d1);
63
+ } else {
64
+ // normal store
65
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0);
66
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 32), d1);
67
+ }
68
+ #else
69
+ TORCH_CHECK(
70
+ false,
71
+ "transpose_pad_2x32_block is only supported when avx512 is supported")
72
+ #endif
73
+ }
74
+
75
+ // To use AMX to accelerate GEMM,
76
+ // reorder the memory format [K, N] -> [K/2, N, 2]
77
+ // Note: If K % 2 != 0, pad K implicitly
78
+ template <typename scalar_t, typename = std::enable_if_t<sizeof(scalar_t) == 2>>
79
+ static inline void pack_vnni2(
80
+ const scalar_t* src,
81
+ scalar_t* dst,
82
+ int64_t ld_src,
83
+ int64_t K,
84
+ int64_t N) {
85
+ #if defined(CPU_CAPABILITY_AVX512)
86
+ int64_t bk = 0;
87
+ int64_t _K = K / 2 * 2;
88
+ int64_t _N = N / 32 * 32;
89
+ for (; bk < _K; bk += 2) {
90
+ int64_t bn = 0;
91
+ for (; bn < _N; bn += 32) {
92
+ transpose_pad_2x32_block(
93
+ src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src);
94
+ }
95
+ int64_t nrem = N - bn;
96
+ if (nrem > 0) {
97
+ transpose_pad_2x32_block(
98
+ src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 2, nrem);
99
+ }
100
+ }
101
+ if (K % 2 == 1) {
102
+ int64_t bn = 0;
103
+ for (; bn < _N; bn += 32) {
104
+ transpose_pad_2x32_block(
105
+ src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1);
106
+ }
107
+ int64_t nrem = N - bn;
108
+ if (nrem > 0) {
109
+ transpose_pad_2x32_block(
110
+ src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1, nrem);
111
+ }
112
+ }
113
+ #else
114
+ TORCH_CHECK(false, "pack_vnni2 is only supported when avx512 is supported")
115
+ #endif
116
+ }
117
+
118
+ } // namespace CPU_CAPABILITY
119
+ } // namespace at::vec
120
+
121
+ #else
122
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
123
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_mask.h ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/cpu/vec/vec_base.h>
5
+ #include <ATen/cpu/vec/vec_n.h>
6
+ namespace at::vec {
7
+ inline namespace CPU_CAPABILITY {
8
+
9
+ /**
10
+ * The `VecMask` class provides a convenient interface for working with
11
+ * vectorized masks in SIMD operations. It encapsulates a `Vectorized<T, N>`
12
+ * mask that can be directly usable in masked vectorized operations. It provides
13
+ * various methods for manipulating and accessing the mask elements:
14
+ * 1. `from` and `to`: Conversion between a vector of boolean values and a
15
+ * vectorized mask.
16
+ * 2. `cast`: Casts the mask to a different base type.
17
+ * 3. `all_zero`: Checks if all mask elements are zero.
18
+ * 4. `is_masked`: Checks if a specific element is masked.
19
+ * 5. `loadu`: Loads data from memory using the mask.
20
+ * 6. `all_masked`: Checks if all mask elements are masked.
21
+ *
22
+ * Some helper template classes are provided to simplify the specialization of
23
+ * the `VecMask` for the specific CPU arch:
24
+ * 1. `VecMaskLoad`: Loads data from memory using the mask.
25
+ * 2. `VecMaskTo`: Converts the mask to boolean.
26
+ * 3. `VecMaskCast`: Casts the mask to a different base type.
27
+ *
28
+ */
29
+ template <typename T, int N>
30
+ class VecMask;
31
+
32
+ template <
33
+ typename data_t,
34
+ int data_n,
35
+ typename mask_t,
36
+ int mask_n,
37
+ typename Enabled = void>
38
+ struct VecMaskLoad {
39
+ static inline VectorizedN<data_t, data_n> apply(
40
+ const data_t* ptr,
41
+ const VecMask<mask_t, mask_n>& vec_mask) {
42
+ constexpr typename VecMask<mask_t, mask_n>::size_type size =
43
+ VecMask<mask_t, mask_n>::size();
44
+ static_assert(VectorizedN<data_t, data_n>::size() >= size);
45
+ __at_align__ data_t data[size];
46
+ __at_align__ mask_t mask[size];
47
+ auto mask_ = VectorizedN<mask_t, mask_n>(vec_mask);
48
+ mask_.store(mask);
49
+ for (int i = 0; i < size; i++) {
50
+ data[i] = mask[i] ? ptr[i] : static_cast<data_t>(0);
51
+ }
52
+ return VectorizedN<data_t, data_n>::loadu(data, size);
53
+ }
54
+ };
55
+
56
+ template <
57
+ typename dst_t,
58
+ int dst_n,
59
+ typename src_t,
60
+ int src_n,
61
+ typename Enabled = void>
62
+ struct VecMaskTo {
63
+ static inline VecMask<dst_t, dst_n> apply(
64
+ const VecMask<src_t, src_n>& vec_mask) {
65
+ auto zeros = VectorizedN<dst_t, dst_n>(static_cast<dst_t>(0));
66
+ auto ones = VectorizedN<dst_t, dst_n>(static_cast<dst_t>(1));
67
+ return VectorizedN<dst_t, dst_n>::blendv(
68
+ zeros, ones, vec_mask.template cast<dst_t, dst_n>());
69
+ }
70
+ };
71
+
72
+ template <
73
+ typename dst_t,
74
+ int dst_n,
75
+ typename src_t,
76
+ int src_n,
77
+ typename Enabled = void>
78
+ struct VecMaskCast {
79
+ static inline VecMask<dst_t, dst_n> apply(
80
+ const VecMask<src_t, src_n>& vec_mask) {
81
+ return VecMask<dst_t, dst_n>::from(VectorizedN<src_t, src_n>(vec_mask));
82
+ }
83
+ };
84
+
85
+ template <typename T, int N>
86
+ struct VecMaskCast<T, N, T, N> {
87
+ static inline VecMask<T, N> apply(const VecMask<T, N>& vec_mask) {
88
+ return vec_mask;
89
+ }
90
+ };
91
+
92
+ template <typename T, int N>
93
+ struct VecMaskCheck {
94
+ static inline bool all_zero(const VectorizedN<T, N>& vec_mask) {
95
+ __at_align__ T mask[VectorizedN<T, N>::size()];
96
+ vec_mask.store(mask);
97
+ return std::all_of(mask, mask + VectorizedN<T, N>::size(), [](T m) {
98
+ return m == static_cast<T>(0);
99
+ });
100
+ }
101
+
102
+ static inline bool all_masked(const VectorizedN<T, N>& vec_mask) {
103
+ __at_align__ T mask[VectorizedN<T, N>::size()];
104
+ vec_mask.store(mask);
105
+ return std::all_of(mask, mask + VectorizedN<T, N>::size(), [](T m) {
106
+ return m != static_cast<T>(0);
107
+ });
108
+ }
109
+
110
+ static inline bool is_masked(const VectorizedN<T, N>& vec_mask, int i) {
111
+ __at_align__ T mask[VectorizedN<T, N>::size()];
112
+ vec_mask.store(mask);
113
+ return mask[i] != static_cast<T>(0);
114
+ }
115
+ };
116
+
117
+ template <typename T, int N>
118
+ class VecMask {
119
+ public:
120
+ using size_type = int;
121
+ static constexpr size_type size() {
122
+ return VectorizedN<T, N>::size();
123
+ }
124
+
125
+ private:
126
+ VectorizedN<T, N> mask_;
127
+
128
+ public:
129
+ VecMask() : mask_(static_cast<T>(0)) {}
130
+ VecMask(const VectorizedN<T, N>& mask) : mask_(mask) {}
131
+
132
+ template <int L = N, typename std::enable_if_t<L == 1, int> = 0>
133
+ VecMask(const Vectorized<T>& mask) : mask_(mask) {}
134
+
135
+ template <typename U, int L>
136
+ static VecMask<T, N> from(const VectorizedN<U, L>& b_vec) {
137
+ __at_align__ U b_buf[size()];
138
+ if constexpr (size() >= VectorizedN<U, L>::size()) {
139
+ b_vec.store(b_buf);
140
+ for (int i = VectorizedN<U, L>::size(); i < size(); i++) {
141
+ b_buf[i] = static_cast<U>(0);
142
+ }
143
+ } else {
144
+ b_vec.store(b_buf, size());
145
+ }
146
+ return from(b_buf);
147
+ }
148
+
149
+ template <typename U>
150
+ static VecMask<T, N> from(U b) {
151
+ using int_t = int_same_size_t<T>;
152
+ T mask = b ? c10::bit_cast<T>((int_t)(~(int_t)0)) : (T)0;
153
+ return VectorizedN<T, N>(mask);
154
+ }
155
+
156
+ template <typename U>
157
+ static VecMask<T, N> from(U* b) {
158
+ using int_t = int_same_size_t<T>;
159
+ __at_align__ T mask[size()];
160
+ #ifndef __msvc_cl__
161
+ #pragma unroll
162
+ #endif
163
+ for (int i = 0; i < size(); i++) {
164
+ *(int_t*)(mask + i) = b[i] ? ~(int_t)0 : (int_t)0;
165
+ }
166
+ return VectorizedN<T, N>(VectorizedN<T, N>::loadu(mask));
167
+ }
168
+
169
+ template <typename U>
170
+ static VecMask<T, N> from(U* b, int count) {
171
+ using int_t = int_same_size_t<T>;
172
+ __at_align__ T mask[size()];
173
+ #ifndef __msvc_cl__
174
+ #pragma unroll
175
+ #endif
176
+ for (int i = 0; i < count; i++) {
177
+ *(int_t*)(mask + i) = b[i] ? ~(int_t)0 : (int_t)0;
178
+ }
179
+ return VectorizedN<T, N>(VectorizedN<T, N>::loadu(mask, count));
180
+ }
181
+
182
+ static VecMask<T, N> blendv(
183
+ const VecMask<T, N>& c,
184
+ const VecMask<T, N>& b,
185
+ const VecMask<T, N>& a) {
186
+ VectorizedN<T, N> result = VectorizedN<T, N>::blendv(
187
+ VectorizedN<T, N>(c), VectorizedN<T, N>(b), VectorizedN<T, N>(a));
188
+ return result;
189
+ }
190
+
191
+ static VecMask<T, N> set(
192
+ const VecMask<T, N>& a,
193
+ const VecMask<T, N>& b,
194
+ int64_t count = size()) {
195
+ VectorizedN<T, N> result = VectorizedN<T, N>::set(
196
+ VectorizedN<T, N>(a), VectorizedN<T, N>(b), count);
197
+ return result;
198
+ }
199
+
200
+ void store(bool* b, int count = size()) {
201
+ constexpr int L =
202
+ (VectorizedN<T, N>::size() + Vectorized<bool>::size() - 1) /
203
+ Vectorized<bool>::size();
204
+ auto res = this->to<bool, L>();
205
+ res.store(b, count);
206
+ return;
207
+ }
208
+
209
+ template <typename U, int L, std::enable_if_t<L >= 2, int> = 0>
210
+ inline VectorizedN<U, L> to() const {
211
+ return VecMaskTo<U, L, T, N>::apply(*this);
212
+ }
213
+
214
+ template <typename U, int L, std::enable_if_t<L == 1, int> = 0>
215
+ inline Vectorized<U> to() const {
216
+ return VecMaskTo<U, L, T, N>::apply(*this);
217
+ }
218
+
219
+ template <typename U, int L>
220
+ inline VecMask<U, L> cast() const {
221
+ return VecMaskCast<U, L, T, N>::apply(*this);
222
+ }
223
+
224
+ inline bool all_zero() const {
225
+ return VecMaskCheck<T, N>::all_zero(mask_);
226
+ }
227
+
228
+ inline bool all_masked() const {
229
+ return VecMaskCheck<T, N>::all_masked(mask_);
230
+ }
231
+
232
+ inline bool is_masked(int i) const {
233
+ return VecMaskCheck<T, N>::is_masked(mask_, i);
234
+ }
235
+
236
+ inline operator VectorizedN<T, N>() const {
237
+ return mask_;
238
+ }
239
+
240
+ template <int L = N, typename std::enable_if_t<L == 1, int> = 0>
241
+ inline operator Vectorized<T>() const {
242
+ return mask_[0];
243
+ }
244
+
245
+ inline Vectorized<T> operator[](int i) const {
246
+ return mask_[i];
247
+ }
248
+
249
+ template <
250
+ typename U,
251
+ int L,
252
+ std::enable_if_t<L >= 2 && VectorizedN<U, L>::size() >= size(), int> = 0>
253
+ VectorizedN<U, L> loadu(const U* ptr) const {
254
+ return VecMaskLoad<U, L, T, N>::apply(ptr, *this);
255
+ }
256
+
257
+ template <
258
+ typename U,
259
+ int L,
260
+ std::enable_if_t<L == 1 && Vectorized<U>::size() >= size(), int> = 0>
261
+ Vectorized<U> loadu(const U* ptr) const {
262
+ return VecMaskLoad<U, L, T, N>::apply(ptr, *this);
263
+ }
264
+ };
265
+
266
+ #define VEC_MASK_DEFINE_UNARY_OP_GLOBAL(op) \
267
+ template <typename T, int N> \
268
+ inline VecMask<T, N> op(const VecMask<T, N>& a) { \
269
+ return op(VectorizedN<T, N>(a)); \
270
+ }
271
+
272
+ #define VEC_MASK_DEFINE_BINARY_OP_GLOBAL(op) \
273
+ template < \
274
+ typename T, \
275
+ int N, \
276
+ typename V, \
277
+ int M, \
278
+ std::enable_if_t<VecMask<T, N>::size() == VecMask<V, M>::size(), int> = \
279
+ 0> \
280
+ inline VecMask<T, N> op(const VecMask<T, N>& a, const VecMask<V, M>& b) { \
281
+ return op( \
282
+ VectorizedN<T, N>(a), VectorizedN<T, N>(b.template cast<T, N>())); \
283
+ }
284
+
285
+ #define VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(op, EXPR) \
286
+ template < \
287
+ typename T, \
288
+ int N, \
289
+ typename V, \
290
+ int M, \
291
+ std::enable_if_t<VecMask<T, N>::size() == VecMask<V, M>::size(), int> = \
292
+ 0> \
293
+ inline VecMask<T, N> op(const VecMask<T, N>& a, const VecMask<V, M>& b) { \
294
+ return EXPR; \
295
+ }
296
+
297
+ VEC_MASK_DEFINE_UNARY_OP_GLOBAL(operator~)
298
+ VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator&)
299
+ VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator|)
300
+ VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator^)
301
+ VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator*)
302
+ VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator>, a & ~b)
303
+ VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator<, ~a& b)
304
+ VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator==, ~(a ^ b))
305
+ VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator>=, (a == b) | (a > b))
306
+ VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator<=, (a == b) | (a < b))
307
+ VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator!=, (a ^ b))
308
+
309
+ #undef VEC_MASK_DEFINE_UNARY_OP_GLOBAL
310
+ #undef VEC_MASK_DEFINE_BINARY_OP_GLOBAL
311
+ #undef VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL
312
+
313
+ } // namespace CPU_CAPABILITY
314
+ } // namespace at::vec
315
+
316
+ #else
317
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
318
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_n.h ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/cpu/vec/vec_base.h>
5
+ #include <array>
6
+
7
+ namespace at::vec {
8
+ inline namespace CPU_CAPABILITY {
9
+
10
+ /**
11
+ * @brief A class template representing a vectorized type with
12
+ * `N * Vectorized<T>::size()` elements, aiming to support vectors of
13
+ * arbitrary size. A specific use case of it is to represent vectors
14
+ * converted from data types with different sizes but with the same
15
+ * number of vector elements, e.g., `VectorizedN<float, 2>` can be
16
+ * a vector converted from two `Vectorized<bfloat16>`, `VectorizedN<int64_t, 2>`
17
+ * can be a vector converted from two `Vectorized<int32_t>` etc.
18
+ *
19
+ * It supports most of the operations of `Vectorized<T>`
20
+ * and the implementation delegates to `Vectorized<T>` with loops over `N`.
21
+ *
22
+ * @tparam T The underlying type of the vectorized elements.
23
+ * @tparam N The number of underlying `Vectorized<T>`.
24
+ */
25
+ template <typename T, int N>
26
+ class VectorizedN {
27
+ public:
28
+ using value_type = T;
29
+ using size_type = int;
30
+
31
+ static constexpr size_type size_T = sizeof(T);
32
+ static constexpr size_type size() {
33
+ return Vectorized<T>::size() * N;
34
+ }
35
+
36
+ private:
37
+ std::array<Vectorized<T>, N> values;
38
+
39
+ public:
40
+ // methods not implemented yet:
41
+ // variadic constructor, operator T*, as_bytes, zero_mask
42
+
43
+ #define VECTORIZEDN_DEFINE_UNARY_OP(op) \
44
+ VectorizedN<T, N> op() const { \
45
+ return unary_op([](const Vectorized<T>& a) { return a.op(); }); \
46
+ }
47
+
48
+ #define VECTORIZEDN_DEFINE_BINARY_OP(op) \
49
+ VectorizedN<T, N> op(const VectorizedN<T, N>& other) const { \
50
+ return binary_op( \
51
+ other, [](const Vectorized<T>& a, const Vectorized<T>& b) { \
52
+ return a.op(b); \
53
+ }); \
54
+ }
55
+
56
+ template <typename Op>
57
+ inline VectorizedN<T, N> unary_op(Op op) const {
58
+ VectorizedN<T, N> result;
59
+ #ifndef _MSC_VER
60
+ #pragma unroll
61
+ #endif
62
+ for (int i = 0; i < N; ++i) {
63
+ result.values[i] = op(values[i]);
64
+ }
65
+ return result;
66
+ }
67
+
68
+ template <typename Op>
69
+ inline VectorizedN<T, N> binary_op(const VectorizedN<T, N>& other, Op op)
70
+ const {
71
+ VectorizedN<T, N> result;
72
+ #ifndef _MSC_VER
73
+ #pragma unroll
74
+ #endif
75
+ for (int i = 0; i < N; ++i) {
76
+ result.values[i] = op(values[i], other.values[i]);
77
+ }
78
+ return result;
79
+ }
80
+
81
+ template <typename Op>
82
+ inline VectorizedN<T, N> ternary_op(
83
+ const VectorizedN<T, N>& other,
84
+ const VectorizedN<T, N>& other2,
85
+ Op op) const {
86
+ VectorizedN<T, N> result;
87
+ #ifndef _MSC_VER
88
+ #pragma unroll
89
+ #endif
90
+ for (int i = 0; i < N; ++i) {
91
+ result.values[i] = op(values[i], other.values[i], other2.values[i]);
92
+ }
93
+ return result;
94
+ }
95
+
96
+ VectorizedN() = default;
97
+
98
+ explicit VectorizedN(T val) {
99
+ for (int i = 0; i < N; ++i) {
100
+ values[i] = Vectorized<T>(val);
101
+ }
102
+ }
103
+
104
+ template <int L = N, typename std::enable_if_t<L == 1, int> = 0>
105
+ VectorizedN(const Vectorized<T>& val) : values({val}) {}
106
+
107
+ template <int L = N, typename std::enable_if_t<L == 2, int> = 0>
108
+ VectorizedN(const Vectorized<T>& val_0, const Vectorized<T>& val_1)
109
+ : values({val_0, val_1}) {}
110
+
111
+ template <int L = N, typename std::enable_if_t<L == 1, int> = 0>
112
+ inline operator Vectorized<T>() const {
113
+ return values[0];
114
+ }
115
+
116
+ inline const Vectorized<T>& operator[](int i) const {
117
+ return values[i];
118
+ }
119
+
120
+ inline Vectorized<T>& operator[](int i) {
121
+ return values[i];
122
+ }
123
+
124
+ template <int64_t mask>
125
+ static VectorizedN<T, N> blend(
126
+ const VectorizedN<T, N>& a,
127
+ const VectorizedN<T, N>& b) {
128
+ VectorizedN<T, N> result;
129
+ for (int i = 0; i < N; ++i) {
130
+ result.values[i] =
131
+ Vectorized<T>::template blend<mask>(a.values[i], b.values[i]);
132
+ }
133
+ return result;
134
+ }
135
+
136
+ static VectorizedN<T, N> blendv(
137
+ const VectorizedN<T, N>& a,
138
+ const VectorizedN<T, N>& b,
139
+ const VectorizedN<T, N>& mask) {
140
+ VectorizedN<T, N> result;
141
+ for (int i = 0; i < N; ++i) {
142
+ result.values[i] =
143
+ Vectorized<T>::blendv(a.values[i], b.values[i], mask.values[i]);
144
+ }
145
+ return result;
146
+ }
147
+
148
+ template <typename step_t>
149
+ static VectorizedN<T, N> arange(
150
+ T base = static_cast<T>(0),
151
+ step_t step = static_cast<step_t>(1)) {
152
+ VectorizedN<T, N> result;
153
+ for (int i = 0; i < N; ++i) {
154
+ result.values[i] = Vectorized<T>::arange(base, step);
155
+ base += step * Vectorized<T>::size();
156
+ }
157
+ return result;
158
+ }
159
+
160
+ static VectorizedN<T, N> set(
161
+ const VectorizedN<T, N>& a,
162
+ const VectorizedN<T, N>& b,
163
+ int64_t count = size()) {
164
+ VectorizedN<T, N> result;
165
+ for (int i = 0; i < N; ++i) {
166
+ if (count > 0) {
167
+ result.values[i] = Vectorized<T>::set(
168
+ a.values[i],
169
+ b.values[i],
170
+ std::min(count, (int64_t)Vectorized<T>::size()));
171
+ count -= Vectorized<T>::size();
172
+ } else {
173
+ result.values[i] = a.values[i];
174
+ }
175
+ }
176
+ return result;
177
+ }
178
+
179
+ static VectorizedN<T, N> loadu(const void* ptr) {
180
+ VectorizedN<T, N> result;
181
+ for (int i = 0; i < N; ++i) {
182
+ result.values[i] = Vectorized<T>::loadu(ptr);
183
+ ptr = static_cast<const T*>(ptr) + Vectorized<T>::size();
184
+ }
185
+ return result;
186
+ }
187
+
188
+ static VectorizedN<T, N> loadu(const void* ptr, int64_t count) {
189
+ VectorizedN<T, N> result;
190
+ for (int i = 0; i < N; ++i) {
191
+ if (count > 0) {
192
+ result.values[i] = Vectorized<T>::loadu(
193
+ ptr, std::min(count, (int64_t)Vectorized<T>::size()));
194
+ ptr = static_cast<const T*>(ptr) + Vectorized<T>::size();
195
+ count -= Vectorized<T>::size();
196
+ } else {
197
+ result.values[i] = Vectorized<T>((T)1);
198
+ }
199
+ }
200
+ return result;
201
+ }
202
+
203
+ void store(void* ptr) const {
204
+ for (int i = 0; i < N; ++i) {
205
+ values[i].store(ptr);
206
+ ptr = static_cast<T*>(ptr) + Vectorized<T>::size();
207
+ }
208
+ }
209
+
210
+ void store(void* ptr, int count) const {
211
+ for (int i = 0; i < N; ++i) {
212
+ values[i].store(ptr, std::min(count, (int)Vectorized<T>::size()));
213
+ ptr = static_cast<T*>(ptr) + Vectorized<T>::size();
214
+ count -= Vectorized<T>::size();
215
+ if (count <= 0) {
216
+ break;
217
+ }
218
+ }
219
+ }
220
+
221
+ bool has_inf_nan() const {
222
+ for (int i = 0; i < N; ++i) {
223
+ if (values[i].has_inf_nan()) {
224
+ return true;
225
+ }
226
+ }
227
+ return false;
228
+ }
229
+
230
+ VectorizedN<T, N> map(T (*const f)(T)) const {
231
+ VectorizedN<T, N> result;
232
+ for (int i = 0; i < N; ++i) {
233
+ result.values[i] = values[i].map(f);
234
+ }
235
+ return result;
236
+ }
237
+
238
+ VectorizedN<T, N> map(T (*const f)(const T&)) const {
239
+ VectorizedN<T, N> result;
240
+ for (int i = 0; i < N; ++i) {
241
+ result.values[i] = values[i].map(f);
242
+ }
243
+ return result;
244
+ }
245
+
246
+ VECTORIZEDN_DEFINE_UNARY_OP(isnan)
247
+ VECTORIZEDN_DEFINE_UNARY_OP(abs)
248
+ VECTORIZEDN_DEFINE_UNARY_OP(sgn)
249
+ VECTORIZEDN_DEFINE_UNARY_OP(angle)
250
+ VECTORIZEDN_DEFINE_UNARY_OP(real)
251
+ VECTORIZEDN_DEFINE_UNARY_OP(imag)
252
+ VECTORIZEDN_DEFINE_UNARY_OP(conj)
253
+ VECTORIZEDN_DEFINE_UNARY_OP(acos)
254
+ VECTORIZEDN_DEFINE_UNARY_OP(acosh)
255
+ VECTORIZEDN_DEFINE_UNARY_OP(asin)
256
+ VECTORIZEDN_DEFINE_UNARY_OP(asinh)
257
+ VECTORIZEDN_DEFINE_UNARY_OP(atan)
258
+ VECTORIZEDN_DEFINE_UNARY_OP(atanh)
259
+ VECTORIZEDN_DEFINE_BINARY_OP(atan2)
260
+ VECTORIZEDN_DEFINE_BINARY_OP(copysign)
261
+ VECTORIZEDN_DEFINE_UNARY_OP(erf)
262
+ VECTORIZEDN_DEFINE_UNARY_OP(erfc)
263
+ VECTORIZEDN_DEFINE_UNARY_OP(erfinv)
264
+ VECTORIZEDN_DEFINE_UNARY_OP(exp)
265
+ VECTORIZEDN_DEFINE_UNARY_OP(exp2)
266
+ VECTORIZEDN_DEFINE_UNARY_OP(expm1)
267
+ VECTORIZEDN_DEFINE_UNARY_OP(exp_u20)
268
+ VECTORIZEDN_DEFINE_UNARY_OP(fexp_u20)
269
+ VECTORIZEDN_DEFINE_UNARY_OP(frac)
270
+ VECTORIZEDN_DEFINE_BINARY_OP(fmod)
271
+ VECTORIZEDN_DEFINE_UNARY_OP(log)
272
+ VECTORIZEDN_DEFINE_UNARY_OP(log10)
273
+ VECTORIZEDN_DEFINE_UNARY_OP(log1p)
274
+ VECTORIZEDN_DEFINE_UNARY_OP(log2)
275
+ VECTORIZEDN_DEFINE_UNARY_OP(ceil)
276
+ VECTORIZEDN_DEFINE_UNARY_OP(cos)
277
+ VECTORIZEDN_DEFINE_UNARY_OP(cosh)
278
+ VECTORIZEDN_DEFINE_UNARY_OP(floor)
279
+ VECTORIZEDN_DEFINE_BINARY_OP(hypot)
280
+ VECTORIZEDN_DEFINE_UNARY_OP(i0)
281
+ VECTORIZEDN_DEFINE_UNARY_OP(i0e)
282
+ VECTORIZEDN_DEFINE_UNARY_OP(digamma)
283
+ VECTORIZEDN_DEFINE_BINARY_OP(igamma)
284
+ VECTORIZEDN_DEFINE_BINARY_OP(igammac)
285
+ VECTORIZEDN_DEFINE_UNARY_OP(neg)
286
+ VECTORIZEDN_DEFINE_BINARY_OP(nextafter)
287
+ VECTORIZEDN_DEFINE_UNARY_OP(round)
288
+ VECTORIZEDN_DEFINE_UNARY_OP(sin)
289
+ VECTORIZEDN_DEFINE_UNARY_OP(sinh)
290
+ VECTORIZEDN_DEFINE_UNARY_OP(tan)
291
+ VECTORIZEDN_DEFINE_UNARY_OP(tanh)
292
+ VECTORIZEDN_DEFINE_UNARY_OP(trunc)
293
+ VECTORIZEDN_DEFINE_UNARY_OP(lgamma)
294
+ VECTORIZEDN_DEFINE_UNARY_OP(sqrt)
295
+ VECTORIZEDN_DEFINE_UNARY_OP(reciprocal)
296
+ VECTORIZEDN_DEFINE_UNARY_OP(rsqrt)
297
+ VECTORIZEDN_DEFINE_BINARY_OP(pow)
298
+ VECTORIZEDN_DEFINE_BINARY_OP(operator==)
299
+ VECTORIZEDN_DEFINE_BINARY_OP(operator!=)
300
+ VECTORIZEDN_DEFINE_BINARY_OP(operator>=)
301
+ VECTORIZEDN_DEFINE_BINARY_OP(operator<=)
302
+ VECTORIZEDN_DEFINE_BINARY_OP(operator>)
303
+ VECTORIZEDN_DEFINE_BINARY_OP(operator<)
304
+ VECTORIZEDN_DEFINE_BINARY_OP(eq)
305
+ VECTORIZEDN_DEFINE_BINARY_OP(ne)
306
+ VECTORIZEDN_DEFINE_BINARY_OP(gt)
307
+ VECTORIZEDN_DEFINE_BINARY_OP(ge)
308
+ VECTORIZEDN_DEFINE_BINARY_OP(lt)
309
+ VECTORIZEDN_DEFINE_BINARY_OP(le)
310
+
311
+ #undef VECTORIZEDN_DEFINE_UNARY_OP
312
+ #undef VECTORIZEDN_DEFINE_BINARY_OP
313
+ };
314
+
315
+ #define VECTORIZEDN_DEFINE_UNARY_OP_GLOBAL(op) \
316
+ template <typename T, int N> \
317
+ inline VectorizedN<T, N> op(const VectorizedN<T, N>& a) { \
318
+ return a.unary_op([](const Vectorized<T>& a) { return op(a); }); \
319
+ }
320
+
321
+ #define VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(op) \
322
+ template <typename T, int N> \
323
+ inline VectorizedN<T, N> op( \
324
+ const VectorizedN<T, N>& a, const VectorizedN<T, N>& b) { \
325
+ return a.binary_op(b, [](const Vectorized<T>& a, const Vectorized<T>& b) { \
326
+ return op(a, b); \
327
+ }); \
328
+ }
329
+
330
+ #define VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL(op) \
331
+ template <typename T, int N> \
332
+ inline VectorizedN<T, N> op( \
333
+ const VectorizedN<T, N>& a, \
334
+ const VectorizedN<T, N>& b, \
335
+ const VectorizedN<T, N>& c) { \
336
+ return a.ternary_op( \
337
+ b, \
338
+ c, \
339
+ [](const Vectorized<T>& a, \
340
+ const Vectorized<T>& b, \
341
+ const Vectorized<T>& c) { return op(a, b, c); }); \
342
+ }
343
+
344
+ #define VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(op) \
345
+ template <typename T, int N> \
346
+ inline VectorizedN<T, N>& op( \
347
+ VectorizedN<T, N>& a, const VectorizedN<T, N>& b) { \
348
+ a = a.binary_op(b, [](const Vectorized<T>& a, const Vectorized<T>& b) { \
349
+ return op(a, b); \
350
+ }); \
351
+ return a; \
352
+ }
353
+
354
+ VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator+)
355
+ VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator-)
356
+ VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator*)
357
+ VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator/)
358
+ VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator%)
359
+ VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator||)
360
+ VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator<<)
361
+ VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator>>)
362
+ VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(maximum)
363
+ VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(minimum)
364
+ VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL(fmadd)
365
+ VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL(fmsub)
366
+ VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL(clamp)
367
+ VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp_max)
368
+ VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp_min)
369
+ VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator&)
370
+ VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator|)
371
+ VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator^)
372
+ VECTORIZEDN_DEFINE_UNARY_OP_GLOBAL(operator~)
373
+
374
+ VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator+=)
375
+ VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator-=)
376
+ VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator*=)
377
+ VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator/=)
378
+ VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator%=)
379
+ VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator<<=)
380
+ VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator>>=)
381
+
382
+ #undef VECTORIZEDN_DEFINE_UNARY_OP_GLOBAL
383
+ #undef VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL
384
+ #undef VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL
385
+
386
+ template <typename T, int N, typename OpVec>
387
+ inline T vec_reduce_all(const OpVec& vec_fun, VectorizedN<T, N> acc_vec) {
388
+ Vectorized<T> vec_result = acc_vec[0];
389
+ for (int i = 1; i < N; i++) {
390
+ vec_result = vec_fun(vec_result, acc_vec[i]);
391
+ }
392
+ return vec_reduce_all(vec_fun, vec_result);
393
+ }
394
+
395
+ template <typename T, int N>
396
+ std::ostream& operator<<(std::ostream& stream, const VectorizedN<T, N>& vec_n) {
397
+ stream << "vec_n[";
398
+ for (int i = 0; i < N; ++i) {
399
+ if (i != 0) {
400
+ stream << ", ";
401
+ }
402
+ stream << vec_n[i];
403
+ }
404
+ stream << ']';
405
+ return stream;
406
+ }
407
+ } // namespace CPU_CAPABILITY
408
+ } // namespace at::vec
409
+
410
+ #else
411
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
412
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_quant.h ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/cpu/vec/intrinsics.h>
5
+ #include <c10/util/Exception.h>
6
+
7
+ namespace at::vec {
8
+ // See Note [CPU_CAPABILITY namespace]
9
+ inline namespace CPU_CAPABILITY {
10
+
11
+ // Transpose a [4, 64] block to [64, 4] (with contiguous output, ld=4)
12
+ template <typename scalar_t, typename = std::enable_if_t<sizeof(scalar_t) == 1>>
13
+ static inline void transpose_pad_4x64_block(
14
+ const scalar_t* src,
15
+ scalar_t* dst,
16
+ int64_t ld_src,
17
+ int krem = 4,
18
+ int nrem = 64) {
19
+ #if defined(CPU_CAPABILITY_AVX512)
20
+ __m512i r[4];
21
+ // Load with mask if partial
22
+ if (nrem < 64) {
23
+ __mmask64 mask = (1ULL << nrem) - 1;
24
+ for (int i = 0; i < krem; ++i) {
25
+ r[i] = _mm512_maskz_loadu_epi8(mask, src + i * ld_src);
26
+ }
27
+ for (int i = krem; i < 4; ++i) {
28
+ r[i] = _mm512_setzero_si512();
29
+ }
30
+ } else {
31
+ for (int i = 0; i < krem; ++i) {
32
+ r[i] = _mm512_loadu_si512(
33
+ reinterpret_cast<const __m512i*>(src + i * ld_src));
34
+ }
35
+ for (int i = krem; i < 4; ++i) {
36
+ r[i] = _mm512_setzero_si512();
37
+ }
38
+ }
39
+
40
+ // Transpose 4x64 bytes using unpack and shuffle
41
+ __m512i t0 = _mm512_unpacklo_epi8(r[0], r[1]);
42
+ __m512i t1 = _mm512_unpackhi_epi8(r[0], r[1]);
43
+ __m512i t2 = _mm512_unpacklo_epi8(r[2], r[3]);
44
+ __m512i t3 = _mm512_unpackhi_epi8(r[2], r[3]);
45
+
46
+ __m512i u0 = _mm512_unpacklo_epi16(t0, t2);
47
+ __m512i u1 = _mm512_unpackhi_epi16(t0, t2);
48
+ __m512i u2 = _mm512_unpacklo_epi16(t1, t3);
49
+ __m512i u3 = _mm512_unpackhi_epi16(t1, t3);
50
+
51
+ __m512i v0 = _mm512_shuffle_i32x4(u0, u1, 0x88);
52
+ __m512i v1 = _mm512_shuffle_i32x4(u0, u1, 0xdd);
53
+ __m512i v2 = _mm512_shuffle_i32x4(u2, u3, 0x88);
54
+ __m512i v3 = _mm512_shuffle_i32x4(u2, u3, 0xdd);
55
+
56
+ __m512i r0 = _mm512_shuffle_i32x4(v0, v2, 0x88);
57
+ __m512i r1 = _mm512_shuffle_i32x4(v1, v3, 0x88);
58
+ __m512i r2 = _mm512_shuffle_i32x4(v0, v2, 0xdd);
59
+ __m512i r3 = _mm512_shuffle_i32x4(v1, v3, 0xdd);
60
+
61
+ // Store output
62
+ if (nrem < 16) {
63
+ __mmask64 mask = (1ULL << (nrem * 4)) - 1;
64
+ _mm512_mask_storeu_epi8(dst, mask, r0);
65
+ } else if (nrem == 16) {
66
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0);
67
+ } else if (nrem < 32) {
68
+ int n_bytes1 = 64;
69
+ int n_bytes2 = (nrem * 4) - n_bytes1;
70
+ __mmask64 mask = (1ULL << n_bytes2) - 1;
71
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0);
72
+ _mm512_mask_storeu_epi8(reinterpret_cast<__m512i*>(dst + 64), mask, r1);
73
+ } else if (nrem == 32) {
74
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0);
75
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1);
76
+ } else if (nrem < 48) {
77
+ int n_bytes1 = 64 * 2;
78
+ int n_bytes2 = (nrem * 4) - n_bytes1;
79
+ __mmask64 mask = (1ULL << n_bytes2) - 1;
80
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0);
81
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1);
82
+ _mm512_mask_storeu_epi8(reinterpret_cast<__m512i*>(dst + 64 * 2), mask, r2);
83
+ } else if (nrem == 48) {
84
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0);
85
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1);
86
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64 * 2), r2);
87
+ } else if (nrem < 64) {
88
+ int n_bytes1 = 64 * 3;
89
+ int n_bytes2 = (nrem * 4) - n_bytes1;
90
+ __mmask64 mask = (1ULL << n_bytes2) - 1;
91
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0);
92
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1);
93
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64 * 2), r2);
94
+ _mm512_mask_storeu_epi8(reinterpret_cast<__m512i*>(dst + 64 * 3), mask, r3);
95
+ } else {
96
+ // normal case, nrem == 64
97
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0);
98
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1);
99
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64 * 2), r2);
100
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64 * 3), r3);
101
+ }
102
+ #else
103
+ TORCH_CHECK(
104
+ false,
105
+ "transpose_pad_4x64_block is only supported when AVX-512 is supported")
106
+ #endif
107
+ }
108
+
109
+ // Reorder [K, N] → [K/4, N, 4] (VNNI4-style layout for bit8)
110
+ template <typename scalar_t, typename = std::enable_if_t<sizeof(scalar_t) == 1>>
111
+ static inline void pack_vnni4(
112
+ const scalar_t* src,
113
+ scalar_t* dst,
114
+ int64_t ld_src,
115
+ int64_t K,
116
+ int64_t N) {
117
+ #if defined(CPU_CAPABILITY_AVX512)
118
+ int64_t bk = 0;
119
+ int64_t _K = K / 4 * 4;
120
+ int64_t _N = N / 64 * 64;
121
+ for (; bk < _K; bk += 4) {
122
+ int64_t bn = 0;
123
+ for (; bn < _N; bn += 64) {
124
+ transpose_pad_4x64_block(
125
+ src + bk * ld_src + bn, dst + bk * N + bn * 4, ld_src);
126
+ }
127
+ int64_t nrem = N - bn;
128
+ if (nrem > 0) {
129
+ transpose_pad_4x64_block(
130
+ src + bk * ld_src + bn, dst + bk * N + bn * 4, ld_src, 4, nrem);
131
+ }
132
+ }
133
+
134
+ // Handle leftover K rows (< 4)
135
+ if (K % 4 != 0) {
136
+ int krem = K - bk;
137
+ int64_t bn = 0;
138
+ for (; bn < _N; bn += 64) {
139
+ transpose_pad_4x64_block(
140
+ src + bk * ld_src + bn, dst + bk * N + bn * 4, ld_src, krem);
141
+ }
142
+ int64_t nrem = N - bn;
143
+ if (nrem > 0) {
144
+ transpose_pad_4x64_block(
145
+ src + bk * ld_src + bn, dst + bk * N + bn * 4, ld_src, krem, nrem);
146
+ }
147
+ }
148
+ #else
149
+ TORCH_CHECK(false, "pack_vnni4 is only supported when AVX-512 is supported")
150
+ #endif
151
+ }
152
+
153
+ // This is a helper function for transpose_pack_vnni4
154
+ // Transform a [4, 16] block (with incontiguous output)
155
+ // Src:
156
+ // a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 a11 a12 a13 a14 a15 a16
157
+ // b1 b2 b3 b4 b5 b6 b7 b8 b9 b10 b11 b12 b13 b14 b15 b16
158
+ // c1 c2 c3 c4 c5 c6 c7 c8 c9 c10 c11 c12 c13 c14 c15 c16
159
+ // d1 d2 d3 d4 d5 d6 d7 d8 d9 d10 d11 d12 d13 d14 d15 d16
160
+ // Dst:
161
+ // a1 a2 a3 a4 b1 b2 b3 b4 c1 c2 c3 c4 d1 d2 d3 d4
162
+ // a5 a6 a7 a8 b5 b6 b7 b8 c5 c6 c7 c8 d5 d6 d7 d8
163
+ // a9 a10 a11 a12 b9 b10 b11 b12 c9 c10 c11 c12 d9 d10 d11 d12
164
+ // a13 a14 a15 a16 b13 b14 b15 b16 c13 c14 c15 c16 d13 d14 d15 d16
165
+ template <typename scalar_t, typename = std::enable_if_t<sizeof(scalar_t) == 1>>
166
+ static inline void transpose_vnni4_pad_4x16_block(
167
+ const scalar_t* src,
168
+ scalar_t* dst,
169
+ int64_t ld_src,
170
+ int64_t ld_dst,
171
+ int krem = 4) {
172
+ #if defined(CPU_CAPABILITY_AVX512)
173
+ __m128i r[4];
174
+ for (int i = 0; i < krem; ++i) {
175
+ r[i] = _mm_loadu_si128(reinterpret_cast<const __m128i*>(src + i * ld_src));
176
+ }
177
+ for (int i = krem; i < 4; ++i) {
178
+ r[i] = _mm_setzero_si128();
179
+ }
180
+
181
+ // Transpose 4x16 bytes using unpack and shuffle
182
+ __m128i t0 = _mm_unpacklo_epi32(r[0], r[1]);
183
+ __m128i t1 = _mm_unpackhi_epi32(r[0], r[1]);
184
+ __m128i t2 = _mm_unpacklo_epi32(r[2], r[3]);
185
+ __m128i t3 = _mm_unpackhi_epi32(r[2], r[3]);
186
+
187
+ __m128i r0 = _mm_unpacklo_epi64(t0, t2);
188
+ __m128i r1 = _mm_unpackhi_epi64(t0, t2);
189
+ __m128i r2 = _mm_unpacklo_epi64(t1, t3);
190
+ __m128i r3 = _mm_unpackhi_epi64(t1, t3);
191
+
192
+ // Store output
193
+ if (krem == 4) {
194
+ // normal case
195
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), r0);
196
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + ld_dst), r1);
197
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + ld_dst * 2), r2);
198
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + ld_dst * 3), r3);
199
+ } else {
200
+ // masked case
201
+ __mmask16 mask = (1ULL << (krem * 4)) - 1;
202
+ _mm_mask_storeu_epi8(dst, mask, r0);
203
+ _mm_mask_storeu_epi8(reinterpret_cast<__m128i*>(dst + ld_dst), mask, r1);
204
+ _mm_mask_storeu_epi8(
205
+ reinterpret_cast<__m128i*>(dst + ld_dst * 2), mask, r2);
206
+ _mm_mask_storeu_epi8(
207
+ reinterpret_cast<__m128i*>(dst + ld_dst * 3), mask, r3);
208
+ }
209
+ #else
210
+ TORCH_CHECK(
211
+ false,
212
+ "transpose_vnni4_pad_4x16_block is only supported when AVX-512 is supported")
213
+ #endif
214
+ }
215
+
216
+ // Do the transpose packing fusion with VNNI4
217
+ // Reorder [K, N] → [N/4, K, 4] (VNNI4-style layout for bit8)
218
+ template <typename scalar_t, typename = std::enable_if_t<sizeof(scalar_t) == 1>>
219
+ static inline void transpose_pack_vnni4(
220
+ const scalar_t* src,
221
+ scalar_t* dst,
222
+ int64_t ld_src,
223
+ int64_t K,
224
+ int64_t N) {
225
+ #if defined(CPU_CAPABILITY_AVX512)
226
+ TORCH_CHECK(
227
+ N % 16 == 0, "N needs to be multiple of 16 for transpose_pack_vnni4");
228
+ int64_t bk = 0;
229
+ int64_t _K = K / 4 * 4;
230
+ for (; bk < _K; bk += 4) {
231
+ int64_t bn = 0;
232
+ for (; bn < N; bn += 16) {
233
+ transpose_vnni4_pad_4x16_block(
234
+ src + bk * ld_src + bn, dst + bn * K + bk * 4, ld_src, K * 4);
235
+ }
236
+ }
237
+
238
+ // Handle leftover K rows (< 4)
239
+ if (K % 4 != 0) {
240
+ int krem = K - bk;
241
+ int64_t bn = 0;
242
+ for (; bn < N; bn += 16) {
243
+ transpose_vnni4_pad_4x16_block(
244
+ src + bk * ld_src + bn, dst + bn * K + bk * 4, ld_src, K * 4, krem);
245
+ }
246
+ }
247
+ #else
248
+ TORCH_CHECK(
249
+ false, "transpose_pack_vnni4 is only supported when AVX-512 is supported")
250
+ #endif
251
+ }
252
+
253
+ } // namespace CPU_CAPABILITY
254
+ } // namespace at::vec
255
+
256
+ #else
257
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
258
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/ADInterpreters.h ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ #include <ATen/functorch/Interpreter.h>
4
+
5
+ namespace at::functorch {
6
+
7
+ // These are the interpreters for our AD transforms
8
+ // (grad, vjp and jvp).
9
+ // See NOTE: [functorch interpreter stack] for more details.
10
+
11
+ struct TORCH_API GradInterpreterPtr {
12
+ explicit GradInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Grad); }
13
+ TransformType key() const { return base_->key(); }
14
+ int64_t level() const { return base_->level(); }
15
+ void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
16
+ void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
17
+ bool prevGradMode() const {
18
+ return std::get<GradInterpreterMeta>(base_->meta()).prevGradMode_;
19
+ }
20
+ Tensor lift(const Tensor& tensor) const;
21
+ private:
22
+ const Interpreter* base_;
23
+ };
24
+
25
+ struct TORCH_API JvpInterpreterPtr {
26
+ explicit JvpInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Jvp); }
27
+ TransformType key() const { return base_->key(); }
28
+ int64_t level() const { return base_->level(); }
29
+ void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
30
+ void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
31
+ bool prevFwdGradMode() const {
32
+ return std::get<JvpInterpreterMeta>(base_->meta()).prevFwdGradMode_;
33
+ }
34
+ Tensor lift(const Tensor& tensor) const;
35
+ private:
36
+ const Interpreter* base_;
37
+ };
38
+
39
+ } // namespace at::functorch
40
+
41
+ #else
42
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
43
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchedFallback.h ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ // Copyright (c) Facebook, Inc. and its affiliates.
3
+ // All rights reserved.
4
+ //
5
+ // This source code is licensed under the BSD-style license found in the
6
+ // LICENSE file in the root directory of this source tree.
7
+
8
+ #pragma once
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/core/op_registration/op_registration.h>
11
+ #include <torch/library.h>
12
+
13
+ namespace at::functorch {
14
+
15
+ // This file contains code for the vmap fallback (also known as the
16
+ // BatchedTensor fallback or the Batched fallback). This code runs
17
+ // when an operation doesn't have a batching rule implemented.
18
+
19
+ // If an operator doesn't have a batching rule implemented then we fallback
20
+ // to this implementation. The fallback doesn't work on out= variants or
21
+ // view operations; that is, it works for out-of-place operations and
22
+ // in-place non-view operations.
23
+ //
24
+ // For out-of-place operations, the fallback effectively takes all of the
25
+ // BatchedTensors in `stack`, slices them, and runs `op` on all of the
26
+ // corresponding slices to produce slices of the outputs. The output slices
27
+ // then get `torch.stack`ed to create the
28
+ // final returns.
29
+ //
30
+ // The performance of the fallback is not very good because it introduces an
31
+ // extra copy from stacking the sliced outputs. Because of this, we prefer to
32
+ // write batching rules for operators whenever possible.
33
+ void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
34
+ void batchedNestedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
35
+
36
+ void vmapErrorFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
37
+
38
+ // The vmap fallback emits a warning by default, but it may be disabled if
39
+ // the user finds it to be too annoying.
40
+ TORCH_API bool isVmapFallbackWarningEnabled();
41
+ TORCH_API void setVmapFallbackWarningEnabled(bool enabled);
42
+
43
+ // Used for testing. The vmap fallback is enabled by default. When it is disabled,
44
+ // it raises an error.
45
+ TORCH_API bool isVmapFallbackEnabled();
46
+ TORCH_API void setVmapFallbackEnabled(bool enabled);
47
+
48
+ template <typename A> A vector_to_result(const std::vector<IValue>& buffer) {
49
+ return buffer[0].to<A>();
50
+ }
51
+ template <typename A, typename B> std::tuple<A, B> vector_to_result(const std::vector<IValue>& buffer) {
52
+ return std::make_tuple(buffer[0].to<A>(), buffer[1].to<B>());
53
+ }
54
+ template <typename A, typename B, typename C> std::tuple<A, B, C> vector_to_result(const std::vector<IValue>& buffer) {
55
+ return std::make_tuple(buffer[0].to<A>(), buffer[1].to<B>(), buffer[2].to<B>());
56
+ }
57
+
58
+ // slow_fallback is a way to call the vmap fallback inside some boxed kernel.
59
+ // There is probably some better way to metaprogram this.
60
+ template <typename Ret>
61
+ Ret slow_fallback(const c10::OperatorHandle& op, ArrayRef<IValue> args) {
62
+ std::vector<IValue> stack(args.begin(), args.end());
63
+ batchedTensorForLoopFallback(op, &stack);
64
+ return vector_to_result<Ret>(stack);
65
+ }
66
+
67
+ template <typename A, typename B>
68
+ std::tuple<A, B> slow_fallback(const c10::OperatorHandle& op, ArrayRef<IValue> args) {
69
+ std::vector<IValue> stack(args.begin(), args.end());
70
+ batchedTensorForLoopFallback(op, &stack);
71
+ return vector_to_result<A, B>(stack);
72
+ }
73
+
74
+ template <typename A, typename B, typename C>
75
+ std::tuple<A, B, C> slow_fallback(const c10::OperatorHandle& op, ArrayRef<IValue> args) {
76
+ std::vector<IValue> stack(args.begin(), args.end());
77
+ batchedTensorForLoopFallback(op, &stack);
78
+ return vector_to_result<A, B, C>(stack);
79
+ }
80
+
81
+
82
+ } // namespace at::functorch
83
+
84
+ #else
85
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
86
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchedTensorImpl.h ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ // Copyright (c) Facebook, Inc. and its affiliates.
3
+ // All rights reserved.
4
+ //
5
+ // This source code is licensed under the BSD-style license found in the
6
+ // LICENSE file in the root directory of this source tree.
7
+
8
+ #pragma once
9
+
10
+ #include <bitset>
11
+
12
+ #include <ATen/ArrayRef.h>
13
+ #include <ATen/SmallVector.h>
14
+ #include <ATen/Tensor.h>
15
+
16
+ namespace at::functorch {
17
+
18
+ using Tensor = at::Tensor;
19
+
20
+ // We assume this in a few other places in the codebase,
21
+ // but there isn't a centralized definition.
22
+ constexpr int64_t kVmapMaxTensorDims = 64;
23
+
24
+ // The valid vmap levels range from [0, 64). This effectively means that we
25
+ // support a maximum of 64 nested vmaps.
26
+ constexpr int64_t kVmapNumLevels = 64;
27
+
28
+ // Store this number of elements of BatchDims on the stack. Most people will
29
+ // probably use <= 5 nested vmaps, but adjust this number as necessary.
30
+ constexpr int64_t kBatchDimsStackSize = 5;
31
+
32
+ // A BatchedTensorImpl holds an underlying Tensor and a single batch dim
33
+ // NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
34
+ // BatchedTensorImpl.
35
+ //
36
+ // The batch dimensions are treated as being "private"; they are not user-visible.
37
+ // For example, in the following Tensor,
38
+ // bt = BatchedTensorImpl(ones(2, 3, 5, 7), lvl=1, dim=0)
39
+ // dimension 0 is batch dimension.
40
+ //
41
+ // bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
42
+ // dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7) tensor.
43
+ struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
44
+ explicit BatchedTensorImpl(at::DispatchKeySet key_set, Tensor value, int64_t dim, int64_t level);
45
+
46
+ // Returns batch dimension of this tensor
47
+ int64_t bdim() const { return bdim_; }
48
+
49
+ // Returns batch dimension of this tensor
50
+ int64_t level() const { return level_; }
51
+
52
+ // BatchedTensorImpl wraps a Tensor
53
+ const Tensor& value() const { return value_; }
54
+
55
+ // Given a public dimension index, return the dimension index in the underlying
56
+ // value() tensor.
57
+ // For example, if we have
58
+ // bt = BatchedTensorImpl(ones(2, 3, 5, 7), lvl=1, dim=0)
59
+ // bt.actualDim(0) -> 1
60
+ // bt.actualDim(1) -> 2
61
+ // bt.actualDim(2) -> 3
62
+ // bt.actualDim(3) -> Error
63
+ int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
64
+
65
+ IntArrayRef sizes_custom() const override;
66
+ SymIntArrayRef sym_sizes_custom() const override;
67
+ int64_t size_custom(int64_t d) const override;
68
+ c10::SymInt sym_size_custom(int64_t d) const override;
69
+ // We have to override this because we opted into CustomStrides
70
+ IntArrayRef strides_custom() const override;
71
+ SymIntArrayRef sym_strides_custom() const override;
72
+ // Override a bunch of methods inherited from TensorImpl to return error messages.
73
+ c10::SymBool sym_is_contiguous_custom(at::MemoryFormat memory_format) const override;
74
+ void set_size(int64_t dim, int64_t new_size) override;
75
+ void set_stride(int64_t dim, int64_t new_stride) override;
76
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
77
+ const c10::VariableVersion& version_counter,
78
+ bool allow_tensor_metadata_change) const override;
79
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
80
+ c10::VariableVersion&& version_counter,
81
+ bool allow_tensor_metadata_change) const override;
82
+ void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
83
+ #ifdef DEBUG
84
+ bool has_storage() const override;
85
+ #endif
86
+
87
+ void refreshTensorMetadata();
88
+
89
+ // Used in torchdim. torchdim uses non-lexical BatchedTensor; the way it
90
+ // accomplishes this is a hack where it is able to modify the levels of
91
+ // BatchedTensor to match the level of the current vmap transform.
92
+ void _unsafe_set_level(int64_t level) {
93
+ level_ = level;
94
+ }
95
+
96
+ // Used in batching rule for in-place view operations that can change
97
+ // the index of the bdim (think squeeze_, unsqueeze_)
98
+ void unsafe_set_bdim(int64_t bdim) {
99
+ // NB: you MUST call refreshTensorMetadata after doing this.
100
+ bdim_ = bdim;
101
+ }
102
+ private:
103
+ // see NOTE: [BatchedTensorImpl levels invariant]
104
+ void checkInvariants() const;
105
+ const char* tensorimpl_type_name() const override;
106
+
107
+ Tensor value_;
108
+
109
+ int64_t level_;
110
+ int64_t bdim_;
111
+ };
112
+
113
+ // NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
114
+ // BatchedTensorImpl.
115
+ inline bool isBatchedTensor(const Tensor& tensor) {
116
+ return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::FuncTorchBatched) ||
117
+ tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::BatchedNestedTensor);
118
+ }
119
+
120
+ // It is unsafe to call this on a Tensor that is not backed by a
121
+ // BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible.
122
+ inline BatchedTensorImpl* unsafeGetBatchedImpl(const Tensor& tensor) {
123
+ return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
124
+ }
125
+
126
+ inline BatchedTensorImpl* maybeGetBatchedImpl(const Tensor& tensor) {
127
+ if (!isBatchedTensor(tensor)) {
128
+ return nullptr;
129
+ }
130
+ return unsafeGetBatchedImpl(tensor);
131
+ }
132
+
133
+ // Returns a bitset. If bit i is set, then that means dim i is a batchdim.
134
+ inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset(int64_t dim) {
135
+ std::bitset<kVmapMaxTensorDims> is_bdim;
136
+ is_bdim.set(dim);
137
+ return is_bdim;
138
+ }
139
+
140
+ // Creates a bitset for the given level
141
+ inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(int64_t level) {
142
+ std::bitset<kVmapNumLevels> result;
143
+ result.set(level);
144
+ return result;
145
+ }
146
+
147
+ // Use this to construct a BatchedTensor from a regular Tensor
148
+ TORCH_API Tensor makeBatched(Tensor tensor, int64_t dim, int64_t level);
149
+
150
+ // Adds a batch dim to `tensor`, returning a BatchedTensor
151
+ TORCH_API Tensor addBatchDim(Tensor tensor, int64_t dim, int64_t level);
152
+
153
+ // Certain dispatch keys must be propagated to the BatchedTensor (or, in general,
154
+ // any wrapper Tensor subclasses). This is because there are methods on Tensor
155
+ // that skip dispatch and check for the presence of a dispatch key (e.g. is_cpu()).
156
+ // TODO: should probably contain more (or all?) backend keys
157
+ constexpr DispatchKeySet kKeysToPropagateToWrapper({
158
+ DispatchKey::Negative,
159
+ DispatchKey::Conjugate,
160
+ DispatchKey::XLA,
161
+ DispatchKey::XPU,
162
+ DispatchKey::HPU,
163
+ DispatchKey::CUDA,
164
+ DispatchKey::CPU,
165
+ DispatchKey::PrivateUse1,
166
+ DispatchKey::SparseCPU,
167
+ DispatchKey::SparseCUDA,
168
+ DispatchKey::SparseCsrCPU,
169
+ DispatchKey::SparseCsrCUDA,
170
+ });
171
+
172
+ inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) {
173
+ auto key_set = tensor.unsafeGetTensorImpl()->key_set();
174
+ return key_set & kKeysToPropagateToWrapper;
175
+ }
176
+
177
+ } // namespace at::functorch
178
+
179
+ #else
180
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
181
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/DynamicLayer.h ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ // Copyright (c) Facebook, Inc. and its affiliates.
3
+ // All rights reserved.
4
+ //
5
+ // This source code is licensed under the BSD-style license found in the
6
+ // LICENSE file in the root directory of this source tree.
7
+
8
+ #pragma once
9
+ #include <ATen/functorch/Macros.h>
10
+ #include <c10/core/DispatchKey.h>
11
+ #include <ATen/core/function_schema.h>
12
+ #include <optional>
13
+ #include <c10/core/impl/LocalDispatchKeySet.h>
14
+ #include <ATen/functorch/Interpreter.h>
15
+ #include <ATen/functorch/VmapInterpreter.h>
16
+ #include <ATen/functorch/ADInterpreters.h>
17
+ #include <ATen/functorch/FunctionalizeInterpreter.h>
18
+
19
+ // Forward declared
20
+ namespace c10 { struct AutogradMetaInterface; }
21
+
22
+ namespace at::functorch {
23
+
24
+ // This file contains the implementation of functorch's interpreter stack.
25
+ // See NOTE: [functorch interpreter stack] first before reading on.
26
+ //
27
+ // NB: the functorch interpreter stack is also referred to as:
28
+ // - the "dynamic layer stack" -- an older name for "interpreter" was
29
+ // "dynamic layer".
30
+ // - the "functorch mode stack". You can think of each functorch transform as a
31
+ // "mode" (in the same sense as torch_dispatch mode or torch_function mode),
32
+ // and functorch being an implementation of a "mode stack" where the modes
33
+ // may be arbitrary composed.
34
+
35
+ // DynamicLayer is basically the same thing as an Interpreter.
36
+ // It represents a functorch transform and it holds an Interpreter,
37
+ // which contains metadata related to the transform and instructions on
38
+ // how to perform the transform.
39
+ //
40
+ // TODO: we can excise DynamicLayer in favor of Interpreter,
41
+ // But I am going to leave it for now as a compatibility shim to avoid
42
+ // needing to refactor a lot of callsites...
43
+ struct TORCH_API DynamicLayer {
44
+ explicit DynamicLayer(
45
+ TransformType transform_type,
46
+ int64_t layerId,
47
+ std::optional<c10::SymInt> batchSize = std::nullopt,
48
+ std::optional<RandomnessType> randomness = std::nullopt,
49
+ std::optional<bool> prev_grad_mode = std::nullopt,
50
+ std::optional<bool> pre_fwd_grad_mode = std::nullopt,
51
+ std::optional<bool> functionalize_add_back_views = std::nullopt);
52
+
53
+ TransformType key() const;
54
+ int64_t layerId() const;
55
+
56
+ const Interpreter& interpreter() const { return interpreter_; }
57
+ Interpreter& interpreter() { return interpreter_; }
58
+
59
+ // Only valid for vmap
60
+ c10::SymInt batchSize() const;
61
+ RandomnessType randomness() const;
62
+
63
+ private:
64
+ Interpreter interpreter_;
65
+ };
66
+
67
+ TORCH_API int64_t initAndPushDynamicLayer(
68
+ TransformType transform_type,
69
+ std::optional<c10::SymInt> batch_size = std::nullopt,
70
+ std::optional<RandomnessType> randomness = std::nullopt,
71
+ std::optional<bool> prev_grad_mode = std::nullopt,
72
+ std::optional<bool> prev_fwd_grad_mode = std::nullopt,
73
+ std::optional<bool> functionalize_add_back_views = std::nullopt);
74
+ TORCH_API DynamicLayer popDynamicLayerAndDeleteMetadata();
75
+ TORCH_API std::optional<DynamicLayer> maybeCurrentDynamicLayer();
76
+ TORCH_API const std::vector<DynamicLayer>& getDynamicLayerStack();
77
+ TORCH_API void setDynamicLayerStack(const std::vector<DynamicLayer>& stack);
78
+ TORCH_API void setDynamicLayerFrontBackKeysIncluded(bool included);
79
+
80
+ // NOTE: [Life handles and lexically scoped transforms]
81
+ // functorch transforms are lexically scoped.
82
+ // Given a level, we store a "life handle" that is a boolean that tells us if the
83
+ // transform with that level is active or not.
84
+ //
85
+ // functorch's TensorWrapper (for grad transforms) stores a life handle.
86
+ // If a TensorWrapper escapes from the scope of the transform, then somehow
87
+ // it must know it escaped; it can tell by querying the life handle.
88
+ TORCH_API const std::shared_ptr<bool>& getLifeHandleForLevel(int64_t level);
89
+
90
+ // Returns if an operator is in-place. An operator is inplace if:
91
+ // 1. The first argument is a Tensor and it is being written to
92
+ // 2. The first argument is being returned
93
+ // 3. No other arguments are aliased
94
+ // Here is an example of an in-place operator:
95
+ // add_(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
96
+ TORCH_API bool isInplaceOp(const c10::FunctionSchema& schema);
97
+
98
+ // Given the indices of unwrapped inputs and the schema, this returns the indices of any outputs that should remain unwrapped
99
+ TORCH_API std::optional<size_t> findAliasedOutput(const FunctionSchema& schema, const int64_t immutable_input);
100
+
101
+ TORCH_API Tensor unwrapIfDead(const Tensor& tensor);
102
+ TORCH_API bool isDeadTensorWrapper(const Tensor& tensor);
103
+
104
+ // Pretty printers
105
+ TORCH_API std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer);
106
+ TORCH_API std::ostream& operator<<(std::ostream& os, const std::vector<DynamicLayer>& dynamicLayerStack);
107
+
108
+ // While a functorch transform is active, torch.autograd.function._SingleLevelFunction
109
+ // is disabled by default. The following two APIs are APIs for enabling
110
+ // it. These are not user-facing APIs. We can delete this in the future, but
111
+ // it is useful for debugging when something goes wrong with the
112
+ // autograd.Function <> functorch interaction, which uses _SingleLevelFunction,
113
+ // because it leads to loud errors if something is incorrect.
114
+ TORCH_API void setSingleLevelAutogradFunctionAllowed(bool allowed);
115
+ TORCH_API bool getSingleLevelAutogradFunctionAllowed();
116
+
117
+ // While a functorch grad transform is active, Tensor.requires_grad_() gets
118
+ // disabled. These two functions are the mechanism to controlling that.
119
+ TORCH_API void setInplaceRequiresGradAllowed(bool allowed);
120
+ TORCH_API bool getInplaceRequiresGradAllowed();
121
+
122
+ TORCH_API DynamicLayer popDynamicLayer();
123
+ TORCH_API int64_t pushDynamicLayer(DynamicLayer&& layer);
124
+
125
+ } // namespace at::functorch
126
+
127
+ #else
128
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
129
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/FunctionalizeInterpreter.h ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ #include <ATen/functorch/Interpreter.h>
4
+
5
+ namespace at::functorch {
6
+
7
+ // This is the interpreter that handles the functionalize() transform.
8
+ // See NOTE: [functorch interpreter stack] for more details.
9
+
10
+ struct FunctionalizeInterpreterPtr {
11
+ explicit FunctionalizeInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Functionalize); }
12
+ TransformType key() const { return base_->key(); }
13
+ int64_t level() const { return base_->level(); }
14
+ void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
15
+ void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
16
+ bool functionalizeAddBackViews() const {
17
+ return std::get<FunctionalizeInterpreterMeta>(base_->meta()).functionalizeAddBackViews_;
18
+ }
19
+ private:
20
+ const Interpreter* base_;
21
+ };
22
+
23
+ } // namespace at::functorch
24
+
25
+ #else
26
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
27
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/Interpreter.h ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/functorch/Macros.h>
5
+ #include <ATen/core/dispatch/Dispatcher.h>
6
+ #include <c10/core/impl/LocalDispatchKeySet.h>
7
+ #include <c10/util/Exception.h>
8
+ #include <optional>
9
+ #include <bitset>
10
+ #include <utility>
11
+ #include <variant>
12
+
13
+ #include <nlohmann/json.hpp>
14
+
15
+ namespace at::functorch {
16
+
17
+ // NOTE: [functorch interpreter stack]
18
+ //
19
+ // functorch's dispatching system uses a stack of interpreters.
20
+ // Historically we've referred to this as the "DynamicLayerStack".
21
+ //
22
+ // An interpreter is something that reads in the code it is passed
23
+ // and then executes it. We have a different interpreter per-transform:
24
+ // the "VmapInterpreter" is responsible for reading in operators (like aten::mv)
25
+ // and executing the batched version of it (the batching rule for aten::mv).
26
+ //
27
+ // Concretely, each interpreter is responsible for two things:
28
+ //
29
+ // 1) process(ophandle, stack)
30
+ // Given an operator handle and a stack of arguments, the interpreter is
31
+ // responsible for figuring out how to execute the operation under the semantics
32
+ // of the interpreter. For e.g. VmapInterpreter, this is figuring out how to call
33
+ // the batching rule.
34
+ //
35
+ // The batching rules are stored as kernels on the FuncTorchBatched key, so the way
36
+ // VmapInterpreter calls the batching rule is roughly: (A) exclude all
37
+ // dispatch keys aside from the Batched key, (B) redispatch so we get to the
38
+ // Batched key.
39
+ //
40
+ // 2) sendToNextInterpreter(ophandle, stack)
41
+ // The VmapInterpreter, when it sees aten::mv, will process it into a call to
42
+ // aten::mm. It then needs to send the call to aten::mm to the next interpreter
43
+ // in the interpreter stack.
44
+ //
45
+ // The VmapInterpreter just does this via a call to ophandle.callBoxed(stack)
46
+ // and most Interpreters will implement it this way.
47
+
48
+ enum class RandomnessType {
49
+ Error, // always errors when calling a random function
50
+ Same, // randomness appears the same across batches
51
+ Different, // randomness appears different across batches
52
+ END
53
+ };
54
+
55
+ enum class TransformType {
56
+ Torch, // Unused
57
+ Vmap,
58
+ Grad, // reverse-mode AD, aka vjp
59
+ Jvp, // forward-mode AD
60
+ Functionalize,
61
+ };
62
+
63
+ std::ostream& operator<<(std::ostream& os, const TransformType& t);
64
+
65
+ // NOTE: [Interpreter "subclassing" design]
66
+ //
67
+ // How are various Interpreters for different transforms (vmap, grad, ...)
68
+ // implemented?
69
+ //
70
+ // Accessing interpreters is in the hot-path of functorch so we have a constraint
71
+ // that this code must be as fast as possible.
72
+ //
73
+ // As a result, we stay away from virtual methods and this causes our code
74
+ // to look a little funny.
75
+ //
76
+ // `Interpreter` is the struct for Interpreters. It holds ALL of the
77
+ // relevant information (what type of interpreter it is and the metadata).
78
+ // Metadata for each interpreter is represented as a Union (std::variant)
79
+ // of all possible metadata (VmapInterpreterMeta, GradInterpreterMeta, ...).
80
+ //
81
+ // Given an Interpreter, how do I get a "VmapInterpreter"? You may wish to do this
82
+ // if you want to access the metadata fields (like batchSize and randomness).
83
+ //
84
+ // Each type of interpreter (e.g. Vmap) has a convenience struct
85
+ // (e.g. VmapInterpreterPtr) associated with it.
86
+ //
87
+ // Construct the convenience struct with VmapInterpreterPtr(Interpreter*),
88
+ // and then one can access methods on VmapInterpreterPtr like so:
89
+ // >>> VmapInterpreterPtr(&interpreter).batchSize()
90
+ //
91
+ // Finally, Interpreter::process switches on the type of the interpreter
92
+ // and calls one of {Transform}Interpreter::processImpl under the hood.
93
+ // Same for Interpreter::sendToNextInterpreter :)
94
+
95
+ struct VmapInterpreterMeta {
96
+ explicit VmapInterpreterMeta(c10::SymInt batchSize, RandomnessType randomness) :
97
+ batchSize_(std::move(batchSize)), randomness_(randomness) {}
98
+
99
+ c10::SymInt batchSize_;
100
+ RandomnessType randomness_;
101
+
102
+ VmapInterpreterMeta() = default;
103
+ VmapInterpreterMeta(const VmapInterpreterMeta&) = default;
104
+ VmapInterpreterMeta(VmapInterpreterMeta&&) = default;
105
+ VmapInterpreterMeta& operator=(const VmapInterpreterMeta&) = default;
106
+ VmapInterpreterMeta& operator=(VmapInterpreterMeta&&) = default;
107
+ ~VmapInterpreterMeta() = default;
108
+
109
+ template <typename T>
110
+ friend void to_json(T& json_j, const VmapInterpreterMeta& json_t) {
111
+ TORCH_CHECK(
112
+ !json_t.batchSize_.is_heap_allocated(),
113
+ "Serialization for heap-allocated SymInt is not implemented yet"
114
+ );
115
+ json_j["batchSize"] = json_t.batchSize_.as_int_unchecked();
116
+ json_j["randomness"] = static_cast<int64_t>(json_t.randomness_);
117
+ }
118
+
119
+ template <typename T>
120
+ friend void from_json(const T& json_j, VmapInterpreterMeta& json_t) {
121
+ json_t.batchSize_ = c10::SymInt(SymInt::Unchecked::UNCHECKED, json_j["batchSize"]);
122
+ json_t.randomness_ = static_cast<RandomnessType>(json_j["randomness"]);
123
+ }
124
+ };
125
+
126
+ struct GradInterpreterMeta {
127
+ explicit GradInterpreterMeta(bool prevGradMode): prevGradMode_(prevGradMode) {}
128
+ GradInterpreterMeta() = default;
129
+ GradInterpreterMeta(const GradInterpreterMeta&) = default;
130
+ GradInterpreterMeta(GradInterpreterMeta&&) = default;
131
+ GradInterpreterMeta& operator=(const GradInterpreterMeta&) = default;
132
+ GradInterpreterMeta& operator=(GradInterpreterMeta&&) = default;
133
+ ~GradInterpreterMeta() = default;
134
+
135
+ bool prevGradMode_;
136
+ template <typename T>
137
+ friend void to_json(T& json_j, const GradInterpreterMeta& json_t) {
138
+ json_j["prevGradMode"] = json_t.prevGradMode_;
139
+ }
140
+
141
+ template <typename T>
142
+ friend void from_json(const T& json_j, GradInterpreterMeta& json_t) {
143
+ json_t.prevGradMode_ = json_j["prevGradMode"];
144
+ }
145
+ };
146
+
147
+ struct JvpInterpreterMeta {
148
+ explicit JvpInterpreterMeta(bool prevFwdGradMode) : prevFwdGradMode_(prevFwdGradMode) {}
149
+ JvpInterpreterMeta() = default;
150
+ JvpInterpreterMeta(const JvpInterpreterMeta&) = default;
151
+ JvpInterpreterMeta(JvpInterpreterMeta&&) = default;
152
+ JvpInterpreterMeta& operator=(const JvpInterpreterMeta&) = default;
153
+ JvpInterpreterMeta& operator=(JvpInterpreterMeta&&) = default;
154
+ ~JvpInterpreterMeta() = default;
155
+
156
+ bool prevFwdGradMode_;
157
+ template <typename T>
158
+ friend void to_json(T& json_j, const JvpInterpreterMeta& json_t) {
159
+ json_j["prevFwdGradMode"] = json_t.prevFwdGradMode_;
160
+ }
161
+
162
+ template <typename T>
163
+ friend void from_json(const T& json_j, JvpInterpreterMeta& json_t) {
164
+ json_t.prevFwdGradMode_ = json_j["prevFwdGradMode"];
165
+ }
166
+ };
167
+
168
+ struct FunctionalizeInterpreterMeta {
169
+ explicit FunctionalizeInterpreterMeta(bool functionalizeAddBackViews) :
170
+ functionalizeAddBackViews_(functionalizeAddBackViews) {}
171
+ FunctionalizeInterpreterMeta() = default;
172
+ FunctionalizeInterpreterMeta(const FunctionalizeInterpreterMeta&) = default;
173
+ FunctionalizeInterpreterMeta(FunctionalizeInterpreterMeta&&) = default;
174
+ FunctionalizeInterpreterMeta& operator=(const FunctionalizeInterpreterMeta&) = default;
175
+ FunctionalizeInterpreterMeta& operator=(FunctionalizeInterpreterMeta&&) = default;
176
+ ~FunctionalizeInterpreterMeta() = default;
177
+
178
+ bool functionalizeAddBackViews_;
179
+ template <typename T>
180
+ friend void to_json(T& json_j, const FunctionalizeInterpreterMeta& json_t) {
181
+ json_j["functionalizeAddBackViews"] = json_t.functionalizeAddBackViews_;
182
+ }
183
+
184
+ template <typename T>
185
+ friend void from_json(const T& json_j, FunctionalizeInterpreterMeta& json_t) {
186
+ json_t.functionalizeAddBackViews_ = json_j["functionalizeAddBackViews"];
187
+ }
188
+ };
189
+
190
+ typedef std::variant<
191
+ int64_t,
192
+ GradInterpreterMeta,
193
+ JvpInterpreterMeta,
194
+ VmapInterpreterMeta,
195
+ FunctionalizeInterpreterMeta
196
+ > InterpreterMeta;
197
+
198
+
199
+ struct Interpreter {
200
+ // factory functions
201
+ static Interpreter Vmap(int64_t level, c10::SymInt batchSize, RandomnessType randomness) {
202
+ return Interpreter(TransformType::Vmap, level, VmapInterpreterMeta(std::move(batchSize), randomness));
203
+ }
204
+ static Interpreter Grad(int64_t level, bool prevGradMode) {
205
+ return Interpreter(TransformType::Grad, level, GradInterpreterMeta(prevGradMode));
206
+ }
207
+ static Interpreter Jvp(int64_t level, bool prevFwdGradMode) {
208
+ return Interpreter(TransformType::Jvp, level, JvpInterpreterMeta(prevFwdGradMode));
209
+ }
210
+ static Interpreter Functionalize(int64_t level, bool functionalizeAddBackViews) {
211
+ return Interpreter(TransformType::Functionalize, level, FunctionalizeInterpreterMeta(functionalizeAddBackViews));
212
+ }
213
+
214
+ // methods
215
+ TransformType key() const { return type_; }
216
+ int64_t level() const { return level_; }
217
+ const InterpreterMeta& meta() const { return meta_; }
218
+
219
+ void process(const c10::OperatorHandle& op, torch::jit::Stack* stack);
220
+ void sendToNextInterpreter(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
221
+
222
+ void saveLocalDispatchKeySet(c10::impl::LocalDispatchKeySet keyset) {
223
+ TORCH_INTERNAL_ASSERT(!savedLocalDispatchKeySet_.has_value());
224
+ savedLocalDispatchKeySet_ = keyset;
225
+ }
226
+ void clearSavedLocalDispatchKeySet() {
227
+ TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value());
228
+ savedLocalDispatchKeySet_ = std::nullopt;
229
+ }
230
+ c10::impl::LocalDispatchKeySet getSavedLocalDispatchKeySet() const {
231
+ TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value());
232
+ return *savedLocalDispatchKeySet_;
233
+ }
234
+
235
+ // An Interpreter is alive if we are currently inside the ongoing transform
236
+ // for the interpreter. For example, vmap(f)(x); inside of f, the vmap's
237
+ // corresponding Interpreter is alive, even when it is not on the DynamicLayerStack.
238
+ bool is_alive() const {
239
+ return *is_alive_;
240
+ }
241
+ const std::shared_ptr<bool>& is_alive_ptr() const {
242
+ return is_alive_;
243
+ }
244
+ void set_is_alive(bool alive) {
245
+ *is_alive_ = alive;
246
+ }
247
+
248
+ // Please don't use this
249
+ explicit Interpreter() = default;
250
+
251
+ template <typename T>
252
+ friend void to_json(T& json_j, const Interpreter& json_t) {
253
+ json_j["type"] = static_cast<int64_t>(json_t.type_);
254
+ json_j["level"] = json_t.level_;
255
+ if (json_t.savedLocalDispatchKeySet_) {
256
+ json_j["savedLocalDispatchKeySet"] = {
257
+ {"included", json_t.savedLocalDispatchKeySet_->included_.raw_repr()},
258
+ {"excluded", json_t.savedLocalDispatchKeySet_->excluded_.raw_repr()}
259
+ };
260
+ } else {
261
+ json_j["savedLocalDispatchKeySet"] = nlohmann::json();
262
+ }
263
+ json_j["is_alive"] = *json_t.is_alive_;
264
+ std::visit([&](auto&& arg) {
265
+ using V = std::decay_t<decltype(arg)>;
266
+ if constexpr (std::is_same_v<V, int64_t>) {
267
+ json_j["meta"] = {{"Torch", arg}};
268
+ } else if constexpr (std::is_same_v<V, GradInterpreterMeta>) {
269
+ json_j["meta"] = {{"Grad", arg}};
270
+ } else if constexpr (std::is_same_v<V, JvpInterpreterMeta>) {
271
+ json_j["meta"] = {{"Jvp", arg}};
272
+ } else if constexpr (std::is_same_v<V, VmapInterpreterMeta>) {
273
+ json_j["meta"] = {{"Vmap", arg}};
274
+ } else if constexpr (std::is_same_v<V, FunctionalizeInterpreterMeta>) {
275
+ json_j["meta"] = {{"Functionalize", arg}};
276
+ } else {
277
+ static_assert(false && sizeof(V), "unknown variant case");
278
+ }
279
+ }, json_t.meta_);
280
+ }
281
+
282
+ template <typename T>
283
+ friend void from_json(const T& json_j, Interpreter& json_t) {
284
+ json_t.type_ = static_cast<TransformType>(json_j["type"]);
285
+ json_t.level_ = json_j["level"];
286
+ auto savedLocalDispatchKeySet = json_j["savedLocalDispatchKeySet"];
287
+ if (savedLocalDispatchKeySet.is_null()) {
288
+ json_t.savedLocalDispatchKeySet_ = std::nullopt;
289
+ } else {
290
+ c10::impl::PODLocalDispatchKeySet pod;
291
+ pod.set_included(DispatchKeySet::from_raw_repr(savedLocalDispatchKeySet["included"].template get<uint64_t>()));
292
+ pod.set_excluded(DispatchKeySet::from_raw_repr(savedLocalDispatchKeySet["excluded"].template get<uint64_t>()));
293
+ json_t.savedLocalDispatchKeySet_ = c10::impl::LocalDispatchKeySet(pod);
294
+ }
295
+ json_t.is_alive_ = std::make_shared<bool>(json_j["is_alive"]);
296
+ auto meta = json_j["meta"];
297
+ if (meta.contains("Torch")) {
298
+ json_t.meta_.emplace<int64_t>(meta["Torch"].template get<int64_t>());
299
+ } else if (meta.contains("Grad")) {
300
+ json_t.meta_.emplace<GradInterpreterMeta>(meta["Grad"].template get<GradInterpreterMeta>());
301
+ } else if (meta.contains("Jvp")) {
302
+ json_t.meta_.emplace<JvpInterpreterMeta>(meta["Jvp"].template get<JvpInterpreterMeta>());
303
+ } else if (meta.contains("Vmap")) {
304
+ json_t.meta_.emplace<VmapInterpreterMeta>(meta["Vmap"].template get<VmapInterpreterMeta>());
305
+ } else if (meta.contains("Functionalize")) {
306
+ json_t.meta_.emplace<FunctionalizeInterpreterMeta>(meta["Functionalize"].template get<FunctionalizeInterpreterMeta>());
307
+ } else {
308
+ TORCH_CHECK(false, "unknown interpreter metadata type");
309
+ }
310
+ }
311
+
312
+ std::string serialize() const {
313
+ return nlohmann::json(*this).dump();
314
+ }
315
+
316
+ static Interpreter deserialize(const std::string& serialized) {
317
+ return nlohmann::json::parse(serialized).get<Interpreter>();
318
+ }
319
+
320
+ private:
321
+ explicit Interpreter(TransformType type, int64_t level, InterpreterMeta meta):
322
+ type_(type), level_(level), is_alive_(std::make_shared<bool>(false)), meta_(std::move(meta)) {}
323
+
324
+ // fields
325
+ TransformType type_{};
326
+ int64_t level_{};
327
+ std::optional<c10::impl::LocalDispatchKeySet> savedLocalDispatchKeySet_;
328
+ std::shared_ptr<bool> is_alive_;
329
+ InterpreterMeta meta_;
330
+ };
331
+
332
+ // Applies the following for-loop:
333
+ // for i in range(begin, end):
334
+ // args[i] = func(args[i])
335
+ void foreachTensorInplace(std::vector<IValue>& args, int64_t begin, int64_t end,
336
+ std::function<Tensor(const Tensor&)> func);
337
+
338
+ // Applies the following for-loop:
339
+ // for i in range(begin, end):
340
+ // if use_flag_relative[i] == 1: <-- treats use_flag_relative as a bitset
341
+ // args[i] = func(args[i], i - begin, true)
342
+ // args[i] = func(args[i], i - begin)
343
+ void foreachTensorInplaceWithFlag(std::vector<IValue>& args, int64_t begin, int64_t end,
344
+ const std::bitset<64> use_flag_relative, const std::function<Tensor(const Tensor&, bool)>& func);
345
+
346
+ std::vector<int64_t> findUnwrappedInputs(std::vector<IValue>& args, int64_t begin, int64_t end);
347
+
348
+ DispatchKeySet keysToExcludeWhenEnteringDynamicLayer(TransformType key);
349
+
350
+ void setup_dispatch_key_tls(TransformType key, DispatchKeySet include);
351
+
352
+ void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack);
353
+
354
+ } // namespace at::functorch
355
+
356
+ #else
357
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
358
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/PlumbingHelper.h ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ // Copyright (c) Facebook, Inc. and its affiliates.
3
+ // All rights reserved.
4
+ //
5
+ // This source code is licensed under the BSD-style license found in the
6
+ // LICENSE file in the root directory of this source tree.
7
+ #pragma once
8
+ #include <ATen/Tensor.h>
9
+ #include <ATen/functorch/BatchedTensorImpl.h>
10
+ #include <ATen/functorch/DynamicLayer.h>
11
+
12
+ // NOTE: [vmap plumbing]
13
+ //
14
+ // Here's how "batching rules" work.
15
+ // - we register kernels to the Batched key
16
+ // - these kernels have the same signatures as the original operators.
17
+ // For example, at::sin(Tensor self) accepts a Tensor, and the batched kernel
18
+ // must also accept a Tensor
19
+ // - However, it is more natural for users to write a batching rule like the
20
+ // following: sin_batch_rule(Tensor self, std::optional<int> self_bdim)
21
+ // - There is some codegenerated layer (the "plumbing") that wraps the user
22
+ // defined batching rule (e.g. sin_batch_rule) in a kernel that can be
23
+ // registered to the Batched key.
24
+ //
25
+ // The plumbing is responsible for wrapping a batching rule into a form that may
26
+ // be registered as the kernel for the batched key.
27
+
28
+ namespace at::functorch {
29
+
30
+ void vmap_check_escaped(const std::optional<DynamicLayer> &layer, const char* what);
31
+
32
+ // Create a BatchedTensor given a tensor, bdim, and level
33
+ TORCH_API Tensor makeBatched(Tensor tensor, std::optional<int64_t> bdim, int64_t level);
34
+
35
+ // Given a Tensor that may or may not be a BatchedTensor, unwrap it.
36
+ // If `tensor` is not a BatchedTensor, or is a BatchedTensor but the level
37
+ // doesn't match, then this returns (tensor, std::nullopt).
38
+ // Otherwise, it returns (unwrap(tensor), bdim).
39
+ TORCH_API std::tuple<Tensor, std::optional<int64_t>> unwrapTensorAtLevel(const Tensor& tensor, int64_t level);
40
+
41
+ // Creates a vector of BatchedTensor
42
+ TORCH_API std::vector<Tensor> makeBatchedVector(std::vector<Tensor> tensors, std::optional<int64_t> bdim, int64_t level);
43
+
44
+ // Returns True if ANY tensor in tensors is batched at level
45
+ TORCH_API bool isBatchedAtLevel(ITensorListRef tensors, int64_t level);
46
+ TORCH_API bool isBatchedAtLevel(const c10::List<std::optional<Tensor>>& maybe_tensors, int64_t level);
47
+ TORCH_API bool isBatchedAtLevel(const Tensor& tensor, int64_t level);
48
+ TORCH_API bool isBatchedAtLevel(const std::optional<Tensor>& maybe_tensor, int64_t level);
49
+
50
+ // Convenience helper. Returns true if any tensor is batched at level
51
+ TORCH_API bool areAnyBatchedAtLevel(ArrayRef<std::optional<Tensor>> maybe_tensors, int64_t level);
52
+
53
+ inline bool ivalueParticipatesInCurrentLevel(const IValue& ivalue) {
54
+ if (ivalue.isTensor()) {
55
+ auto maybe_level = maybeCurrentDynamicLayer();
56
+ TORCH_INTERNAL_ASSERT(maybe_level.has_value());
57
+ auto current_level = maybe_level->layerId();
58
+ return isBatchedAtLevel(ivalue.toTensor(), current_level);
59
+ }
60
+ // TODO: should really check this
61
+ return false;
62
+ }
63
+
64
+ } // namespace at::functorch
65
+
66
+ #else
67
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
68
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/TensorWrapper.h ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ // Copyright (c) Facebook, Inc. and its affiliates.
3
+ // All rights reserved.
4
+ //
5
+ // This source code is licensed under the BSD-style license found in the
6
+ // LICENSE file in the root directory of this source tree.
7
+
8
+ #pragma once
9
+
10
+ #include <ATen/functorch/Macros.h>
11
+ #include <ATen/Tensor.h>
12
+ #include <ATen/functorch/Interpreter.h>
13
+
14
+ namespace at::functorch {
15
+
16
+ // NOTE: [functorch's TensorWrapper]
17
+ //
18
+ // Taking better suggestions for a name. TensorWrapper is the wrapper Tensor
19
+ // Subclass for functorch's grad-based transforms (grad, vjp, jvp). It is
20
+ // analogous to how vmap uses BatchedTensor as the wrapper Tensor subclass.
21
+ //
22
+ // If you're familiar with the Tensor-Variable merge, TensorWrapper is effectively
23
+ // another Variable.
24
+ //
25
+ // Consider grad(grad(torch.sin))(x). This wraps `x` as TensorWrapper(TensorWrapper(x)).
26
+ // The reason why is so that each TensorWrapper can hold its own AutogradMeta and
27
+ // participate in a **separate** autograd graph.
28
+ //
29
+ // There are alternative designs we could have chosen (e.g. each grad transform
30
+ // stores a weak map of Tensor -> AutogradMeta); the benefit of the TensorWrapper
31
+ // design is that we can reuse existing VariableType kernels (i.e. Autograd kernels)
32
+ // without much modification. Since a TensorWrapper looks like a regular Tensor,
33
+ // the VariableType kernel can pull out the AutogradMeta struct from where it
34
+ // expects and extend the autograd graph
35
+
36
+ struct TORCH_API TensorWrapper : public c10::TensorImpl {
37
+ explicit TensorWrapper(
38
+ c10::DispatchKeySet key_set,
39
+ Tensor value,
40
+ int64_t level,
41
+ std::shared_ptr<bool> is_alive,
42
+ bool is_immutable = false, // if true, this came from an operation that aliases an immutable tensor
43
+ bool use_value_sizes_strides = true);
44
+
45
+ void refreshMetadata();
46
+
47
+ const Tensor& value() const {
48
+ return value_;
49
+ }
50
+ std::optional<int64_t> level() const {
51
+ if (is_alive()) {
52
+ return level_;
53
+ }
54
+ return {};
55
+ }
56
+ bool is_immutable() const {
57
+ return is_immutable_;
58
+ }
59
+ bool is_alive() const;
60
+
61
+ // Overrides necessary for autograd
62
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
63
+ const c10::VariableVersion& version_counter,
64
+ bool allow_tensor_metadata_change) const override;
65
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
66
+ c10::VariableVersion&& version_counter,
67
+ bool allow_tensor_metadata_change) const override;
68
+ void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
69
+
70
+ private:
71
+ const char* tensorimpl_type_name() const override;
72
+ Tensor value_;
73
+ int64_t level_;
74
+ bool is_immutable_;
75
+
76
+ // TensorWrapper receives a boolean flag on whether or not the Grad Interpreter
77
+ // that created it is still alive or not.
78
+ // If the Grad Interpreter is no longer alive then it attempts to behave like
79
+ // a regular Tensor.
80
+ //
81
+ // When we exit the level, this wrapper may be marked as "not alive".
82
+ // Wrappers that are not alive:
83
+ // 1) May still have autograd metadata on them
84
+ // 2) Forward dispatches to the underlying value()
85
+ std::shared_ptr<bool> is_alive_;
86
+ };
87
+
88
+ // There are two variants of makeTensorWrapper: one that accepts a level
89
+ // and one that accepts an Interpreter.
90
+ //
91
+ // The one that accepts a level tries to automatically get the life handle from the
92
+ // interpreter on the DynamicLayerStack.
93
+ // It needs to be used with caution: if the interpreter is not on the
94
+ // DynamicLayerStack, then we won't be able to find the life handle.
95
+ //
96
+ // In practice this isn't a problem: when we're constructing TensorWrapper in
97
+ // Python, the corresponding interpreter is on the stack.
98
+ TORCH_API Tensor makeTensorWrapper(const Tensor& tensor, int64_t level, bool is_immutable=false);
99
+ TORCH_API Tensor makeTensorWrapper(const Tensor& tensor, const Interpreter& interpreter, bool is_immutable=false);
100
+ TORCH_API TensorWrapper* maybeGetTensorWrapper(const Tensor& tensor);
101
+ TORCH_API void dumpTensor(std::ostream & ss, const Tensor& tensor);
102
+ TORCH_API void dumpTensorCout(const Tensor& tensor);
103
+
104
+ } // namespace at::functorch
105
+
106
+ #else
107
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
108
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/Tensor.h>
5
+ #include <c10/core/QScheme.h>
6
+
7
+ #ifdef USE_FBGEMM
8
+ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi")
9
+ #include <fbgemm/Fbgemm.h>
10
+ #include <fbgemm/FbgemmSparse.h>
11
+ #include <ATen/native/ao_sparse/quantized/cpu/packed_params.h>
12
+ C10_DIAGNOSTIC_POP()
13
+
14
+
15
+ namespace ao::sparse {
16
+
17
+ struct TORCH_API PackedLinearWeight
18
+ : public LinearPackedParamsBase {
19
+ PackedLinearWeight(std::unique_ptr<fbgemm::BCSRMatrix<int8_t>> w,
20
+ std::optional<at::Tensor> bias,
21
+ std::vector<int32_t> col_offsets,
22
+ std::vector<float> w_scale,
23
+ std::vector<int32_t> w_zp,
24
+ c10::QScheme q_scheme,
25
+ const int64_t out_features_block_size /* block sparsity size across output_features */,
26
+ const int64_t in_features_block_size /* block sparsity size across input_features */)
27
+ : LinearPackedParamsBase(
28
+ out_features_block_size,
29
+ in_features_block_size),
30
+ w(std::move(w)),
31
+ bias_(std::move(bias)),
32
+ col_offsets(std::move(col_offsets)),
33
+ w_scale(std::move(w_scale)),
34
+ w_zp(std::move(w_zp)),
35
+ q_scheme(q_scheme) {}
36
+ std::unique_ptr<fbgemm::BCSRMatrix<int8_t>> w;
37
+ std::optional<at::Tensor> bias_;
38
+ std::vector<int32_t> col_offsets;
39
+ std::vector<float> w_scale;
40
+ std::vector<int32_t> w_zp;
41
+ c10::QScheme q_scheme;
42
+
43
+ at::Tensor apply(
44
+ const at::Tensor& input,
45
+ double output_scale,
46
+ int64_t output_zero_point) override;
47
+ at::Tensor apply_relu(
48
+ const at::Tensor& input,
49
+ double output_scale,
50
+ int64_t output_zero_point) override;
51
+
52
+ at::Tensor apply_dynamic(const at::Tensor& input) override {
53
+ TORCH_INTERNAL_ASSERT(
54
+ false,
55
+ "Sparse quantized dynamic linear with fused relu is not yet "
56
+ "supported on qnnpack backend.");
57
+ return at::Tensor();
58
+ }
59
+ at::Tensor apply_dynamic_relu(const at::Tensor& input) override {
60
+ TORCH_INTERNAL_ASSERT(
61
+ false,
62
+ "Sparse quantized dynamic linear with fused relu is not yet "
63
+ "supported on qnnpack backend.");
64
+ return at::Tensor();
65
+ }
66
+
67
+ LinearPackedSerializationType unpack() override;
68
+
69
+ BCSRSerializationType serialize() override;
70
+
71
+ static c10::intrusive_ptr<LinearPackedParamsBase> deserialize(
72
+ const BCSRSerializationType& serialized);
73
+
74
+ std::optional<at::Tensor> bias() override {
75
+ return bias_;
76
+ }
77
+
78
+ static c10::intrusive_ptr<LinearPackedParamsBase> prepack(
79
+ const at::Tensor& weight,
80
+ const std::optional<at::Tensor>& bias,
81
+ const int64_t out_features_block_size,
82
+ const int64_t in_features_block_size);
83
+
84
+ private:
85
+ template <bool ReluFused>
86
+ at::Tensor apply_impl(
87
+ const at::Tensor& input,
88
+ double output_scale,
89
+ int64_t output_zero_point);
90
+ };
91
+
92
+ } // namespace ao::sparse
93
+
94
+ #endif // USE_FBGEMM
95
+
96
+ namespace ao::sparse {
97
+ int register_linear_params();
98
+ } // namespace ao::sparse
99
+
100
+ #else
101
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
102
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/packed_params.h ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <cstdint>
5
+
6
+ #include <ATen/core/ivalue.h>
7
+ #include <c10/util/Exception.h>
8
+
9
+ namespace ao::sparse {
10
+
11
+ // <Weight, bias, out_features_block_size, in_features_block_size>
12
+ using LinearPackedSerializationType =
13
+ std::tuple<at::Tensor, std::optional<at::Tensor>, std::vector<int64_t>>;
14
+
15
+ #define SPARSE_LINEAR_PACKED_PARAM_SERIALIZATION_VERSION 2
16
+
17
+ using BCSRSerializationType =
18
+ std::tuple<
19
+ int64_t, // Serialization Version
20
+ std::optional<at::Tensor>, // Bias
21
+ int64_t, // Out Features (Row) Block Size
22
+ int64_t, // In Features (Column) Block Size
23
+ at::Tensor, // Weight Scales (single element vector if per-tensor) (float)
24
+ at::Tensor, // Wrapper for Weight Zero Points (single element vector if per-tensor) (int8_t)
25
+ bool, // Quantization Scheme (true: per tensor, false: per channel)
26
+ at::Tensor, // Wrapper for Row Block Indices (int8_t, int16_t, or int32_t)
27
+ at::Tensor, // Wrapper for Column Block Indices (int8_t, int16_t, or int32_t)
28
+ at::Tensor, // Wrapper for Non-Zero Weight Values, each +128 (uint8_t)
29
+ int64_t, // Number of Output Channels
30
+ int64_t // Number of Input Channels
31
+ >;
32
+
33
+ using BCSR =
34
+ std::tuple<
35
+ std::vector<int8_t>, // Non-Zero Weight Values
36
+ std::vector<int32_t>, // Compressed Row Block Indices
37
+ std::vector<int32_t> // Column Block Indices
38
+ >;
39
+
40
+ struct LinearPackedParamsBase : public torch::jit::CustomClassHolder {
41
+ public:
42
+ LinearPackedParamsBase(
43
+ const int64_t out_features_block_size,
44
+ const int64_t in_features_block_size)
45
+ : out_features_block_size_(out_features_block_size),
46
+ in_features_block_size_(in_features_block_size) {}
47
+
48
+ virtual at::Tensor apply(
49
+ const at::Tensor& input,
50
+ double output_scale,
51
+ int64_t output_zero_point) = 0;
52
+ virtual at::Tensor apply_relu(
53
+ const at::Tensor& input,
54
+ double output_scale,
55
+ int64_t output_zero_point) = 0;
56
+
57
+ virtual at::Tensor apply_dynamic(const at::Tensor& input) = 0;
58
+ virtual at::Tensor apply_dynamic_relu(const at::Tensor& input) = 0;
59
+
60
+ virtual LinearPackedSerializationType unpack() = 0;
61
+
62
+ virtual BCSRSerializationType serialize() = 0;
63
+
64
+ virtual std::optional<at::Tensor> bias() = 0;
65
+
66
+ virtual void set_bias(const std::optional<at::Tensor>& bias) {
67
+ TORCH_CHECK(false, "set_bias is not implemented for this packed parameter type");
68
+ }
69
+
70
+ protected:
71
+ const int64_t out_features_block_size_, in_features_block_size_;
72
+ };
73
+
74
+ } // namespace ao::sparse
75
+
76
+ #else
77
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
78
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/qnnpack_utils.h ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/Tensor.h>
5
+ #include <c10/core/QScheme.h>
6
+
7
+ #ifdef USE_PYTORCH_QNNPACK
8
+ // TODO: Refacto QnnpackUtils.h so as to separate code
9
+ // needed for quantized op from the generic qnnpack specific
10
+ // quantization utilities.
11
+ #include <ATen/native/ao_sparse/quantized/cpu/packed_params.h>
12
+ #include <ATen/native/quantized/cpu/QnnpackUtils.h>
13
+ #include <pack_block_sparse.h>
14
+
15
+ namespace ao::sparse {
16
+
17
+ struct TORCH_API PackedLinearWeightQnnp : public LinearPackedParamsBase {
18
+ PackedLinearWeightQnnp(const at::Tensor& weight, const std::optional<at::Tensor>& bias, const int64_t out_features_block_size /* block sparsity size across output_features */, const int64_t in_features_block_size /* block sparsity size across input_features */);
19
+ explicit PackedLinearWeightQnnp(const BCSRSerializationType& serialized);
20
+ std::optional<at::Tensor> orig_bias_;
21
+ // Separate copy of bias exist so that we can fill in zeros when
22
+ // optional bias does not exist. This is to compy with qnnpack operator that
23
+ // expects bias to be present.
24
+ // In case bias is present bias_ is just a reference to orig_bias_
25
+ at::Tensor bias_;
26
+ c10::QScheme q_scheme_;
27
+ double input_scale_{};
28
+ std::unique_ptr<qnnpack::BCSRMatrix> bcsr_matrix_;
29
+ at::Tensor w_scales_;
30
+ std::vector<uint8_t> w_zero_points_;
31
+ std::vector<float> requantization_scales_;
32
+ std::unique_ptr<pytorch_qnnp_operator, QnnpackOperatorDeleter>
33
+ sparse_linear_op_{nullptr};
34
+ int64_t output_channels_;
35
+ int64_t input_channels_;
36
+ // Deserialized Tensors are stored to maintain the lifetime of underlying
37
+ // BCSR data.
38
+ // These are left empty if PackedLinearWeightQnnp is created via prepacking
39
+ // rather than deserializing.
40
+ at::Tensor deserialized_bcsr_row_block_indices_;
41
+ at::Tensor deserialized_bcsr_col_block_indices_;
42
+ at::Tensor deserialized_bcsr_weight_values_;
43
+
44
+ at::Tensor apply(
45
+ const at::Tensor& input,
46
+ double output_scale,
47
+ int64_t output_zero_point) override {
48
+ TORCH_CHECK(
49
+ false, "Static quantized sparse linear unimplemented on QNNPACK");
50
+ }
51
+ at::Tensor apply_relu(
52
+ const at::Tensor& input,
53
+ double output_scale,
54
+ int64_t output_zero_point) override {
55
+ TORCH_CHECK(
56
+ false, "Static quantized sparse linear unimplemented on QNNPACK");
57
+ }
58
+
59
+ at::Tensor apply_dynamic(const at::Tensor& input) override;
60
+ at::Tensor apply_dynamic_relu(const at::Tensor& input) override;
61
+
62
+ LinearPackedSerializationType unpack() override;
63
+
64
+ BCSRSerializationType serialize() override;
65
+
66
+ static c10::intrusive_ptr<LinearPackedParamsBase> deserialize(
67
+ const BCSRSerializationType& serialized);
68
+
69
+ std::optional<at::Tensor> bias() override {
70
+ return orig_bias_;
71
+ }
72
+
73
+ static c10::intrusive_ptr<LinearPackedParamsBase> prepack(
74
+ const at::Tensor& weight,
75
+ const std::optional<at::Tensor>& bias,
76
+ const int64_t out_features_block_size,
77
+ const int64_t in_features_block_size);
78
+
79
+ private:
80
+ template <bool ReluFused>
81
+ at::Tensor apply_impl(
82
+ const at::Tensor& input,
83
+ double output_scale,
84
+ int64_t output_zero_point);
85
+ template <bool ReluFused>
86
+ at::Tensor apply_dynamic_impl(const at::Tensor& input);
87
+ };
88
+
89
+ } // namespace ao::sparse
90
+
91
+ #endif // USE_PYTORCH_QNNPACK
92
+
93
+ #else
94
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
95
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/AtomicAddFloat.h ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #ifndef ATOMIC_ADD_FLOAT
3
+ #define ATOMIC_ADD_FLOAT
4
+
5
+ #if (defined(__x86_64__) || defined(__i386__) || defined(__aarch64__))
6
+ #include <ATen/native/cpu/Intrinsics.h>
7
+ #else
8
+ #define _mm_pause()
9
+ #endif
10
+
11
+ #include <atomic>
12
+
13
+ static inline void cpu_atomic_add_float(float* dst, float fvalue)
14
+ {
15
+ typedef union {
16
+ unsigned intV;
17
+ float floatV;
18
+ } uf32_t;
19
+
20
+ uf32_t new_value, old_value;
21
+ std::atomic<unsigned>* dst_intV = (std::atomic<unsigned>*)dst;
22
+
23
+ old_value.floatV = *dst;
24
+ new_value.floatV = old_value.floatV + fvalue;
25
+
26
+ unsigned* old_intV = &old_value.intV;
27
+ while (!std::atomic_compare_exchange_strong(dst_intV, old_intV, new_value.intV)) {
28
+ #ifdef __aarch64__
29
+ __asm__ __volatile__("yield;" : : : "memory");
30
+ #else
31
+ _mm_pause();
32
+ #endif
33
+ old_value.floatV = *dst;
34
+ new_value.floatV = old_value.floatV + fvalue;
35
+ }
36
+ }
37
+
38
+ #endif
39
+
40
+ #else
41
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
42
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/ChannelShuffleKernel.h ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ #include <ATen/native/DispatchStub.h>
4
+ #include <cstdint>
5
+
6
+ namespace at {
7
+ class TensorBase;
8
+ }
9
+
10
+ namespace at::native {
11
+
12
+ using channel_shuffle_fn = void(*)(TensorBase&, const TensorBase&, int64_t);
13
+ DECLARE_DISPATCH(channel_shuffle_fn, channel_shuffle_kernel)
14
+
15
+ } // at::native
16
+
17
+ #else
18
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
19
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/DepthwiseConvKernel.h ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/native/DispatchStub.h>
5
+ #include <c10/util/ArrayRef.h>
6
+
7
+ /*
8
+ Depthwise 3x3 Winograd convolution operator
9
+ */
10
+
11
+ namespace at {
12
+ class Tensor;
13
+
14
+ namespace native {
15
+
16
+ using convolution_depthwise3x3_winograd_fn =
17
+ Tensor (*)(const Tensor &, const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, int64_t);
18
+
19
+ DECLARE_DISPATCH(convolution_depthwise3x3_winograd_fn, convolution_depthwise3x3_winograd_stub)
20
+
21
+ } // namespace native
22
+ } // namespace at
23
+
24
+ #else
25
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
26
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/Elu.h ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ // On Windows, math.h needs to be included with _USE_MATH_DEFINES defined to
5
+ // access constants such as M_SQRT2 and M_2_SQRTPI.
6
+ #ifdef _WIN32
7
+ #define _USE_MATH_DEFINES
8
+ #include <cmath>
9
+ #endif // _WIN32
10
+
11
+ #include <ATen/cpu/vec/vec.h>
12
+ #include <c10/util/BFloat16.h> // For c10::is_reduced_floating_point_v.
13
+
14
+ namespace at::native {
15
+ inline namespace CPU_CAPABILITY {
16
+ /**
17
+ * Return a function object that calculates ELU with the given
18
+ * parameters on its input element. ParamT is the type of the input
19
+ * and output to the ELU, and MathT is the type (possibly
20
+ * higher-precision, e.g. float if ParamT is reduced-precision float)
21
+ * in which to do intermediate calculations.
22
+ */
23
+ template <typename ParamT, typename MathT=ParamT>
24
+ auto get_scalar_elu_elementwise_func(MathT alpha, MathT scale, MathT input_scale) {
25
+ const auto negcoef = alpha * scale;
26
+ const auto poscoef = scale;
27
+ const auto negiptcoef = input_scale;
28
+ return [negcoef, negiptcoef, poscoef](ParamT a) -> ParamT {
29
+ return MathT(a) < MathT(0)
30
+ ? std::expm1(MathT(a) * negiptcoef) * negcoef
31
+ : MathT(a) * poscoef;
32
+ };
33
+ }
34
+
35
+ /**
36
+ * Return a function object that calculates ELU with the given
37
+ * parameters on its input element. The function object takes and
38
+ * returns Vectorized<T>.
39
+ */
40
+ template <typename T, std::enable_if_t<!c10::is_reduced_floating_point_v<T>, bool> = true>
41
+ auto get_vectorized_elu_elementwise_func(T alpha, T scale, T input_scale) {
42
+ const vec::Vectorized<T> negcoef_vec(alpha * scale);
43
+ const vec::Vectorized<T> poscoef_vec(scale);
44
+ const vec::Vectorized<T> negiptcoef_vec(input_scale);
45
+ const vec::Vectorized<T> zero_vec(static_cast<T>(0));
46
+ return [negcoef_vec, poscoef_vec, negiptcoef_vec, zero_vec](vec::Vectorized<T> a) -> vec::Vectorized<T> {
47
+ const auto cmp = a >= zero_vec;
48
+ if (!cmp.zero_mask()) {
49
+ return a * poscoef_vec;
50
+ } else {
51
+ return vec::Vectorized<T>::blendv((a * negiptcoef_vec).expm1() * negcoef_vec, a * poscoef_vec, cmp);
52
+ }
53
+ };
54
+ }
55
+
56
+ /**
57
+ * Return a function object that calculates ELU with the given
58
+ * parameters on its input element. The function object takes and
59
+ * returns Vectorized<ParamT>, and Vectorized<MathT> is the type
60
+ * (possibly higher-precision) in which to do intermediate
61
+ * calculations.
62
+ */
63
+ template <typename T, std::enable_if_t<c10::is_reduced_floating_point_v<T>, bool> = true>
64
+ auto get_vectorized_elu_elementwise_func(float alpha, float scale, float input_scale) {
65
+ // Takes float->float.
66
+ const auto float_func = get_vectorized_elu_elementwise_func<float>(alpha, scale, input_scale);
67
+ return [float_func](vec::Vectorized<T> a) -> vec::Vectorized<T> {
68
+ auto [a0, a1] = vec::convert_to_float<T>(a);
69
+ auto res0 = float_func(a0);
70
+ auto res1 = float_func(a1);
71
+ return vec::convert_from_float<T>(res0, res1);
72
+ };
73
+ }
74
+ } // namespace CPU_CAPABILITY
75
+ } // namespace at::native
76
+
77
+ #else
78
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
79
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/GridSamplerKernel.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/native/DispatchStub.h>
5
+
6
+ #include <array>
7
+ #include <cstdint>
8
+
9
+ namespace at {
10
+ class TensorBase;
11
+ }
12
+
13
+ namespace at::native {
14
+
15
+ using forward_2d_fn = void (*) (
16
+ const TensorBase &output,
17
+ const TensorBase &input,
18
+ const TensorBase &grid,
19
+ int64_t interpolation_mode,
20
+ int64_t padding_mode,
21
+ bool align_corners);
22
+ using backward_2d_fn = void (*) (
23
+ const TensorBase &grad_input,
24
+ const TensorBase &grad_grid,
25
+ const TensorBase &grad_output,
26
+ const TensorBase &input,
27
+ const TensorBase &grid,
28
+ int64_t interpolation_mode,
29
+ int64_t padding_mode,
30
+ bool align_corners,
31
+ std::array<bool, 2> output_mask);
32
+ DECLARE_DISPATCH(forward_2d_fn, grid_sampler_2d_cpu_kernel)
33
+ DECLARE_DISPATCH(backward_2d_fn, grid_sampler_2d_backward_cpu_kernel)
34
+
35
+ } // namespace at::native
36
+
37
+ #else
38
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
39
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/IndexKernelUtils.h ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ #include <ATen/native/TensorIterator.h>
4
+ #include <c10/util/irange.h>
5
+
6
+ namespace at::native {
7
+
8
+ inline bool is_constant_index(int ntensor, const int64_t* strides) {
9
+ AT_ASSERT(ntensor >= 3);
10
+ for (const auto arg : c10::irange(2, ntensor)) {
11
+ if (strides[arg] != 0) {
12
+ return false;
13
+ }
14
+ }
15
+ return true;
16
+ }
17
+
18
+
19
+ struct Indexer {
20
+ Indexer(int64_t num_indexers, char** indexers, const int64_t* indexer_strides,
21
+ IntArrayRef original_sizes, IntArrayRef original_strides)
22
+ : num_indexers(num_indexers)
23
+ , indexers(indexers)
24
+ , indexer_strides(indexer_strides)
25
+ , original_strides(original_strides.data())
26
+ , original_sizes(original_sizes.data()) {
27
+ AT_ASSERT(static_cast<int64_t>(original_strides.size()) == num_indexers);
28
+ AT_ASSERT(static_cast<int64_t>(original_sizes.size()) == num_indexers);
29
+ }
30
+
31
+ int64_t num_indexers;
32
+ char** indexers;
33
+ const int64_t* indexer_strides;
34
+ const int64_t* original_strides;
35
+ const int64_t* original_sizes;
36
+
37
+ int64_t get(int64_t idx) {
38
+ int64_t offset = 0;
39
+ for (const auto j : c10::irange(num_indexers)) {
40
+ int64_t value = *(int64_t*)&indexers[j][idx * indexer_strides[j]];
41
+ int64_t size = original_sizes[j];
42
+ TORCH_CHECK_INDEX(value >= -size && value < size,
43
+ "index ", value, " is out of bounds for dimension ", j, " with size ", size);
44
+ if (value < 0) {
45
+ value += size;
46
+ }
47
+ offset += value * original_strides[j];
48
+ }
49
+ return offset;
50
+ }
51
+ };
52
+
53
+ template <typename scalar_t, typename func_t>
54
+ void cpu_index_kernel(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride,
55
+ const func_t& f, bool serial_execution=false)
56
+ {
57
+ int ntensor = iter.ntensors();
58
+ // When launch the index parallel version, set a relative small grain size less than the INTERNAL::GRAIN_SIZE
59
+ // to make the whole available thread numbers get more balanced work load and a better cache location.
60
+ // The grain size here is chosen by the op benchmark to overcome the thread launch overhead
61
+ const int index_parallel_grain_size = 3000;
62
+ auto loop = [&](char** data, const int64_t* strides, int64_t n) {
63
+ auto indexer = Indexer(ntensor - 2, &data[2], &strides[2], index_size, index_stride);
64
+ char* dst = data[0];
65
+ char* src = data[1];
66
+ if (is_constant_index(ntensor, strides)) {
67
+ // specialization for when every element uses the same index
68
+ int64_t offset = indexer.get(0);
69
+ for (const auto i : c10::irange(n)) {
70
+ f(dst + strides[0] * i, src + strides[1] * i, offset);
71
+ }
72
+ } else {
73
+ for (const auto i : c10::irange(n)) {
74
+ int64_t offset = indexer.get(i);
75
+ f(dst + strides[0] * i, src + strides[1] * i, offset);
76
+ }
77
+ }
78
+ };
79
+ if (serial_execution) {
80
+ iter.serial_for_each(loop, {0, iter.numel()});
81
+ } else {
82
+ iter.for_each(loop, index_parallel_grain_size);
83
+ }
84
+ }
85
+ } // at
86
+ // native
87
+
88
+ #else
89
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
90
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/Intrinsics.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #if defined(__clang__) && (defined(__x86_64__) || defined(__i386__))
5
+ /* Clang-compatible compiler, targeting x86/x86-64 */
6
+ #include <x86intrin.h>
7
+ #elif defined(_MSC_VER)
8
+ /* Microsoft C/C++-compatible compiler */
9
+ #include <intrin.h>
10
+ #if _MSC_VER <= 1900
11
+ #define _mm256_extract_epi64(X, Y) (((uint64_t*)&X)[Y])
12
+ #endif
13
+ #elif defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__))
14
+ /* GCC-compatible compiler, targeting x86/x86-64 */
15
+ #include <x86intrin.h>
16
+ #elif defined(__GNUC__) && defined(__ARM_NEON__)
17
+ /* GCC-compatible compiler, targeting ARM with NEON */
18
+ #include <arm_neon.h>
19
+ #elif defined(__GNUC__) && defined(__IWMMXT__)
20
+ /* GCC-compatible compiler, targeting ARM with WMMX */
21
+ #include <mmintrin.h>
22
+ #elif (defined(__GNUC__) || defined(__xlC__)) && \
23
+ (defined(__VEC__) || defined(__ALTIVEC__))
24
+ /* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */
25
+ #include <altivec.h>
26
+ /* We need to undef those tokens defined by <altivec.h> to avoid conflicts
27
+ with the C++ types. => Can still use __bool/__vector */
28
+ #undef bool
29
+ #undef vector
30
+ #undef pixel
31
+ #elif defined(__GNUC__) && defined(__SPE__)
32
+ /* GCC-compatible compiler, targeting PowerPC with SPE */
33
+ #include <spe.h>
34
+ #endif
35
+
36
+ #else
37
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
38
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/IsContiguous.h ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ namespace at::native { inline namespace CPU_CAPABILITY {
5
+
6
+ // n: number of function arguments (arity)
7
+ // traits: function_traits (see FunctionTraits.h)
8
+ // s: index of scalar argument or -1
9
+ template <int n, int stride_index, typename traits, int s=-1>
10
+ struct IsContiguous {
11
+ static bool eval(const int64_t* strides) {
12
+ using type = typename traits::template arg<n - 1>::type;
13
+ return strides[stride_index] == (s == n ? 0 : sizeof(type)) &&
14
+ IsContiguous<n - 1, stride_index - 1, traits, s>::eval(strides);
15
+ }
16
+ };
17
+
18
+ // will be called when there is an output exists
19
+ template <typename traits, int s>
20
+ struct IsContiguous<0, 0, traits, s> {
21
+ static bool eval(const int64_t* strides) {
22
+ return strides[0] == sizeof(typename traits::result_type);
23
+ }
24
+ };
25
+
26
+ // will be called when there is no output
27
+ template <typename traits, int s>
28
+ struct IsContiguous<0, -1, traits, s> {
29
+ static bool eval(const int64_t* /*strides*/) {
30
+ return true;
31
+ }
32
+ };
33
+
34
+ // output and all inputs are contiguous
35
+ template <
36
+ typename traits,
37
+ std::enable_if_t<std::is_void_v<typename traits::result_type>>* =
38
+ nullptr>
39
+ static inline bool is_contiguous(const int64_t* strides) {
40
+ return IsContiguous<traits::arity, traits::arity - 1, traits>::eval(strides);
41
+ }
42
+
43
+ template <typename traits,
44
+ std::enable_if_t<!std::is_void_v<typename traits::result_type>>* = nullptr>
45
+ static inline bool is_contiguous(const int64_t* strides) {
46
+ return IsContiguous<traits::arity, traits::arity, traits>::eval(strides);
47
+ }
48
+
49
+ // input at `s` is scalar (stride 0); output and other inputs are contiguous
50
+ // NB: output is typically at strides[0] so first input corresponds to s=1
51
+ template <typename traits, int s,
52
+ std::enable_if_t<std::is_void_v<typename traits::result_type>>* = nullptr>
53
+ static inline bool is_contiguous_scalar(const int64_t* strides) {
54
+ static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds");
55
+ return IsContiguous<traits::arity, traits::arity - 1, traits, s>::eval(strides);
56
+ }
57
+
58
+ template <typename traits, int s,
59
+ std::enable_if_t<!std::is_void_v<typename traits::result_type>>* = nullptr>
60
+ static inline bool is_contiguous_scalar(const int64_t* strides) {
61
+ static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds");
62
+ return IsContiguous<traits::arity, traits::arity, traits, s>::eval(strides);
63
+ }
64
+
65
+ }}
66
+
67
+ #else
68
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
69
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/LogSoftmaxKernelImpl.h ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/OpMathType.h>
5
+ #include <ATen/Parallel.h>
6
+ #include <ATen/cpu/vec/functional.h>
7
+ #include <ATen/cpu/vec/vec.h>
8
+ #include <c10/util/irange.h>
9
+
10
+ #include <algorithm>
11
+ #include <cmath>
12
+ #include <cstdint>
13
+ #include <limits>
14
+ #include <memory>
15
+ #include <type_traits>
16
+
17
+ namespace at::native {
18
+ inline namespace CPU_CAPABILITY {
19
+ template <typename scalar_t>
20
+ int64_t vec_log_softmax_lastdim_chunk_size(int64_t grain_size, int64_t outer_size, int64_t dim_size) {
21
+ // Coincidentally, at::internal::GRAIN_SIZE is 32768, which is equal to the
22
+ // size of L1D cache on many processors. Some processors have 48 KB L1D cache
23
+ // nowadays, so maybe in the future, we can leverage the knowledge of a
24
+ // machine's L1D cache size.
25
+ int64_t MAX_CHUNK_SIZE = std::max<int64_t>(
26
+ 1,
27
+ grain_size / (sizeof(scalar_t) * dim_size));
28
+ return std::min<int64_t>(MAX_CHUNK_SIZE, outer_size);
29
+ }
30
+
31
+ template <typename scalar_t>
32
+ void serial_vec_log_softmax_lastdim_range(
33
+ const scalar_t* input_data_base,
34
+ scalar_t* output_data_base,
35
+ int64_t dim_size,
36
+ int64_t chunk_size,
37
+ int64_t begin,
38
+ int64_t end) {
39
+ if (end <= begin) {
40
+ return;
41
+ }
42
+ using Vec = vec::Vectorized<vec::vec_scalar_t<scalar_t>>;
43
+ // MSVC requires such a declaration of dynamic arrays
44
+ // Source: https://stackoverflow.com/a/33423538
45
+ auto tmp_sum_scalar = std::make_unique<scalar_t[]>(chunk_size);
46
+ auto max_input_arr = std::make_unique<scalar_t[]>(chunk_size);
47
+ for (int64_t ii = begin; ii < end; ii += chunk_size) {
48
+ int64_t loop_end = chunk_size;
49
+ if (ii + chunk_size > end) {
50
+ loop_end = end - ii;
51
+ }
52
+ for (const auto j : c10::irange(loop_end)) {
53
+ int64_t i = ii + j;
54
+ const scalar_t* input_data = input_data_base + i * dim_size;
55
+ max_input_arr[j] = vec::reduce_all<scalar_t>(
56
+ [](Vec& x, Vec& y) { return vec::maximum(x, y); },
57
+ input_data,
58
+ dim_size);
59
+ }
60
+ for (const auto j : c10::irange(loop_end)) {
61
+ int64_t i = ii + j;
62
+ const scalar_t* input_data = input_data_base + i * dim_size;
63
+ scalar_t max_input = max_input_arr[j];
64
+ tmp_sum_scalar[j] = vec::map_reduce_all<scalar_t>(
65
+ [max_input](Vec x) { return (x - Vec(max_input)).exp(); },
66
+ [](Vec x, Vec y) { return x + y; },
67
+ input_data,
68
+ dim_size);
69
+ }
70
+ // See [Note AVX-SSE transitions] for why this should call the
71
+ // vectorized version (aside from perf improvements).
72
+ vec::map(
73
+ [](Vec x) { return x.log(); },
74
+ tmp_sum_scalar.get(),
75
+ tmp_sum_scalar.get(),
76
+ loop_end);
77
+ for (const auto j : c10::irange(loop_end)) {
78
+ int64_t i = ii + j;
79
+ const scalar_t* input_data = input_data_base + i * dim_size;
80
+ scalar_t* output_data = output_data_base + i * dim_size;
81
+ scalar_t tmp_sum = tmp_sum_scalar[j];
82
+ scalar_t max_input = max_input_arr[j];
83
+
84
+ // It's necessary to keep the order of the operations below.
85
+ // In some cases that input is large digits and the difference
86
+ // is small, if we compute `max_input` plus `tmp_sum` before,
87
+ // there would be a numerical problem. See an example in
88
+ // https://github.com/pytorch/pytorch/issues/11752#issuecomment-422883379
89
+ vec::map(
90
+ [tmp_sum, max_input](Vec x) {
91
+ return x - Vec(max_input) - Vec(tmp_sum);
92
+ },
93
+ output_data,
94
+ input_data,
95
+ dim_size);
96
+ }
97
+ }
98
+ }
99
+
100
+ // Can't include ATen/Parallel.h.
101
+ // TODO: find a way to have only one copy of divup.
102
+ inline int64_t divup(int64_t x, int64_t y) {
103
+ return (x + y - 1) / y;
104
+ }
105
+
106
+ template <typename scalar_t, int64_t BLOCK_SIZE = 128 * 1024>
107
+ std::pair<int64_t,int64_t> vec_logsoftmax_chunk_size_and_num_chunks(int64_t inner_size, int64_t dim_size) {
108
+ using Vec = vec::Vectorized<scalar_t>;
109
+ int64_t MAX_CHUNK_SIZE = std::max<int64_t>(BLOCK_SIZE / dim_size / sizeof(scalar_t), Vec::size());
110
+ MAX_CHUNK_SIZE = MAX_CHUNK_SIZE / Vec::size() * Vec::size();
111
+ int64_t CHUNK_SIZE = std::min<int64_t>(MAX_CHUNK_SIZE, inner_size);
112
+ int64_t num_chunks = divup(inner_size, CHUNK_SIZE);
113
+ return {CHUNK_SIZE, num_chunks};
114
+ }
115
+
116
+ template <typename scalar_t>
117
+ std::enable_if_t<std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
118
+ serial_vec_logsoftmax_range(
119
+ const scalar_t* input_data_base,
120
+ scalar_t* output_data_base,
121
+ int64_t inner_size,
122
+ int64_t chunk_size,
123
+ int64_t num_chunks,
124
+ int64_t dim_size,
125
+ int64_t begin,
126
+ int64_t end) {
127
+ using Vec = vec::Vectorized<scalar_t>;
128
+ // thread local temp buffer which holds vertical reduction result: max and sum.
129
+ auto buffer = std::make_unique<scalar_t []>(chunk_size * 2);
130
+ scalar_t* input_max_data = buffer.get();
131
+ scalar_t* tmp_sum_data = buffer.get() + chunk_size;
132
+
133
+ for (int64_t i = begin; i < end; i++) {
134
+ int64_t outer_idx = i / num_chunks;
135
+ int64_t k = i % num_chunks;
136
+ int64_t inner_idx_begin = k * chunk_size;
137
+ int64_t size = std::min(chunk_size, inner_size - inner_idx_begin);
138
+
139
+ // init
140
+ Vec zero_vec = Vec(scalar_t(0));
141
+ Vec min_vec = Vec(-std::numeric_limits<scalar_t>::infinity());
142
+ int64_t d0 = 0;
143
+ for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) {
144
+ min_vec.store(input_max_data + d0);
145
+ zero_vec.store(tmp_sum_data + d0);
146
+ }
147
+ for (; d0 < size; d0++) {
148
+ input_max_data[d0] = -std::numeric_limits<scalar_t>::infinity();
149
+ tmp_sum_data[d0] = scalar_t(0);
150
+ }
151
+
152
+ // compute max
153
+ for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
154
+ const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
155
+ + dim_idx * inner_size + inner_idx_begin;
156
+
157
+ int64_t d1 = 0;
158
+ for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
159
+ Vec data_vec = Vec::loadu(input_ptr + d1);
160
+ Vec max_vec = Vec::loadu(input_max_data + d1);
161
+ max_vec = Vec::blendv(max_vec, data_vec, data_vec > max_vec);
162
+ max_vec.store(input_max_data + d1);
163
+ }
164
+ for (; d1 < size; d1++) {
165
+ scalar_t data_val = input_ptr[d1];
166
+ scalar_t max_val = input_max_data[d1];
167
+ input_max_data[d1] = data_val > max_val ? data_val : max_val;
168
+ }
169
+ }
170
+
171
+ // compute sum of (x - max).exp()
172
+ for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
173
+ const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
174
+ + dim_idx * inner_size + inner_idx_begin;
175
+
176
+ int64_t d2 = 0;
177
+ for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
178
+ Vec data_vec = Vec::loadu(input_ptr + d2);
179
+ Vec sum_vec = Vec::loadu(tmp_sum_data + d2);
180
+ Vec max_vec = Vec::loadu(input_max_data + d2);
181
+ sum_vec += (data_vec - max_vec).exp();
182
+ sum_vec.store(tmp_sum_data + d2);
183
+ }
184
+ for (; d2 < size; d2++) {
185
+ scalar_t data_val = input_ptr[d2];
186
+ scalar_t max_val = input_max_data[d2];
187
+ tmp_sum_data[d2] += std::exp(data_val - max_val);
188
+ }
189
+ }
190
+
191
+ // apply log
192
+ vec::map([](Vec x) { return x.log(); }, tmp_sum_data, tmp_sum_data, size);
193
+
194
+ // compute x - max - sum
195
+ for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
196
+ int64_t offset = outer_idx * dim_size * inner_size + dim_idx * inner_size + inner_idx_begin;
197
+ const scalar_t* input_ptr = input_data_base + offset;
198
+ scalar_t* output_ptr = output_data_base + offset;
199
+
200
+ int64_t d3 = 0;
201
+ for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) {
202
+ Vec data_vec = Vec::loadu(input_ptr + d3);
203
+ Vec max_vec = Vec::loadu(input_max_data + d3);
204
+ Vec sum_vec = Vec::loadu(tmp_sum_data + d3);
205
+ Vec out_vec = data_vec - max_vec - sum_vec;
206
+ out_vec.store(output_ptr + d3);
207
+ }
208
+ for (; d3 < size; d3++) {
209
+ output_ptr[d3] = input_ptr[d3] - input_max_data[d3] - tmp_sum_data[d3];
210
+ }
211
+ }
212
+ }
213
+ }
214
+
215
+ template <typename scalar_t>
216
+ std::enable_if_t<!std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
217
+ serial_vec_logsoftmax_range(
218
+ const scalar_t* input_data_base,
219
+ scalar_t* output_data_base,
220
+ int64_t inner_size,
221
+ int64_t chunk_size,
222
+ int64_t num_chunks,
223
+ int64_t dim_size,
224
+ int64_t begin,
225
+ int64_t end) {
226
+ using Vec = vec::Vectorized<scalar_t>;
227
+ using fVec = vec::Vectorized<float>;
228
+ auto buffer = std::make_unique<float []>(chunk_size * 2);
229
+ float* input_max_data = buffer.get();
230
+ float* tmp_sum_data = buffer.get() + chunk_size;
231
+
232
+ // thread local buffer that holds input data in float32 to save next 2 dtype conversion
233
+ auto input_buffer = std::make_unique<float []>(dim_size * chunk_size);
234
+ float* input_buffer_data = input_buffer.get();
235
+
236
+ // init
237
+ for (int64_t i = begin; i < end; i++) {
238
+ int64_t outer_idx = i / num_chunks;
239
+ int64_t k = i % num_chunks;
240
+ int64_t inner_idx_begin = k * chunk_size;
241
+ int64_t size = std::min(chunk_size, inner_size - inner_idx_begin);
242
+
243
+ fVec zero_fvec = fVec(float(0));
244
+ fVec min_fvec = fVec(-std::numeric_limits<float>::infinity());
245
+ int64_t d0 = 0;
246
+ for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) {
247
+ min_fvec.store(input_max_data + d0);
248
+ min_fvec.store(input_max_data + d0 + fVec::size());
249
+ zero_fvec.store(tmp_sum_data + d0);
250
+ zero_fvec.store(tmp_sum_data + d0 + fVec::size());
251
+ }
252
+ for (; d0 < size; d0++) {
253
+ input_max_data[d0] = -std::numeric_limits<float>::infinity();
254
+ tmp_sum_data[d0] = float(0);
255
+ }
256
+
257
+ // compute max
258
+ for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
259
+ const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
260
+ + dim_idx * inner_size + inner_idx_begin;
261
+ float* input_buffer_ptr = input_buffer_data + dim_idx * chunk_size;
262
+
263
+ int64_t d1 = 0;
264
+ for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
265
+ Vec data_vec = Vec::loadu(input_ptr + d1);
266
+ auto [data_fvec0, data_fvec1] = vec::convert_to_float<scalar_t>(data_vec);
267
+ fVec max_fvec0 = fVec::loadu(input_max_data + d1);
268
+ fVec max_fvec1 = fVec::loadu(input_max_data + d1 + fVec::size());
269
+ max_fvec0 = fVec::blendv(max_fvec0, data_fvec0, data_fvec0 > max_fvec0);
270
+ max_fvec1 = fVec::blendv(max_fvec1, data_fvec1, data_fvec1 > max_fvec1);
271
+ max_fvec0.store(input_max_data + d1);
272
+ max_fvec1.store(input_max_data + d1 + fVec::size());
273
+
274
+ // cache the 'converted' float input
275
+ data_fvec0.store(input_buffer_ptr + d1);
276
+ data_fvec1.store(input_buffer_ptr + d1 + fVec::size());
277
+ }
278
+ for (; d1 < size; d1++) {
279
+ float data_val = float(input_ptr[d1]);
280
+ float max_val = input_max_data[d1];
281
+ input_max_data[d1] = data_val > max_val ? data_val : max_val;
282
+ input_buffer_ptr[d1] = data_val;
283
+ }
284
+ }
285
+
286
+ // compute sum of (x - max).exp()
287
+ for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
288
+ float* input_buffer_ptr = input_buffer_data + dim_idx * chunk_size;
289
+
290
+ int64_t d2 = 0;
291
+ for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
292
+ fVec data_fvec0 = fVec::loadu(input_buffer_ptr + d2);
293
+ fVec data_fvec1 = fVec::loadu(input_buffer_ptr + d2 + fVec::size());
294
+ fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d2);
295
+ fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d2 + fVec::size());
296
+ fVec max_fvec0 = fVec::loadu(input_max_data + d2);
297
+ fVec max_fvec1 = fVec::loadu(input_max_data + d2 + fVec::size());
298
+ sum_fvec0 += (data_fvec0 - max_fvec0).exp();
299
+ sum_fvec1 += (data_fvec1 - max_fvec1).exp();
300
+ sum_fvec0.store(tmp_sum_data + d2);
301
+ sum_fvec1.store(tmp_sum_data + d2 + fVec::size());
302
+ }
303
+ for (; d2 < size; d2++) {
304
+ float data_val = input_buffer_ptr[d2];
305
+ float max_val = input_max_data[d2];
306
+ tmp_sum_data[d2] += std::exp(data_val - max_val);
307
+ }
308
+ }
309
+
310
+ // apply log
311
+ vec::map([](fVec x) { return x.log(); }, tmp_sum_data, tmp_sum_data, size);
312
+
313
+ // compute x - max - sum
314
+ for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
315
+ float* input_buffer_ptr = input_buffer_data + dim_idx * chunk_size;
316
+ scalar_t* output_ptr = output_data_base + outer_idx * dim_size * inner_size
317
+ + dim_idx * inner_size + inner_idx_begin;
318
+
319
+ int64_t d3 = 0;
320
+ for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) {
321
+ fVec data_fvec0 = fVec::loadu(input_buffer_ptr + d3);
322
+ fVec data_fvec1 = fVec::loadu(input_buffer_ptr + d3 + fVec::size());
323
+ fVec max_fvec0 = fVec::loadu(input_max_data + d3);
324
+ fVec max_fvec1 = fVec::loadu(input_max_data + d3 + fVec::size());
325
+ fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d3);
326
+ fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d3 + fVec::size());
327
+ fVec out_fvec0 = data_fvec0 - max_fvec0 - sum_fvec0;
328
+ fVec out_fvec1 = data_fvec1 - max_fvec1 - sum_fvec1;
329
+ Vec out_vec = vec::convert_from_float<scalar_t>(out_fvec0, out_fvec1);
330
+ out_vec.store(output_ptr + d3);
331
+ }
332
+ for (; d3 < size; d3++) {
333
+ output_ptr[d3] = scalar_t(input_buffer_ptr[d3] - input_max_data[d3] - tmp_sum_data[d3]);
334
+ }
335
+ }
336
+ }
337
+ } // namespace CPU_CAPABILITY
338
+ }} // namespace at::native
339
+
340
+ #else
341
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
342
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/Loops.h ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ // This file provides two functions to help write elementwise kernels:
5
+ //
6
+ // cpu_kernel(TensorIterator iter, <lambda>)
7
+ // cpu_kernel_vec(TensorIterator iter, <lambda>, <vec_lambda>)
8
+ //
9
+ // Both functions may generate vectorized code. The cpu_kernel implementation
10
+ // relies on the compiler's auto-vectorization. The cpu_kernel_vec
11
+ // implementation uses x86 SIMD intrinsics when available. These functions
12
+ // are only intended to be used in the ATen/native/cpu subdirectory, since files
13
+ // in other directories are not compiled with AVX/AVX2 enabled. See README.md
14
+ // for more details.
15
+ //
16
+ // For example, to write a multiplication kernel for float:
17
+ //
18
+ // cpu_kernel(iter, [](float a, float b) { return a * b; });
19
+ //
20
+ // Or you may write:
21
+ //
22
+ // cpu_kernel_vec(iter,
23
+ // [](float a, float b) { return a * b; },
24
+ // [](Vectorized<float> a, Vectorized<float> b) { return a * b; });
25
+ //
26
+ // See BinaryOpsKernel.cpp for the complete implementation
27
+ //
28
+ //
29
+
30
+ #include <cstdint>
31
+ #include <c10/util/C++17.h>
32
+ #include <c10/util/Load.h>
33
+ #include <c10/util/irange.h>
34
+ #include <ATen/detail/FunctionTraits.h>
35
+ #include <ATen/native/cpu/IsContiguous.h>
36
+ #include <ATen/native/TensorIterator.h>
37
+ #include <ATen/native/TensorIteratorDynamicCasting.h>
38
+ #include <ATen/cpu/vec/vec.h>
39
+
40
+ #include <tuple>
41
+ #include <utility>
42
+
43
+ namespace at::native { inline namespace CPU_CAPABILITY {
44
+
45
+ using namespace vec;
46
+
47
+ template <typename traits, std::size_t... INDEX>
48
+ typename traits::ArgsTuple
49
+ dereference_impl(char* C10_RESTRICT data[], const int64_t* strides, int64_t i,
50
+ std::index_sequence<INDEX...> /*unused*/) {
51
+ return std::make_tuple(
52
+ c10::load<typename traits::template arg<INDEX>::type>(
53
+ data[INDEX] + i * strides[INDEX])...);
54
+ }
55
+
56
+ template <typename traits>
57
+ typename traits::ArgsTuple
58
+ dereference(char* C10_RESTRICT data[], const int64_t* strides, int64_t i) {
59
+ using Indices = std::make_index_sequence<traits::arity>;
60
+ return dereference_impl<traits>(data, strides, i, Indices{});
61
+ }
62
+
63
+ template <typename traits, std::size_t... INDEX>
64
+ typename traits::ArgsTuple
65
+ dereference_vec_impl(char* C10_RESTRICT data[],
66
+ const typename traits::result_type& opt_scalar,
67
+ size_t S,
68
+ int64_t i,
69
+ std::index_sequence<INDEX...> /*unused*/) {
70
+ using Vec = typename traits::result_type;
71
+ using scalar_t = typename Vec::value_type;
72
+ return std::make_tuple(
73
+ S == INDEX + 1 ?
74
+ opt_scalar :
75
+ Vec::loadu(data[INDEX] + i * sizeof(scalar_t))...);
76
+ }
77
+
78
+ template <typename traits>
79
+ typename traits::ArgsTuple
80
+ dereference_vec(char* C10_RESTRICT data[], const typename traits::result_type& opt_scalar, size_t S, int64_t i) {
81
+ using Indices = std::make_index_sequence<traits::arity>;
82
+ return dereference_vec_impl<traits>(data, opt_scalar, S, i, Indices{});
83
+ }
84
+
85
+ template <typename func_t,
86
+ std::enable_if_t<!std::is_void_v<typename function_traits<func_t>::result_type>>* = nullptr>
87
+ inline void
88
+ execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) {
89
+ using traits = function_traits<func_t>;
90
+ using result_type = typename traits::result_type;
91
+ for (; i < n; i++) {
92
+ result_type* out_ptr = (result_type*)(data[0] + i * strides[0]);
93
+ *out_ptr = std::apply(op, dereference<traits>(
94
+ &data[1],
95
+ &strides[1],
96
+ i));
97
+ }
98
+ }
99
+
100
+ template <typename func_t,
101
+ std::enable_if_t<std::is_void_v<typename function_traits<func_t>::result_type>>* = nullptr>
102
+ inline void
103
+ execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) {
104
+ using traits = function_traits<func_t>;
105
+ for (; i < n; i++) {
106
+ std::apply(op, dereference<traits>(
107
+ &data[0],
108
+ &strides[0],
109
+ i));
110
+ }
111
+ }
112
+
113
+ // Basic loop operation (one output, N inputs). May be auto-vectorized
114
+ // by the compiler. Supports inputs and outputs of different types.
115
+ template <typename func_t>
116
+ inline void
117
+ basic_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) {
118
+ using traits = function_traits<func_t>;
119
+ constexpr int ntensors = traits::arity + 1;
120
+
121
+ // Copying strides to temporary array helps auto vectorization in older GCC
122
+ // versions.
123
+ int64_t strides[ntensors];
124
+ for (const auto arg : c10::irange(ntensors)) {
125
+ strides[arg] = strides_[arg];
126
+ }
127
+
128
+ execute_op(data, strides, i, n, std::forward<func_t>(op));
129
+ }
130
+
131
+ // the recursive variadic template for iterating over the returned tuple
132
+ template<class T, size_t N>
133
+ struct TupleOutput {
134
+ static void handle(char *C10_RESTRICT data[], const int64_t *strides, int64_t i,
135
+ const T &tuple) {
136
+ TupleOutput<T, N - 1>::handle(data, strides, i, tuple);
137
+
138
+ auto output = std::get<N - 1>(tuple);
139
+ using output_type = decltype(output);
140
+ output_type * out_ptr = (output_type *)(data[N - 1] + i * strides[N - 1]);
141
+ *out_ptr = output;
142
+ }
143
+ };
144
+
145
+ // Base case for the above recursive template
146
+ template<class T>
147
+ struct TupleOutput<T, 1> {
148
+ static void handle(char *C10_RESTRICT data[], const int64_t *strides, int64_t i,
149
+ const T &tuple) {
150
+ auto output = std::get<0>(tuple);
151
+ using output_type = decltype(output);
152
+ output_type* out_ptr = (output_type *)(data[0] + i * strides[0]);
153
+ *out_ptr = output;
154
+ }
155
+ };
156
+
157
+ template<class... Args>
158
+ void handle_tuple_outputs(char* C10_RESTRICT data[],
159
+ const int64_t* strides,
160
+ int64_t i,
161
+ const std::tuple<Args...> &tuple) {
162
+ TupleOutput<decltype(tuple), sizeof...(Args)>::handle(data, strides, i, tuple);
163
+ }
164
+
165
+ // Loop operation for `cpu_kernel_multiple_outputs`.
166
+ // 1. Use `std::apply` to make dynamic method invocation
167
+ // for the lambda passed in `cpu_kernel_multiple_outputs`.
168
+ // 2. Iterate over the members of the returned tuple, set the corresponding
169
+ // output tensor by the tuple member in `handle_tuple_outputs` function.
170
+ template <typename func_t>
171
+ inline void
172
+ multiple_outputs_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) {
173
+ using traits = function_traits<func_t>;
174
+
175
+ using result_type = typename traits::result_type;
176
+ constexpr int num_outputs = std::tuple_size_v<result_type>;
177
+ constexpr int ntensors = traits::arity + num_outputs;
178
+
179
+ // Copying strides to temporary array helps auto vectorization in older GCC
180
+ // versions.
181
+ int64_t strides[ntensors];
182
+ for (const auto arg : c10::irange(ntensors)) {
183
+ strides[arg] = strides_[arg];
184
+ }
185
+
186
+ for (; i < n; i++) {
187
+ auto output = std::apply(op, dereference<traits>(
188
+ &data[num_outputs],
189
+ &strides[num_outputs],
190
+ i));
191
+ handle_tuple_outputs(data, strides, i, output);
192
+ }
193
+ }
194
+
195
+ // Explicitly vectorized loop implementation. All inputs and outputs must be
196
+ // the same type and contiguous with one exception: a single input may be
197
+ // a scalar (stride 0). It's position is indicated by the argument `S`. If `S`
198
+ // is 0, then there are no scalar inputs.
199
+ template <typename func_t, typename vec_func_t>
200
+ inline void
201
+ vectorized_loop(char** C10_RESTRICT data_, int64_t n, int64_t S, func_t&& op, vec_func_t&& vop) {
202
+ using traits = function_traits<vec_func_t>;
203
+ using scalar_t = typename function_traits<func_t>::result_type;
204
+ using Vec = Vectorized<scalar_t>;
205
+ constexpr int ntensors = traits::arity + 1;
206
+
207
+ char* C10_RESTRICT data[ntensors];
208
+ for (const auto arg : c10::irange(ntensors)) {
209
+ data[arg] = data_[arg];
210
+ }
211
+
212
+ Vec opt_scalar = Vec(S > 0 ? c10::load((scalar_t*)data[S]) : scalar_t(0));
213
+ int64_t i = 0;
214
+ for (; i <= n - 2 * Vec::size(); i += 2 * Vec::size()) {
215
+ auto args1 = dereference_vec<traits>(&data[1], opt_scalar, S, i);
216
+ auto args2 = dereference_vec<traits>(&data[1], opt_scalar, S, i + Vec::size());
217
+ auto out1 = std::apply(vop, std::move(args1));
218
+ auto out2 = std::apply(vop, std::move(args2));
219
+ out1.store(data[0] + i * sizeof(scalar_t));
220
+ out2.store(data[0] + (i + Vec::size()) * sizeof(scalar_t));
221
+ }
222
+ if (i < n) {
223
+ int64_t strides[ntensors];
224
+ for (const auto arg : c10::irange(ntensors)) {
225
+ strides[arg] = (S > 0 && arg == S) ? 0 : sizeof(scalar_t);
226
+ }
227
+ basic_loop(data, strides, i, n, std::forward<func_t>(op));
228
+ }
229
+ }
230
+
231
+
232
+ template <typename traits, typename cb_t>
233
+ inline void unroll_contiguous_scalar_checks(
234
+ const int64_t* /*strides*/,
235
+ std::index_sequence<> /*unused*/,
236
+ cb_t&& cb) {
237
+ cb(0);
238
+ }
239
+
240
+ template <typename traits, typename cb_t, size_t INDEX0, size_t ...INDEX>
241
+ inline void unroll_contiguous_scalar_checks(
242
+ const int64_t* strides,
243
+ std::index_sequence<INDEX0, INDEX...> /*unused*/,
244
+ cb_t&& cb) {
245
+ if (is_contiguous_scalar<traits, INDEX0 + 1>(strides)) {
246
+ cb(INDEX0 + 1);
247
+ } else {
248
+ unroll_contiguous_scalar_checks<traits>(strides, std::index_sequence<INDEX...>{}, std::forward<cb_t>(cb));
249
+ }
250
+ }
251
+
252
+ template <typename op_t, typename vop_t>
253
+ struct VectorizedLoop2d {
254
+ op_t op;
255
+ vop_t vop;
256
+
257
+ using traits = function_traits<op_t>;
258
+ static constexpr int ntensors = traits::arity + 1;
259
+ using data_t = std::array<char*, ntensors>;
260
+
261
+ VectorizedLoop2d(op_t op, vop_t vop):
262
+ op(std::move(op)), vop(std::move(vop)) {}
263
+
264
+ static void advance(data_t &data, const int64_t *outer_strides) {
265
+ for (const auto arg : c10::irange(data.size())) {
266
+ data[arg] += outer_strides[arg];
267
+ }
268
+ }
269
+
270
+ void operator()(char** base, const int64_t *strides, int64_t size0, int64_t size1) {
271
+ data_t data;
272
+ std::copy_n(base, ntensors, data.data());
273
+ const int64_t *outer_strides = &strides[ntensors];
274
+
275
+ if (is_contiguous<traits>(strides)) {
276
+ for ([[maybe_unused]] const auto i : c10::irange(size1)) {
277
+ vectorized_loop(data.data(), size0, 0, op, vop);
278
+ advance(data, outer_strides);
279
+ }
280
+ } else {
281
+ using Indices = std::make_index_sequence<traits::arity>;
282
+ unroll_contiguous_scalar_checks<traits>(strides, Indices{}, [&](size_t idx) {
283
+ if (idx) {
284
+ for ([[maybe_unused]] const auto i : c10::irange(size1)) {
285
+ vectorized_loop(data.data(), size0, idx, op, vop);
286
+ advance(data, outer_strides);
287
+ }
288
+ } else {
289
+ for ([[maybe_unused]] const auto i : c10::irange(size1)) {
290
+ basic_loop(data.data(), strides, 0, size0, op);
291
+ advance(data, outer_strides);
292
+ }
293
+ }
294
+ });
295
+ }
296
+ }
297
+ };
298
+
299
+ template <typename op_t, typename vop_t>
300
+ VectorizedLoop2d<op_t, vop_t> make_vectorized_loop2d(
301
+ op_t &&op, vop_t &&vop) {
302
+ return VectorizedLoop2d<op_t, vop_t>(std::forward<op_t>(op), std::forward<vop_t>(vop));
303
+ }
304
+
305
+ template <typename func_t>
306
+ void cpu_kernel(TensorIteratorBase& iter, func_t&& op, int64_t grain_size = at::internal::GRAIN_SIZE) {
307
+ using traits = function_traits<func_t>;
308
+ // this could be extended to work with void return types
309
+ TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
310
+ TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
311
+ // dynamic casting not currently supported on CPU
312
+ TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
313
+
314
+ iter.for_each([&](char** data, const int64_t* strides, int64_t n) {
315
+ // basic loop can handle 1d slices with arbitrary strides, and 1d slices is all that
316
+ // iter.for_each is ever sending to the loop lambda
317
+ basic_loop(data, strides, 0, n, op);
318
+ }, grain_size);
319
+ iter.cast_outputs();
320
+ }
321
+
322
+ // This function helps write elementwise kernels that requires multiple outputs.
323
+ // It follows the similar structure of cpu_kernel.
324
+ // Instead of `basic_loop` function, a new `multiple_outputs_loop` function is
325
+ // manipulated to handle multiple return values.
326
+ // For now `needs_dynamic_casting` check is not added as the passed lambda (`func_t`)
327
+ // of `multiple_outputs_loop` returns `std::tuple` instead of `scalar_t`.
328
+ // The `gpu_kernel_multiple_outputs` is also implemented without this check,
329
+ // We could extend `needs_dynamic_casting` to support both `std::tuple` and
330
+ // `thrust::tuple` in the future.
331
+ template <typename func_t>
332
+ void cpu_kernel_multiple_outputs(TensorIteratorBase& iter, func_t&& op, int64_t grain_size = at::internal::GRAIN_SIZE) {
333
+ using traits = function_traits<func_t>;
334
+ TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
335
+
336
+ iter.for_each([&](char** data, const int64_t* strides, int64_t n) {
337
+ multiple_outputs_loop(data, strides, 0, n, op);
338
+ }, grain_size);
339
+ iter.cast_outputs();
340
+ }
341
+
342
+ template <bool check_dynamic_cast=true, typename func_t, typename vec_func_t>
343
+ void cpu_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop, int64_t grain_size = at::internal::GRAIN_SIZE) {
344
+ using traits = function_traits<func_t>;
345
+ // this could be extended to work with void return types
346
+ TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
347
+ TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
348
+ // dynamic casting not currently supported on CPU, but some kernels (like Fill)
349
+ // explicitly dynamic_cast, so we give the opt-out of checking.
350
+ if constexpr (check_dynamic_cast) {
351
+ TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
352
+ }
353
+
354
+ iter.for_each(make_vectorized_loop2d(std::forward<func_t>(op), std::forward<vec_func_t>(vop)), grain_size);
355
+ iter.cast_outputs();
356
+ }
357
+
358
+ template <typename func_t>
359
+ void cpu_serial_kernel(TensorIteratorBase& iter, func_t&& op, const Range& range) {
360
+ using traits = function_traits<func_t>;
361
+ constexpr bool result_void = std::is_void_v<typename traits::result_type>;
362
+ TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity &&
363
+ ((result_void && iter.noutputs() == 0) || (!result_void && iter.noutputs() == 1)));
364
+ // dynamic casting not currently supported on CPU
365
+ TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
366
+
367
+ iter.serial_for_each([&](char** data, const int64_t* strides, int64_t n) {
368
+ basic_loop(data, strides, 0, n, op);
369
+ }, range);
370
+ iter.cast_outputs();
371
+ }
372
+
373
+ template <typename func_t>
374
+ void cpu_serial_kernel(TensorIteratorBase& iter, func_t&& op) {
375
+ cpu_serial_kernel(iter, std::forward<func_t>(op), {0, iter.numel()});
376
+ }
377
+
378
+ template <typename func_t, typename vec_func_t>
379
+ void cpu_serial_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop, const Range& range) {
380
+ using traits = function_traits<func_t>;
381
+ // this could be extended to work with void return types
382
+ TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
383
+ TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
384
+ // dynamic casting not currently supported on CPU
385
+ TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
386
+
387
+ iter.serial_for_each(make_vectorized_loop2d(std::forward<func_t>(op), std::forward<vec_func_t>(vop)), range);
388
+ iter.cast_outputs();
389
+ }
390
+
391
+ template <typename func_t, typename vec_func_t>
392
+ void cpu_serial_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop) {
393
+ cpu_serial_kernel_vec(iter, std::forward<func_t>(op), std::forward<vec_func_t>(vop), {0, iter.numel()});
394
+ }
395
+
396
+ }} // namespace at::native::<anonymous>
397
+
398
+ #else
399
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
400
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/MaxUnpoolKernel.h ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ #include <ATen/native/DispatchStub.h>
4
+
5
+ namespace at {
6
+ class Tensor;
7
+
8
+ namespace native {
9
+
10
+ using max_unpooling_fn = void(*)(Tensor&, const Tensor&, const Tensor&);
11
+
12
+ DECLARE_DISPATCH(max_unpooling_fn, max_unpool2d_kernel)
13
+ DECLARE_DISPATCH(max_unpooling_fn, max_unpool3d_kernel)
14
+
15
+ }} // at::native
16
+
17
+ #else
18
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
19
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/PixelShuffleKernel.h ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ #include <ATen/native/DispatchStub.h>
4
+
5
+ namespace at {
6
+ class TensorBase;
7
+ }
8
+
9
+ namespace at::native {
10
+
11
+ using pixel_shuffle_fn = void(*)(TensorBase&, const TensorBase&, int64_t);
12
+ DECLARE_DISPATCH(pixel_shuffle_fn, pixel_shuffle_kernel)
13
+ DECLARE_DISPATCH(pixel_shuffle_fn, pixel_unshuffle_kernel)
14
+
15
+ } // at::native
16
+
17
+ #else
18
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
19
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/ReduceUtils.h ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/Parallel.h>
5
+ #include <ATen/NumericUtils.h>
6
+ #include <ATen/cpu/vec/vec.h>
7
+ #include <ATen/cpu/vec/functional.h>
8
+ #include <ATen/native/ReductionType.h>
9
+ #include <c10/util/irange.h>
10
+ #include <ATen/OpMathType.h>
11
+ #include <ATen/native/cpu/utils.h>
12
+
13
+ namespace at::native {
14
+ inline namespace CPU_CAPABILITY {
15
+
16
+ using namespace vec;
17
+
18
+ #define AT_DISPATCH_REDUCTION_TYPES(op, ...) \
19
+ [&] { \
20
+ switch (op) { \
21
+ case ReductionType::SUM: { \
22
+ static constexpr auto reduce = ReductionType::SUM; \
23
+ return __VA_ARGS__(); \
24
+ } \
25
+ case ReductionType::MEAN: { \
26
+ static constexpr auto reduce = ReductionType::MEAN; \
27
+ return __VA_ARGS__(); \
28
+ } \
29
+ case ReductionType::MIN: { \
30
+ static constexpr auto reduce = ReductionType::MIN; \
31
+ return __VA_ARGS__(); \
32
+ } \
33
+ case ReductionType::MAX: { \
34
+ static constexpr auto reduce = ReductionType::MAX; \
35
+ return __VA_ARGS__(); \
36
+ } \
37
+ case ReductionType::PROD: { \
38
+ static constexpr auto reduce = ReductionType::PROD; \
39
+ return __VA_ARGS__(); \
40
+ } \
41
+ } \
42
+ }()
43
+
44
+ template <typename scalar_t, ReductionType reduce>
45
+ inline vec_scalar_t<scalar_t> init_value() {
46
+ using acc_t = vec_scalar_t<scalar_t>;
47
+ acc_t val;
48
+ if (reduce == ReductionType::SUM ||
49
+ reduce == ReductionType::MEAN) {
50
+ val = static_cast<acc_t>(0);
51
+ } else if (reduce == ReductionType::PROD) {
52
+ val = static_cast<acc_t>(1);
53
+ } else if (reduce == ReductionType::MAX) {
54
+ val = -std::numeric_limits<acc_t>::infinity();
55
+ } else {
56
+ TORCH_INTERNAL_ASSERT(reduce == ReductionType::MIN);
57
+ val = std::numeric_limits<acc_t>::infinity();
58
+ }
59
+ return val;
60
+ }
61
+
62
+ template <typename scalar_t, ReductionType reduce>
63
+ inline vec_scalar_t<scalar_t> init_value(const std::optional<Scalar>& initial) {
64
+ using acc_t = vec_scalar_t<scalar_t>;
65
+ if (initial.has_value()) {
66
+ return initial.value().to<acc_t>();
67
+ } else {
68
+ return init_value<scalar_t, reduce>();
69
+ }
70
+ }
71
+
72
+ template <typename scalar_t>
73
+ inline void init(scalar_t* out, int64_t size, const vec_scalar_t<scalar_t>& val) {
74
+ using Vec = Vectorized<vec_scalar_t<scalar_t>>;
75
+ map<scalar_t>(
76
+ [val](Vec x) { return Vec(val); },
77
+ out,
78
+ out,
79
+ size);
80
+ }
81
+
82
+ template <typename scalar_t, ReductionType reduce>
83
+ inline void init(scalar_t* out, int64_t size, const std::optional<Scalar>& initial) {
84
+ using acc_t = vec_scalar_t<scalar_t>;
85
+ acc_t val = init_value<scalar_t, reduce>(initial);
86
+ init(out, size, val);
87
+ }
88
+
89
+ // overload with `include_self`, used by scatter_reduce
90
+ template <typename scalar_t, ReductionType reduce>
91
+ inline void init(scalar_t* out, int64_t size, bool include_self = false) {
92
+ using acc_t = vec_scalar_t<scalar_t>;
93
+ if (!include_self) {
94
+ acc_t val = init_value<scalar_t, reduce>();
95
+ init(out, size, val);
96
+ }
97
+ }
98
+
99
+ template <typename scalar_t, ReductionType reduce>
100
+ inline void _init(scalar_t* self_ptr, at::opmath_type<scalar_t>* buffer_ptr, int64_t size, bool include_self) {
101
+ if (!include_self) {
102
+ init<at::opmath_type<scalar_t>, reduce>(buffer_ptr, size, include_self);
103
+ } else {
104
+ vec::convert(self_ptr, buffer_ptr, size);
105
+ }
106
+ }
107
+
108
+ template <typename scalar_t>
109
+ inline std::enable_if_t<!std::is_same_v<scalar_t, Vec2>, scalar_t>
110
+ _max(const scalar_t& x, const scalar_t& y) {
111
+ return at::_isnan(y) ? y : std::max(x, y);
112
+ }
113
+
114
+ template <typename scalar_t>
115
+ inline Vectorized<scalar_t> _max(const Vectorized<scalar_t>& x, const Vectorized<scalar_t>& y) {
116
+ // vec::maximum propagates NaN
117
+ return vec::maximum(x, y);
118
+ }
119
+
120
+ template <typename vec_t>
121
+ inline std::enable_if_t<std::is_same_v<vec_t, Vec2>, Vec2>
122
+ _max(const vec_t& x, const vec_t& y) {
123
+ // vec::maximum propagates NaN
124
+ return maximum(x, y);
125
+ }
126
+
127
+ template <typename scalar_t>
128
+ inline std::enable_if_t<!std::is_same_v<scalar_t, Vec2>, scalar_t>
129
+ _min(const scalar_t& x, const scalar_t& y) {
130
+ return at::_isnan(y) ? y : std::min(x, y);
131
+ }
132
+
133
+ template <typename scalar_t>
134
+ inline Vectorized<scalar_t> _min(const Vectorized<scalar_t>& x, const Vectorized<scalar_t>& y) {
135
+ // vec::minimum propagates NaN
136
+ return vec::minimum(x, y);
137
+ }
138
+
139
+ template <typename vec_t>
140
+ inline std::enable_if_t<std::is_same_v<vec_t, Vec2>, Vec2>
141
+ _min(const vec_t& x, const vec_t& y) {
142
+ // vec::minimum propagates NaN
143
+ return minimum(x, y);
144
+ }
145
+
146
+ template <typename scalar_t, typename accumut, typename Op,
147
+ typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
148
+ inline void map_acc(
149
+ const Op& vec_fun,
150
+ accumut* output_data,
151
+ const accumut* input_data,
152
+ const scalar_t* input_data2,
153
+ int64_t size) {
154
+ using Vec = vec::Vectorized<scalar_t>;
155
+ using aVec = vec::Vectorized<accumut>;
156
+ int64_t d = 0;
157
+ constexpr int64_t kVecSize = Vec::size();
158
+ constexpr int64_t kaVecSize = aVec::size();
159
+ for (d = 0; d < size - (size % kVecSize); d += kVecSize) {
160
+ Vec data2_vec = Vec::loadu(input_data2 + d);
161
+ auto [data2_avec0, data2_avec1] = convert_to_float<scalar_t>(data2_vec);
162
+ aVec input_vec0 = aVec::loadu(input_data + d);
163
+ aVec input_vec1 = aVec::loadu(input_data + d + kaVecSize);
164
+ vec_fun(input_vec0, data2_avec0).store(output_data + d);
165
+ vec_fun(input_vec1, data2_avec1).store(output_data + d + kaVecSize);
166
+ }
167
+ if (size - d > 0) {
168
+ int64_t tail_size = size - d;
169
+ Vec data2_vec = Vec::loadu(input_data2 + d, tail_size);
170
+ auto [data2_avec0, data2_avec1] = convert_to_float<scalar_t>(data2_vec);
171
+ if (tail_size > kaVecSize) {
172
+ aVec input_vec0 = aVec::loadu(input_data + d);
173
+ aVec input_vec1 = aVec::loadu(input_data + d + kaVecSize, tail_size - kaVecSize);
174
+ vec_fun(input_vec0, data2_avec0).store(output_data + d);
175
+ vec_fun(input_vec1, data2_avec1).store(output_data + d + kaVecSize, tail_size - kaVecSize);
176
+ } else {
177
+ aVec input_vec0 = aVec::loadu(input_data + d, tail_size);
178
+ vec_fun(input_vec0, data2_avec0).store(output_data + d, tail_size);
179
+ }
180
+ }
181
+ }
182
+
183
+ // for Max and Min, propagate NaN:
184
+ template <typename T, ReductionType reduce>
185
+ inline T update(const T& x, const T& y) {
186
+ if (reduce == ReductionType::SUM ||
187
+ reduce == ReductionType::MEAN) {
188
+ return x + y;
189
+ } else if (reduce == ReductionType::PROD) {
190
+ return x * y;
191
+ } else if (reduce == ReductionType::MAX) {
192
+ return _max(x, y);
193
+ } else {
194
+ TORCH_INTERNAL_ASSERT(reduce == ReductionType::MIN);
195
+ return _min(x, y);
196
+ }
197
+ }
198
+
199
+ template <typename scalar_t, ReductionType reduce>
200
+ inline void update(scalar_t* out, const scalar_t* data, int64_t K) {
201
+ using Vec = vec::Vectorized<vec_scalar_t<scalar_t>>;
202
+ map2<scalar_t>(
203
+ [](Vec x, Vec y) { return update<Vec, reduce>(x, y); },
204
+ out,
205
+ out,
206
+ data,
207
+ K);
208
+ }
209
+
210
+ template <typename scalar_t, ReductionType reduce,
211
+ typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
212
+ inline void update(at::opmath_type<scalar_t>* out, const scalar_t* data, int64_t K) {
213
+ using opmath_t = at::opmath_type<scalar_t>;
214
+ using Vec = vec::Vectorized<opmath_t>;
215
+ map_acc<scalar_t, opmath_t>(
216
+ [](Vec x, Vec y) { return update<Vec, reduce>(x, y); },
217
+ out,
218
+ out,
219
+ data,
220
+ K);
221
+ }
222
+
223
+ template <typename scalar_t, ReductionType reduce>
224
+ inline void write(scalar_t* out, int64_t count, int64_t K) {
225
+ using Vec = vec::Vectorized<vec_scalar_t<scalar_t>>;
226
+ if (reduce == ReductionType::MEAN) {
227
+ if (count > 0) {
228
+ vec::map<scalar_t>(
229
+ [count](Vec x) { return x / Vec(count); },
230
+ out,
231
+ out,
232
+ K);
233
+ }
234
+ }
235
+ }
236
+
237
+ } // namespace CPU_CAPABILITY
238
+ } // namespace at::native
239
+
240
+ #else
241
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
242
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.h ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/native/DispatchStub.h>
5
+ #include <c10/macros/Macros.h>
6
+ #include <c10/util/BFloat16.h>
7
+ #include <c10/util/Half.h>
8
+
9
+ namespace at::native {
10
+ #if !defined(C10_MOBILE)
11
+ using fp16_gemv_fn = void(*)(int, int, float, const Half*, int, const Half*, int, float, Half*, int);
12
+ DECLARE_DISPATCH(fp16_gemv_fn, fp16_gemv_trans_stub)
13
+
14
+ using bf16_gemv_fn = void(*)(int, int, BFloat16, const BFloat16*, int, const BFloat16*, int, BFloat16, BFloat16*, int);
15
+ DECLARE_DISPATCH(bf16_gemv_fn, bf16_gemv_trans_stub)
16
+
17
+ using fp16_dot_fn = float(*)(const int64_t, const Half*, const int64_t, const Half*, const int64_t);
18
+ DECLARE_DISPATCH(fp16_dot_fn, fp16_dot_stub)
19
+
20
+ using bf16_dot_fn = float(*)(const int64_t, const BFloat16*, const int64_t, const BFloat16*, const int64_t);
21
+ DECLARE_DISPATCH(bf16_dot_fn, bf16_dot_stub)
22
+
23
+ inline namespace CPU_CAPABILITY {
24
+ float fp16_dot_with_fp32_arith(const Half* vec1, const Half* vec2, int64_t len);
25
+ float bf16_dot_with_fp32_arith(const BFloat16* vec1, const BFloat16* vec2, int64_t len);
26
+ } // inline namespace CPU_CAPABILITY
27
+ #endif // !defined(C10_MOBILE)
28
+ } // namespace at::native
29
+
30
+ #else
31
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
32
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/SampledAddmmKernel.h ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/core/Tensor.h>
5
+ #include <ATen/native/DispatchStub.h>
6
+
7
+ namespace at::native {
8
+
9
+ using sampled_addmm_sparse_csr_fn = void(*)(const Tensor&, const Tensor&, const Scalar&, const Scalar&, const Tensor&);
10
+
11
+ DECLARE_DISPATCH(sampled_addmm_sparse_csr_fn, sampled_addmm_sparse_csr_stub)
12
+
13
+ } // at::native
14
+
15
+ #else
16
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
17
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/SerialStackImpl.h ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ // Copyright 2004-present Facebook. All Rights Reserved.
3
+ #pragma once
4
+
5
+ #include <ATen/core/Tensor.h>
6
+
7
+ #include <ATen/MemoryOverlap.h>
8
+ #include <ATen/Parallel.h>
9
+ #include <ATen/TensorIterator.h>
10
+ #include <ATen/cpu/vec/functional.h>
11
+ #include <ATen/cpu/vec/vec.h>
12
+ #include <c10/util/irange.h>
13
+
14
+ namespace at::native::detail {
15
+
16
+ struct InputMeta {
17
+ void* data_ptr;
18
+ int64_t inner_size;
19
+
20
+ InputMeta(const Tensor& t, int64_t dim, int64_t inner)
21
+ : data_ptr(t.data_ptr()), inner_size(t.sizes()[dim] * inner) {}
22
+ };
23
+
24
+ // This kernel is used by two TensorList types:
25
+ // 1. stack_serial_kernel uses at::ArrayRef<Tensor>
26
+ // 2. Static runtime calls this kernel directly (csrc/jit/runtime/static/ops.cpp) with
27
+ // ProcessedNodeInputWrapper.
28
+ // When making changes, make sure that they are compatible with both types!
29
+ template <typename scalar_t, typename TensorListType>
30
+ void stack_serial_kernel_impl(Tensor& result, TensorListType tensors, int64_t dim) {
31
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
32
+ dim >= 0 && dim <= result.dim(),
33
+ "dim out of range in stack_serial_kernel_impl");
34
+ int64_t outer =
35
+ result.numel() / (result.sizes()[dim] * result.strides()[dim]);
36
+ scalar_t* result_data = result.data_ptr<scalar_t>();
37
+ int64_t ninputs = tensors.size();
38
+ std::vector<InputMeta> inputs;
39
+ inputs.reserve(ninputs);
40
+ for (const auto& tensor : tensors) {
41
+ inputs.emplace_back(tensor, dim, tensor.strides()[dim]);
42
+ }
43
+
44
+ using Vec = vec::Vectorized<scalar_t>;
45
+ scalar_t* result_ptr = result_data;
46
+ for (const auto i : c10::irange(outer)) {
47
+ for (const auto j : c10::irange(ninputs)) {
48
+ int64_t local_inner = inputs[j].inner_size;
49
+ scalar_t* input_ptr = (scalar_t*)(inputs[j].data_ptr) + i * local_inner;
50
+
51
+ if (local_inner < Vec::size()) {
52
+ for (const auto k : c10::irange(local_inner)) {
53
+ result_ptr[k] = input_ptr[k];
54
+ }
55
+ } else {
56
+ vec::map(
57
+ [](Vec x) { return x; }, result_ptr, input_ptr, local_inner);
58
+ }
59
+ result_ptr += local_inner;
60
+ }
61
+ }
62
+ }
63
+
64
+ // Checks to see whether native stack can be invoked under these conditions:
65
+ // - result and input tensors are contiguous
66
+ // - only one thread is used
67
+ // - no type promotion has to occur
68
+ // - tensors dtype is Double or Float
69
+ template <typename TensorListType>
70
+ bool can_use_native_serial_stack_impl(Tensor& result, TensorListType tensors, int64_t dim) {
71
+ TORCH_CHECK(!tensors.empty(), "expected a non-empty list of Tensors");
72
+ const Tensor& first_tensor = tensors[0];
73
+ // stack dimension should be in range [0,firstTensor.dim())
74
+ // dim == firstTensor.dim() is a valid input, but it is handled by default code path
75
+ // that uses unsqueeze
76
+ if (dim >= first_tensor.dim()) return false;
77
+ // Native stack doesn't apply any tensor is skipped.
78
+ if (first_tensor.numel() == 0 && first_tensor.dim() == 1) return false;
79
+ // there should be no type promotion
80
+ if (result.dtype() != first_tensor.dtype()) return false;
81
+
82
+ auto first_tensor_mem_format = first_tensor.suggest_memory_format();
83
+ ScalarType dtype = first_tensor.scalar_type();
84
+
85
+ if (!result.is_contiguous(first_tensor_mem_format)) {
86
+ return false;
87
+ }
88
+
89
+ // fast path only works for Double and Float
90
+ if (dtype != ScalarType::Double && dtype != ScalarType::Float) {
91
+ return false;
92
+ }
93
+
94
+ // check remainder of inputs
95
+ #ifndef STRIP_ERROR_MESSAGES
96
+ auto const &first_tensor_shape = first_tensor.sizes();
97
+ #endif
98
+ for (const auto i : c10::irange(1, tensors.size())) {
99
+ auto const &tensor = tensors[i];
100
+ TORCH_CHECK(tensors[i].sizes() == first_tensor.sizes(),
101
+ "stack expects each tensor to be equal size, but got ", first_tensor_shape,
102
+ " at entry 0 and ", tensor.sizes(), " at entry ", i);
103
+
104
+ // every tensor must be contiguous
105
+ // tensor sizes and strides must be the same
106
+ // there should be no type promotion
107
+ if (!tensor.is_contiguous(first_tensor_mem_format) ||
108
+ tensor.strides() != first_tensor.strides() ||
109
+ tensor.dtype() != dtype) {
110
+ return false;
111
+ }
112
+ }
113
+
114
+ // fast native stack should only be used when it is not worth using multiple threads
115
+ // or there is only one thread. Note that we aren't checking result.numel() here because
116
+ // it may not have been resized and we want to defer that cost till later.
117
+ int64_t numel_in_stack = first_tensor.numel() * tensors.size();
118
+ return numel_in_stack < at::internal::GRAIN_SIZE || at::get_num_threads() == 1;
119
+ }
120
+
121
+ template <typename TensorListType, bool should_skip_overlap_check>
122
+ struct CanUseNativeSerialStack;
123
+
124
+ template <typename TensorListType>
125
+ struct CanUseNativeSerialStack<TensorListType, false> {
126
+ static bool call(Tensor& result, TensorListType tensors, int64_t dim) {
127
+ // Inputs cannot alias the output tensor
128
+ for (const auto i : c10::irange(tensors.size())) {
129
+ auto lap = at::get_overlap_status(result, tensors[i]);
130
+ TORCH_CHECK(lap != at::MemOverlapStatus::Partial &&
131
+ lap != at::MemOverlapStatus::Full, 0,
132
+ "unsupported operation: the input tensors cannot refer to any of the "
133
+ "output memory locations. Found overlap in input tensor ", i);
134
+ }
135
+
136
+ return can_use_native_serial_stack_impl(result, tensors, dim);
137
+ }
138
+ };
139
+
140
+ template <typename TensorListType>
141
+ struct CanUseNativeSerialStack<TensorListType, true> {
142
+ static bool call(Tensor& result, TensorListType tensors, int64_t dim) {
143
+ return can_use_native_serial_stack_impl(result, tensors, dim);
144
+ }
145
+ };
146
+
147
+ } // namespace at::native::detail
148
+
149
+ #else
150
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
151
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/SoftmaxKernel.h ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/native/DispatchStub.h>
5
+ #include <cstdint>
6
+
7
+ namespace at {
8
+ class Tensor;
9
+
10
+ namespace native {
11
+
12
+ using forward_fn = void (*)(const Tensor&, const Tensor&);
13
+ using backward_fn = void(*)(const Tensor &, const Tensor &, const Tensor&);
14
+
15
+ DECLARE_DISPATCH(forward_fn, softmax_lastdim_kernel)
16
+ DECLARE_DISPATCH(forward_fn, log_softmax_lastdim_kernel)
17
+ DECLARE_DISPATCH(backward_fn, softmax_backward_lastdim_kernel)
18
+ DECLARE_DISPATCH(backward_fn, log_softmax_backward_lastdim_kernel)
19
+
20
+ using forward_fn_with_dim = void(*)(const Tensor &, const Tensor &, const int64_t);
21
+ using backward_fn_with_dim =
22
+ void (*)(const Tensor&, const Tensor&, const Tensor&, const int64_t);
23
+
24
+ DECLARE_DISPATCH(forward_fn_with_dim, softmax_kernel)
25
+ DECLARE_DISPATCH(forward_fn_with_dim, log_softmax_kernel)
26
+ DECLARE_DISPATCH(backward_fn_with_dim, softmax_backward_kernel)
27
+ DECLARE_DISPATCH(backward_fn_with_dim, log_softmax_backward_kernel)
28
+ }
29
+ }
30
+
31
+ #else
32
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
33
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/SpmmReduceKernel.h ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/core/Tensor.h>
5
+ #include <ATen/native/DispatchStub.h>
6
+ #include <ATen/native/ReductionType.h>
7
+
8
+ namespace at::native {
9
+
10
+ using spmm_reduce_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
11
+ using spmm_reduce_arg_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
12
+ using spmm_reduce_backward_input_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
13
+ using spmm_reduce_backward_input_arg_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
14
+ using spmm_reduce_backward_other_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
15
+
16
+ DECLARE_DISPATCH(spmm_reduce_fn, spmm_reduce_stub)
17
+ DECLARE_DISPATCH(spmm_reduce_arg_fn, spmm_reduce_arg_stub)
18
+ DECLARE_DISPATCH(spmm_reduce_backward_input_fn, spmm_reduce_backward_input_stub)
19
+ DECLARE_DISPATCH(spmm_reduce_backward_input_arg_fn, spmm_reduce_backward_input_arg_stub)
20
+ DECLARE_DISPATCH(spmm_reduce_backward_other_fn, spmm_reduce_backward_other_stub)
21
+ DECLARE_DISPATCH(spmm_reduce_backward_input_arg_fn, spmm_reduce_backward_other_arg_stub)
22
+
23
+ } // at::native
24
+
25
+ #else
26
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
27
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/StackKernel.h ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ // Copyright 2004-present Facebook. All Rights Reserved.
3
+ #pragma once
4
+
5
+ #include <ATen/core/Tensor.h>
6
+ #include <ATen/native/DispatchStub.h>
7
+
8
+ namespace at::native {
9
+
10
+ using stack_serial_fn = void(*)(Tensor &, TensorList, int64_t);
11
+ DECLARE_DISPATCH(stack_serial_fn, stack_serial_stub)
12
+
13
+ } // namespace at::native
14
+
15
+ #else
16
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
17
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/UpSampleKernelAVXAntialias.h ADDED
@@ -0,0 +1,1381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ /*
3
+ The Python Imaging Library (PIL) is
4
+
5
+ Copyright © 1997-2011 by Secret Labs AB
6
+ Copyright © 1995-2011 by Fredrik Lundh
7
+
8
+ Pillow is the friendly PIL fork. It is
9
+
10
+ Copyright © 2010-2022 by Alex Clark and contributors
11
+
12
+ Like PIL, Pillow is licensed under the open source HPND License
13
+ */
14
+
15
+ // This code is heavily inspired from PILLOW-SIMD's implementation:
16
+ // https://github.com/uploadcare/pillow-simd/blob/simd/master/src/libImaging/Resample.c
17
+
18
+ #pragma once
19
+ #ifdef CPU_CAPABILITY_AVX2
20
+ // TODO: This file only supports AVX2. We could split the AVX kernels into
21
+ // smaller logical blocks in order to port them into the Vec.h logic. This would
22
+ // allow to support other vectorization architectures and perhaps also support
23
+ // the non-vectorized fallback (we'd need to make sure it's not slower than the
24
+ // current fallback).
25
+
26
+ #include <ATen/core/Tensor.h>
27
+ #include <ATen/cpu/vec/intrinsics.h>
28
+ #include <c10/util/irange.h>
29
+
30
+ #ifndef AT_PER_OPERATOR_HEADERS
31
+ #include <ATen/Functions.h>
32
+ #else
33
+ #include <ATen/ops/empty.h>
34
+ #endif
35
+
36
+
37
+ namespace {
38
+
39
+ inline __m128i mm_cvtsi32_si128(const uint8_t* C10_RESTRICT ptr, bool i32_aligned) {
40
+ int32_t v;
41
+ if (i32_aligned) {
42
+ v = *(const int32_t*)ptr;
43
+ } else {
44
+ std::memcpy(&v, ptr, 4);
45
+ }
46
+ return _mm_cvtsi32_si128(v);
47
+ }
48
+
49
+ inline __m128i mm_cvtepu8_epi32(const uint8_t* C10_RESTRICT ptr, bool i32_aligned) {
50
+ return _mm_cvtepu8_epi32(mm_cvtsi32_si128(ptr, i32_aligned));
51
+ }
52
+
53
+ inline void _write_endline_rgb_as_uint32(
54
+ uint8_t* C10_RESTRICT output,
55
+ uint32_t data
56
+ ) {
57
+ // data is (R G B X), output is (X1 X2 X3 | R1 B1 G1 R2 ...)
58
+ // Here we explicitly set X as R1
59
+ uint8_t* data_ptr = reinterpret_cast<uint8_t*>(&data);
60
+ data_ptr[3] = output[3];
61
+ std::memcpy(output, data_ptr, 4);
62
+ }
63
+
64
+ at::Tensor unpack_rgb(const at::Tensor& packed_tensor) {
65
+ // Convert a "packed" tensor (typically RGBRGBRGB if channels_last) into
66
+ // RGBARGBARGBA format where A is hard-coded to 0. Each pixel is encoded
67
+ // into as 32 bits. This generalizes to num_channels <= 4 and also works for
68
+ // non-channels_last tensors.
69
+
70
+ const uint8_t* packed = (const uint8_t*)packed_tensor.const_data_ptr<uint8_t>();
71
+ auto num_pixels = packed_tensor.size(1) * packed_tensor.size(2);
72
+ auto num_channels = packed_tensor.size(0);
73
+
74
+ constexpr int rgba_size = 4;
75
+ auto unpacked_tensor = at::empty({rgba_size, packed_tensor.size(1), packed_tensor.size(2)}, at::CPU(at::kByte));
76
+ uint8_t* unpacked = (uint8_t*) unpacked_tensor.data_ptr<uint8_t>();
77
+
78
+ auto stride_i = packed_tensor.stride(2);
79
+ auto stride_j = packed_tensor.stride(0);
80
+
81
+ for (const auto i : c10::irange(num_pixels)) {
82
+ for (const auto j : c10::irange(rgba_size)) {
83
+ unpacked[rgba_size * i + j] = (j < num_channels) ? packed[stride_i * i + stride_j * j] : 0;
84
+ }
85
+ }
86
+ return unpacked_tensor;
87
+ }
88
+
89
+ void pack_rgb(
90
+ const at::Tensor& unpacked_tensor, // IN
91
+ const at::Tensor& packed_tensor // OUT
92
+ ) {
93
+ // Convert from unpacked channels last 3-channels or 4-channels tensor into original data layout.
94
+
95
+ uint8_t* unpacked = (uint8_t*)unpacked_tensor.data_ptr<uint8_t>();
96
+ uint8_t* packed = (uint8_t*)packed_tensor.data_ptr<uint8_t>();
97
+ auto num_pixels = packed_tensor.size(1) * packed_tensor.size(2);
98
+ auto num_channels = packed_tensor.size(0);
99
+
100
+ auto unpacked_increment = unpacked_tensor.size(0);
101
+ auto packed_increment = packed_tensor.stride(2);
102
+ auto packed_stride = packed_tensor.stride(0);
103
+
104
+ TORCH_INTERNAL_ASSERT(unpacked_increment == 3 || unpacked_increment == 4);
105
+
106
+ for ([[maybe_unused]] const auto i : c10::irange(num_pixels)) {
107
+ for (const auto j : c10::irange(num_channels)) {
108
+ packed[j * packed_stride] = unpacked[j];
109
+ }
110
+ unpacked += unpacked_increment;
111
+ packed += packed_increment;
112
+ }
113
+ }
114
+
115
+ void ImagingResampleHorizontalConvolution8u4x(
116
+ uint8_t* C10_RESTRICT lineOut0,
117
+ uint8_t* C10_RESTRICT lineOut1,
118
+ uint8_t* C10_RESTRICT lineOut2,
119
+ uint8_t* C10_RESTRICT lineOut3,
120
+ int64_t out_xsize,
121
+ const uint8_t* C10_RESTRICT lineIn0,
122
+ const uint8_t* C10_RESTRICT lineIn1,
123
+ const uint8_t* C10_RESTRICT lineIn2,
124
+ const uint8_t* C10_RESTRICT lineIn3,
125
+ int64_t in_xsize,
126
+ const int64_t* idx_ptr_xmin,
127
+ const int64_t* idx_ptr_size,
128
+ const int16_t* kk,
129
+ int kmax,
130
+ unsigned int coefs_precision,
131
+ int64_t num_channels,
132
+ bool is_last_line);
133
+
134
+ void ImagingResampleHorizontalConvolution8u(
135
+ uint8_t* C10_RESTRICT lineOut,
136
+ int64_t out_xsize,
137
+ const uint8_t* C10_RESTRICT lineIn,
138
+ int64_t in_xsize,
139
+ const int64_t* idx_ptr_xmin,
140
+ const int64_t* idx_ptr_size,
141
+ const int16_t* kk,
142
+ int kmax,
143
+ unsigned int coefs_precision,
144
+ int64_t num_channels,
145
+ bool is_last_line);
146
+
147
+ void ImagingResampleVerticalConvolution8u(
148
+ uint8_t* C10_RESTRICT lineOut,
149
+ const uint8_t* C10_RESTRICT lineIn,
150
+ int64_t xsize,
151
+ int64_t ids_min,
152
+ int64_t ids_size,
153
+ const int16_t* k,
154
+ unsigned int coefs_precision,
155
+ int64_t num_channels);
156
+
157
+ template<int num_channels>
158
+ void ImagingResampleHorizontal(
159
+ const at::Tensor & unpacked_output,
160
+ const at::Tensor & unpacked_input,
161
+ int ksize,
162
+ const std::vector<at::Tensor>& horiz_indices_weights,
163
+ unsigned int horiz_weights_precision) {
164
+
165
+ // Interpolation horizontal pass: we compute x-axis (image width) interpolation outputs.
166
+
167
+ // Input data is stored as
168
+ // input = [r[0], g[0], b[0], a[0], r[1], g[1], b[1], a[1], r[2], g[2], b[2], a[2], ...]
169
+ // Weights are float values computed for each output pixel and rescaled to uint16:
170
+ // weights[i] = [w[i, 0], w[i, 1], ..., w[i, K-1]]
171
+ // We want to compute the output as following:
172
+ // output = [oR[0], oG[0], oB[0], oA[0], oR[1], oG[1], oB[1], oA[1], ...]
173
+ // where
174
+ // oR[yoffset + i] = r[yoffset + xmin[i]] * w[i, 0] + ... + r[yoffset + xmin[i] + K-1] * w[i, K-1]
175
+ // oG[yoffset + i] = g[yoffset + xmin[i]] * w[i, 0] + ... + g[yoffset + xmin[i] + K-1] * w[i, K-1]
176
+ // oB[yoffset + i] = b[yoffset + xmin[i]] * w[i, 0] + ... + b[yoffset + xmin[i] + K-1] * w[i, K-1]
177
+ //
178
+
179
+ // TODO: we may want to merge that into the fallback code (currently called
180
+ // basic_loop_aa_horizontal<uint8_t>)
181
+ // Although this may not be needed if / when we port all this code to use
182
+ // Vec.h since this would potentially give us another fall-back implem
183
+
184
+ const int16_t* kk = (int16_t*)(horiz_indices_weights[3].const_data_ptr<double>());
185
+
186
+ auto xout = unpacked_output.size(2);
187
+ auto yout = unpacked_output.size(1);
188
+ auto xin = unpacked_input.size(2);
189
+ TORCH_INTERNAL_ASSERT(num_channels == unpacked_input.size(0));
190
+
191
+ const int64_t* idx_ptr_xmin = horiz_indices_weights[0].const_data_ptr<int64_t>();
192
+ const int64_t* idx_ptr_size = horiz_indices_weights[1].const_data_ptr<int64_t>();
193
+
194
+ uint8_t* unpacked_output_p = unpacked_output.data_ptr<uint8_t>();
195
+ const uint8_t* unpacked_input_p = unpacked_input.const_data_ptr<uint8_t>();
196
+
197
+ int64_t yy = 0;
198
+ auto xout_stride = xout * num_channels;
199
+ auto xin_stride = xin * num_channels;
200
+ for (; yy < yout - 3; yy += 4) {
201
+ ImagingResampleHorizontalConvolution8u4x(
202
+ unpacked_output_p + yy * xout_stride,
203
+ unpacked_output_p + (yy + 1) * xout_stride,
204
+ unpacked_output_p + (yy + 2) * xout_stride,
205
+ unpacked_output_p + (yy + 3) * xout_stride,
206
+ xout,
207
+ unpacked_input_p + yy * xin_stride,
208
+ unpacked_input_p + (yy + 1) * xin_stride,
209
+ unpacked_input_p + (yy + 2) * xin_stride,
210
+ unpacked_input_p + (yy + 3) * xin_stride,
211
+ xin,
212
+ idx_ptr_xmin,
213
+ idx_ptr_size,
214
+ kk,
215
+ ksize,
216
+ horiz_weights_precision,
217
+ num_channels,
218
+ yy + 3 == yout - 1);
219
+ }
220
+ for (; yy < yout; yy++) {
221
+ ImagingResampleHorizontalConvolution8u(
222
+ unpacked_output_p + yy * xout_stride,
223
+ xout,
224
+ unpacked_input_p + yy * xin_stride,
225
+ xin,
226
+ idx_ptr_xmin,
227
+ idx_ptr_size,
228
+ kk,
229
+ ksize,
230
+ horiz_weights_precision,
231
+ num_channels,
232
+ yy == yout - 1);
233
+ }
234
+ }
235
+
236
+ void ImagingResampleVertical(
237
+ const at::Tensor & unpacked_output,
238
+ const at::Tensor & unpacked_input,
239
+ int ksize,
240
+ const std::vector<at::Tensor>& vert_indices_weights,
241
+ unsigned int vert_weights_precision) {
242
+
243
+ // Interpolation vertical pass: we compute y-axis interpolation outputs.
244
+ // Input data is stored as
245
+ // input = [r[0], g[0], b[0], a[0], r[1], g[1], b[1], a[1], r[2], g[2], b[2], a[2], ...]
246
+ // Weights are float values computed for each output pixel and rescaled to uint16:
247
+ // weights[i] = [w[i, 0], w[i, 1], ..., w[i, K-1]]
248
+ // We want to compute the output as following:
249
+ // output = [oR[0], oG[0], oB[0], oA[0], oR[1], oG[1], oB[1], oA[1], ...]
250
+ // where
251
+ // oR[xoffset + i] = r[xoffset + ymin[i]] * w[i, 0] + ... + r[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1]
252
+ // oG[xoffset + i] = g[xoffset + ymin[i]] * w[i, 0] + ... + g[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1]
253
+ // oB[xoffset + i] = b[xoffset + ymin[i]] * w[i, 0] + ... + b[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1]
254
+
255
+ // TODO: we may want to merge that into the fallback code (currently called
256
+ // basic_loop_aa_vertical<uint8_t>)
257
+ // Although this may not be needed if / when we port all this code to use
258
+ // Vec.h since this would potentially give us another fall-back implem
259
+ const int16_t* kk = (int16_t*)(vert_indices_weights[3].const_data_ptr<double>());
260
+
261
+ const int64_t* idx_ptr_xmin = vert_indices_weights[0].const_data_ptr<int64_t>();
262
+ const int64_t* idx_ptr_size = vert_indices_weights[1].const_data_ptr<int64_t>();
263
+
264
+ uint8_t* unpacked_output_p = unpacked_output.data_ptr<uint8_t>();
265
+ const uint8_t* unpacked_input_p = unpacked_input.const_data_ptr<uint8_t>();
266
+
267
+ auto xout = unpacked_output.size(2);
268
+ auto yout = unpacked_output.size(1);
269
+ const auto num_channels = unpacked_input.size(0);
270
+ TORCH_INTERNAL_ASSERT(num_channels == unpacked_output.size(0));
271
+
272
+ auto xout_stride = xout * num_channels;
273
+ for (const auto yy : c10::irange(yout)) {
274
+ const auto* k = &kk[yy * ksize];
275
+ auto ids_min = idx_ptr_xmin[yy];
276
+ auto ids_size = idx_ptr_size[yy];
277
+ ImagingResampleVerticalConvolution8u(
278
+ unpacked_output_p + yy * xout_stride,
279
+ unpacked_input_p,
280
+ xout,
281
+ ids_min,
282
+ ids_size,
283
+ k,
284
+ vert_weights_precision,
285
+ num_channels);
286
+ }
287
+ }
288
+
289
+ // This is the only public entry point in this file. It supports bilinear or bicubic
290
+ // mode for uint8 dtype when C <= 4, with or without antialias. The
291
+ // implem is based on PIL-SIMD.
292
+ // Its equivalent implementation (fallback) for when AVX isn't supported or when
293
+ // C > 4 is separable_upsample_generic_Nd_kernel_impl() There are a bunch of
294
+ // future improvement that can be done: look for the TODOs in this file.
295
+ // For details on how the weights are computed and how the multiplications are
296
+ // run on int (instead of float weights), see
297
+ // [ Weights computation for uint8_t and multiplication trick ]
298
+ // For details on how the AVX kernels are implemented, see
299
+ // https://gist.github.com/NicolasHug/47c97d731f05eaad5694c173849b86f5
300
+ // See also [ Support for antialias=False as a subcase of antialias=True ] to
301
+ // learn more about how the antialias=False case is computed. The same holds
302
+ // here: all these kernels are general enough to handle an arbitrary number of
303
+ // weights, but when aa=False they could be optimized further.
304
+ template <typename scale_type, class F>
305
+ void upsample_avx_bilinear_bicubic_uint8(
306
+ const at::Tensor& input_,
307
+ const at::Tensor& output,
308
+ bool align_corners,
309
+ const scale_type& scales,
310
+ bool antialias) {
311
+ auto batch_size = input_.size(0);
312
+ auto num_channels = input_.size(1);
313
+ auto xin = input_.size(3);
314
+ auto yin = input_.size(2);
315
+ auto xout = output.size(3);
316
+ auto yout = output.size(2);
317
+
318
+ if (xin == xout && yin == yout) {
319
+ output.copy_(input_);
320
+ return;
321
+ }
322
+
323
+ at::Tensor input = input_;
324
+ if (!(input.is_contiguous() || input.is_contiguous(at::MemoryFormat::ChannelsLast))) {
325
+ // If input is not contiguous with memory format channels first or channels last,
326
+ // we explicitly convert the input to contiguous channels last memory format.
327
+ // This simplifies the rest of the code and let us assume that the format is only contiguous channels first or channels last,
328
+ // Most tensors going through this `if` block won't need to go through unpacking, but those having C < 3 may
329
+ // have to (this means 2 copies are made). We could avoid the extra copy by handling non-contiguous input
330
+ // directly within unpack_rgb() and pack_rgb(), but initial attempts showed that this is fairly complex.
331
+ input = input.contiguous(at::MemoryFormat::ChannelsLast);
332
+ }
333
+
334
+ auto need_horizontal = xout != xin;
335
+ auto need_vertical = yout != yin;
336
+
337
+ int ksize_horiz, ksize_vert;
338
+ std::vector<at::Tensor> horiz_indices_weights, vert_indices_weights;
339
+ unsigned int horiz_weights_precision, vert_weights_precision;
340
+
341
+ bool skip_unpacking = (num_channels == 3 || num_channels == 4) && input.is_contiguous(at::MemoryFormat::ChannelsLast);
342
+ bool skip_packing = (num_channels == 3 || num_channels == 4) && output.is_contiguous(at::MemoryFormat::ChannelsLast);
343
+
344
+ if (need_horizontal) {
345
+ int interp_dim = 3;
346
+ auto stride = skip_unpacking ? num_channels : 4;
347
+ std::tie(horiz_indices_weights, ksize_horiz, horiz_weights_precision) =
348
+ F::compute_index_ranges_int16_weights(
349
+ /*input_size=*/xin,
350
+ /*output_size=*/xout,
351
+ /*stride=*/stride,
352
+ /*ndims=*/4,
353
+ /*reshape_dim=*/interp_dim,
354
+ /*align_corners=*/align_corners,
355
+ /*opt_scale=*/scales[interp_dim - 2],
356
+ /*antialias=*/antialias,
357
+ /*align_i32=*/true);
358
+ }
359
+
360
+ if (need_vertical) {
361
+ int interp_dim = 2;
362
+ auto stride = skip_unpacking ? num_channels * xout : 4 * xout;
363
+ std::tie(vert_indices_weights, ksize_vert, vert_weights_precision) =
364
+ F::compute_index_ranges_int16_weights(
365
+ /*input_size=*/yin,
366
+ /*output_size=*/yout,
367
+ /*stride=*/stride,
368
+ /*ndims=*/4,
369
+ /*reshape_dim=*/interp_dim,
370
+ /*align_corners=*/align_corners,
371
+ /*opt_scale=*/scales[interp_dim - 2],
372
+ /*antialias=*/antialias,
373
+ /*align_i32=*/true);
374
+ }
375
+
376
+ at::Tensor buffer_horiz, buffer_vert;
377
+ // Minor optimization: we can avoid allocating an extra buffer if we're performing
378
+ // horizontal-only or vertical-only interpolation, and if the tensor doesn't
379
+ // need repacking
380
+ if (need_horizontal && (need_vertical || !skip_packing)) {
381
+ auto c = skip_unpacking ? num_channels : 4;
382
+ buffer_horiz = at::empty({c, yin, xout}, input.options());
383
+ }
384
+ if (need_vertical && !skip_packing) {
385
+ auto c = skip_unpacking ? num_channels : 4;
386
+ buffer_vert = at::empty({c, yout, xout}, input.options());
387
+ }
388
+
389
+ for (const auto i : c10::irange(batch_size)) {
390
+
391
+ at::Tensor unpacked_input = skip_unpacking ? input[i] : unpack_rgb(input[i]);
392
+ at::Tensor unpacked_output;
393
+
394
+ if (need_horizontal) {
395
+ at::Tensor unpacked_output_temp = (need_vertical || !skip_packing) ? buffer_horiz : output[i];
396
+
397
+ if (skip_unpacking && num_channels == 3) {
398
+ ImagingResampleHorizontal<3>(
399
+ unpacked_output_temp,
400
+ unpacked_input,
401
+ ksize_horiz,
402
+ horiz_indices_weights,
403
+ horiz_weights_precision);
404
+ } else {
405
+ ImagingResampleHorizontal<4>(
406
+ unpacked_output_temp,
407
+ unpacked_input,
408
+ ksize_horiz,
409
+ horiz_indices_weights,
410
+ horiz_weights_precision);
411
+ }
412
+ unpacked_output = unpacked_input = unpacked_output_temp;
413
+ }
414
+ if (need_vertical) {
415
+ unpacked_output = skip_packing ? output[i] : buffer_vert;
416
+
417
+ ImagingResampleVertical(
418
+ unpacked_output,
419
+ unpacked_input,
420
+ ksize_vert,
421
+ vert_indices_weights,
422
+ vert_weights_precision
423
+ );
424
+ }
425
+
426
+ TORCH_INTERNAL_ASSERT(unpacked_output.defined());
427
+
428
+ if (!skip_packing) {
429
+ pack_rgb(unpacked_output, output[i]);
430
+ }
431
+ }
432
+ }
433
+
434
+ void ImagingResampleHorizontalConvolution8u4x(
435
+ uint8_t* C10_RESTRICT lineOut0,
436
+ uint8_t* C10_RESTRICT lineOut1,
437
+ uint8_t* C10_RESTRICT lineOut2,
438
+ uint8_t* C10_RESTRICT lineOut3,
439
+ int64_t out_xsize,
440
+ const uint8_t* C10_RESTRICT lineIn0,
441
+ const uint8_t* C10_RESTRICT lineIn1,
442
+ const uint8_t* C10_RESTRICT lineIn2,
443
+ const uint8_t* C10_RESTRICT lineIn3,
444
+ int64_t in_xsize,
445
+ const int64_t* idx_ptr_xmin,
446
+ const int64_t* idx_ptr_size,
447
+ const int16_t* kk,
448
+ int kmax,
449
+ unsigned int coefs_precision,
450
+ int64_t num_channels,
451
+ bool is_last_line) {
452
+
453
+ // Interpolation horizontal pass processing together 4 vertical lines.
454
+ // - Input data format is RGBA or RGB with R,G,B,A being uint8. In case of RGBA
455
+ // we can encode 4 values as a single uint32 value.
456
+ // - We split the size of weight vector for a given output index as a sum:
457
+ // ids_size = num_blocks_4 * 4 + num_blocks_2 * 2 + num_blocks_1.
458
+ // - We load and process 4 weights values in a loop ("block 4") then we process 2 weights values
459
+ // in another loop ("block 2") and finally we process 1 weights value in the final loop ("block 1").
460
+
461
+ // Define shuffling masks (low/high) for num_channels 4 and 3
462
+ // Mask low casts lower half of each lane to epi16 and reorder RGBARGBA -> RRGGBBAA:
463
+ // [r1 g1 b1 a1 r2 g2 b2 a2 ... | R1 G1 B1 A1 R2 G2 B2 A2 ... ] ->
464
+ // [r1 0 r2 0 g1 0 g2 0 b1 0 b2 0 a1 0 a2 0 | R1 0 R2 0 G1 0 G2 0 B1 0 B2 0 A1 0 A2 0]
465
+ // Mask high casts upper half of each lane to epi16 and reorder RGBARGBA -> RRGGBBAA::
466
+ // [ ... r3 g3 b3 a3 r4 g4 b4 a4 | ... R3 G3 B3 A3 R4 G4 B4 A4 ] ->
467
+ // [r3 0 r4 0 g3 0 g4 0 b3 0 b4 0 a3 0 a4 0 | R3 0 R4 0 G3 0 G4 0 B3 0 B4 0 A3 0 A4 0]
468
+
469
+ const auto mask_low_c4 = _mm256_set_epi8(
470
+ -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0,
471
+ -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);
472
+ const auto mask_high_c4 = _mm256_set_epi8(
473
+ -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8,
474
+ -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8);
475
+ const auto mask_low_c3 = _mm256_set_epi8(
476
+ -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0,
477
+ -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0);
478
+ const auto mask_high_c3 = _mm256_set_epi8(
479
+ -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6,
480
+ -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6);
481
+
482
+ const auto mask_low = (num_channels == 3) ? mask_low_c3 : mask_low_c4;
483
+ const auto mask_high = (num_channels == 3) ? mask_high_c3 : mask_high_c4;
484
+
485
+ const auto stride = num_channels * sizeof(uint8_t);
486
+
487
+ TORCH_INTERNAL_ASSERT(stride == 3 || stride == 4);
488
+
489
+ // out_xsize = output width, out_x = output x index
490
+ // ids_min is the input offset index corresponding to out_x
491
+ // ids_size is the interpolation size for out_x
492
+
493
+ // Let's precompute ids_size limits for block 4 and block 2.
494
+ //
495
+ // In block 4 (4 means we process 4 weight values together), we read input data
496
+ // with _mm_loadu_si128, i.e. 16 bytes, per one line:
497
+ // lineIn0 + stride * (i + ids_min) + 16 <= lineIn0 + stride * (ids_size + ids_min)
498
+ // --> i <= ids_size - 16.0 / stride
499
+ // Strict boundary:
500
+ // --> i < ids_size + 1 - int(ceil(16.0 / stride)) = ids_size - b4_delta
501
+ // Soft boundary for reading inside the buffer except its boundaries:
502
+ // --> i < ids_size + 1 - int(16.0 / stride) = ids_size - b4_delta_soft
503
+ // RGBA: b4_delta = b4_delta_soft = 3
504
+ // RGB : b4_delta = 5
505
+ // RGB : b4_delta_soft = 4
506
+ const auto b4_delta = (stride == 4) ? 3 : (is_last_line ? 5 : 4);
507
+
508
+ // In block 2 (2 means we process 2 weights values together), we read input data
509
+ // with _mm_loadl_epi64, i.e. 8 bytes, per one line:
510
+ // lineIn0 + stride * (i + ids_min) + 8 <= lineIn0 + stride * (ids_size + ids_min)
511
+ // --> i <= ids_size - 8.0 / stride
512
+ // Strict boundary:
513
+ // --> i < ids_size + 1 - int(ceil(8.0 / stride)) = ids_size - b2_delta
514
+ // Soft boundary for reading inside the buffer except its boundaries:
515
+ // --> i < ids_size + 1 - int(8.0 / stride) = ids_size - b2_delta_soft
516
+ // RGBA: b2_delta = b2_delta_soft = 1
517
+ // RGB : b2_delta = 2
518
+ // RGB : b2_delta_soft = 1
519
+ const auto b2_delta = (stride == 4) ? 1 : (is_last_line ? 2 : 1);
520
+
521
+ const auto max_out_x_strided = out_xsize * stride;
522
+ const auto max_in_x_strided = in_xsize * stride;
523
+
524
+ const auto zero = _mm256_setzero_si256();
525
+ const auto initial = _mm256_set1_epi32(1 << (coefs_precision - 1));
526
+
527
+ for (const auto out_x : c10::irange(out_xsize)) {
528
+ const auto ids_min = idx_ptr_xmin[out_x];
529
+ const auto ids_size = idx_ptr_size[out_x];
530
+ const auto * k = &kk[out_x * kmax];
531
+ int64_t i = 0;
532
+
533
+ auto sss0 = initial;
534
+ auto sss1 = initial;
535
+
536
+ const auto * lineIn0_min = lineIn0 + ids_min;
537
+ const auto * lineIn1_min = lineIn1 + ids_min;
538
+ const auto * lineIn2_min = lineIn2 + ids_min;
539
+ const auto * lineIn3_min = lineIn3 + ids_min;
540
+
541
+ // block 4
542
+ for (; i < ids_size - b4_delta; i += 4) {
543
+ // Load 4 values from weight vector
544
+ // mmk0 = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ...]
545
+ // mmk1 = [wl_2 wh_2 wl_3 wh_3 wl_2 wh_2 wl_3 wh_3 ...]
546
+ const auto mmk0 = _mm256_set1_epi32(*(int32_t*)&k[i]);
547
+ const auto mmk1 = _mm256_set1_epi32(*(int32_t*)&k[i + 2]);
548
+
549
+ // RGBA: Load 8 pixels (4 per line) from input lines 0 and 1:
550
+ // source = [
551
+ // r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3
552
+ // R0 G0 B0 A0 R1 G1 B1 A1 R2 G2 B2 A2 R3 G3 B3 A3
553
+ // ]
554
+ // RGB: Load 10 pixels (5 per line)
555
+ // source = [
556
+ // r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5
557
+ // R0 G0 B0 R1 G1 B1 R2 G2 B2 R3 G3 B3 R4 G4 B4 R5
558
+ // ]
559
+ auto source = _mm256_inserti128_si256(_mm256_castsi128_si256(
560
+ _mm_loadu_si128((__m128i *) (lineIn0_min + stride * i))),
561
+ _mm_loadu_si128((__m128i *) (lineIn1_min + stride * i)), 1);
562
+
563
+ // Apply mask_low:
564
+ // RGBA:
565
+ // [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 A0 0 A1 0]
566
+ // RGB:
567
+ // [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 0 0 0 0]
568
+ auto pix1 = _mm256_shuffle_epi8(source, mask_low);
569
+ // Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision
570
+ sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk0));
571
+
572
+ // Apply mask_high:
573
+ // RGBA:
574
+ // [r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 a2 0 a3 0 | R2 0 R3 0 G2 0 G3 0 B2 0 B3 0 A2 0 A3 0]
575
+ // RGB:
576
+ // [r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 0 0 0 0 | R2 0 R3 0 G2 0 G3 0 B2 0 B3 0 0 0 0 0]
577
+ auto pix2 = _mm256_shuffle_epi8(source, mask_high);
578
+ // Compute output value as C += w2 * C2 + w3 * C3 for each channel in 32-bit precision
579
+ sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix2, mmk1));
580
+
581
+ // Same as above to next lines 2 and 3:
582
+ auto source2 = _mm256_inserti128_si256(_mm256_castsi128_si256(
583
+ _mm_loadu_si128((__m128i *) (lineIn2_min + stride * i))),
584
+ _mm_loadu_si128((__m128i *) (lineIn3_min + stride * i)), 1);
585
+ auto pix3 = _mm256_shuffle_epi8(source2, mask_low);
586
+ sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix3, mmk0));
587
+ auto pix4 = _mm256_shuffle_epi8(source2, mask_high);
588
+ sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix4, mmk1));
589
+ }
590
+
591
+ // block 2
592
+ for (; i < ids_size - b2_delta; i += 2) {
593
+ // Load 2 values from weight vector
594
+ // mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ...]
595
+ const auto mmk = _mm256_set1_epi32(*(int32_t*)&k[i]);
596
+
597
+ // Load 4 pixels (2 per line) from input lines 0 and 1:
598
+ // RGBA: source1 = [
599
+ // r0 g0 b0 a0 r1 g1 b1 a1 0 0 0 0 0 0 0 0
600
+ // R0 G0 B0 A0 R1 G1 B1 A1 0 0 0 0 0 0 0 0
601
+ // ]
602
+ // RGB: source1 = [
603
+ // r0 g0 b0 r1 g1 b1 r2 0 0 0 0 0 0 0 0
604
+ // R0 G0 B0 R1 G1 B1 R2 0 0 0 0 0 0 0 0
605
+ // ]
606
+ auto source1 = _mm256_inserti128_si256(_mm256_castsi128_si256(
607
+ _mm_loadl_epi64((__m128i *) (lineIn0_min + stride * i))),
608
+ _mm_loadl_epi64((__m128i *) (lineIn1_min + stride * i)), 1);
609
+ // Apply mask_low:
610
+ // RGBA:
611
+ // [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 A0 0 A1 0]
612
+ // RGB:
613
+ // [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 0 0 0 0]
614
+ auto pix1 = _mm256_shuffle_epi8(source1, mask_low);
615
+ // Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision
616
+ sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk));
617
+
618
+ // Same as above for lines 2 and 3:
619
+ auto source2 = _mm256_inserti128_si256(_mm256_castsi128_si256(
620
+ _mm_loadl_epi64((__m128i *) (lineIn2_min + stride * i))),
621
+ _mm_loadl_epi64((__m128i *) (lineIn3_min + stride * i)), 1);
622
+ auto pix2 = _mm256_shuffle_epi8(source2, mask_low);
623
+ sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
624
+ }
625
+
626
+ // block 1
627
+ const auto i32_aligned = num_channels == 4;
628
+ for (; i < ids_size - 1; i++) {
629
+ // Load 1 value from weight vector
630
+ // mmk = [wl_0 wh_0 0 0 wl_0 wh_0 0 0 ...]
631
+ const auto mmk = _mm256_set1_epi32(k[i]);
632
+
633
+ // Load 2 pixels (one per line) from input lines 0 and 1:
634
+ // RGBA: pix1 = [
635
+ // r0 0 0 0 g0 0 0 0 b0 0 0 0 a0 0 0 0
636
+ // R0 0 0 0 G0 0 0 0 B0 0 0 0 A0 0 0 0
637
+ // ]
638
+ // RGB: pix1 = [
639
+ // r0 0 0 0 g0 0 0 0 b0 0 0 0 r1 0 0 0
640
+ // R0 0 0 0 G0 0 0 0 B0 0 0 0 R1 0 0 0
641
+ // ]
642
+ auto pix1 = _mm256_inserti128_si256(_mm256_castsi128_si256(
643
+ mm_cvtepu8_epi32(lineIn0_min + stride * i, i32_aligned)),
644
+ mm_cvtepu8_epi32(lineIn1_min + stride * i, i32_aligned), 1);
645
+ // Compute output value as C += w0 * C0 for each channel in 32-bit precision
646
+ sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk));
647
+
648
+ // Same as above for lines 2 and 3
649
+ auto pix2 = _mm256_inserti128_si256(_mm256_castsi128_si256(
650
+ mm_cvtepu8_epi32(lineIn2_min + stride * i, i32_aligned)),
651
+ mm_cvtepu8_epi32(lineIn3_min + stride * i, i32_aligned), 1);
652
+ sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
653
+ }
654
+
655
+ if (i == ids_size - 1) {
656
+ // last element
657
+ auto mmk = _mm256_set1_epi32(k[i]);
658
+ // For num_channels == 3 (3 bytes = one pixel) we tolerate to read 4 bytes
659
+ // lines 0, 1 and 2 won't go out of allocated memory bounds
660
+ auto pix = _mm256_inserti128_si256(_mm256_castsi128_si256(
661
+ mm_cvtepu8_epi32(lineIn0_min + stride * i, i32_aligned)),
662
+ mm_cvtepu8_epi32(lineIn1_min + stride * i, i32_aligned), 1);
663
+ sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix, mmk));
664
+
665
+ auto p0 = mm_cvtepu8_epi32(lineIn2_min + stride * i, i32_aligned);
666
+ __m128i p1;
667
+ if (num_channels == 3 && C10_UNLIKELY(is_last_line && ids_min + stride * i + 4 >= max_in_x_strided)) {
668
+ uint8_t input[4];
669
+ std::memcpy(input, lineIn3_min + stride * i, 3);
670
+ p1 = mm_cvtepu8_epi32(input, true);
671
+ } else {
672
+ p1 = mm_cvtepu8_epi32(lineIn3_min + stride * i, i32_aligned);
673
+ }
674
+ auto pix2 = _mm256_inserti128_si256(_mm256_castsi128_si256(p0), p1, 1);
675
+ sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
676
+ }
677
+
678
+ // Convert fixed point values back to integers (truncating)
679
+ sss0 = _mm256_srai_epi32(sss0, coefs_precision);
680
+ sss1 = _mm256_srai_epi32(sss1, coefs_precision);
681
+ // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
682
+ // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d 0 0 0 0 0 0 0 0)
683
+ sss0 = _mm256_packs_epi32(sss0, zero);
684
+ sss1 = _mm256_packs_epi32(sss1, zero);
685
+ // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
686
+ // (a a b b c c d d) -> (a b c d 0 0 0 0)
687
+ sss0 = _mm256_packus_epi16(sss0, zero);
688
+ sss1 = _mm256_packus_epi16(sss1, zero);
689
+
690
+ // Write the output into single uint32
691
+ // (a b c d) -> x_uint32
692
+ auto o0 = _mm_cvtsi128_si32(_mm256_castsi256_si128(sss0));
693
+ auto o1 = _mm_cvtsi128_si32(_mm256_extracti128_si256(sss0, 1));
694
+ auto o2 = _mm_cvtsi128_si32(_mm256_castsi256_si128(sss1));
695
+ auto o3 = _mm_cvtsi128_si32(_mm256_extracti128_si256(sss1, 1));
696
+
697
+ const auto out_x_strided = stride * out_x;
698
+
699
+ if (num_channels == 3 && C10_UNLIKELY(out_x_strided + 4 >= max_out_x_strided)) {
700
+ // Memcpy 4-bytes is faster than 3-bytes and this is a boundary case when we want to write
701
+ // 4 bytes (R G B | X) to the output buffer (X1 X2 X3 | R1).
702
+ // The 4th byte in the register (X) has a garbage value and 4th byte in the output buffer (R1) has a correct
703
+ // value which was previously computed by another line. In other words, it means that we can not overwrite
704
+ // it by simply writing 4 bytes from the register to the output. We'll do the following:
705
+ // v----------|
706
+ // Output = [... X1 X2 X3 | R1 G1 B1 R2 ...]
707
+ // First, we write R1 value to the 4th byte of (R G B | X) -> (R G B | R1)
708
+ // Second, we write 4 bytes from the register to the output: (X1 X2 X3 | R1) -> (R G B | R1)
709
+ // Output = [... R G B | R1 G1 B1 R2 ...]
710
+
711
+ _write_endline_rgb_as_uint32(lineOut0 + out_x_strided, o0);
712
+ _write_endline_rgb_as_uint32(lineOut1 + out_x_strided, o1);
713
+ _write_endline_rgb_as_uint32(lineOut2 + out_x_strided, o2);
714
+
715
+ if (C10_UNLIKELY(is_last_line)) {
716
+ // When we handle the last line, we can not access the next 4 bytes
717
+ // as they are out of memory bounds.
718
+ std::memcpy(lineOut3 + out_x_strided, (uint8_t *) &o3, num_channels);
719
+ } else {
720
+ _write_endline_rgb_as_uint32(lineOut3 + out_x_strided, o3);
721
+ }
722
+ } else if (num_channels == 3) {
723
+ // Memcpy 4-bytes is faster than 3-bytes and here
724
+ // we simply write 4 bytes (... R G B X 0 0 0 0 0 ...) where X is a garbage value
725
+ // that we will overwrite on the next iteration: (... R G B R G B X 0 0 ...)
726
+ std::memcpy(lineOut0 + out_x_strided, (uint8_t *) &o0, 4);
727
+ std::memcpy(lineOut1 + out_x_strided, (uint8_t *) &o1, 4);
728
+ std::memcpy(lineOut2 + out_x_strided, (uint8_t *) &o2, 4);
729
+ std::memcpy(lineOut3 + out_x_strided, (uint8_t *) &o3, 4);
730
+ } else {
731
+ // num_channels = 4 -> lineOutX + out_x_strided should be uint32 aligned
732
+ *(uint32_t *)(lineOut0 + out_x_strided) = o0;
733
+ *(uint32_t *)(lineOut1 + out_x_strided) = o1;
734
+ *(uint32_t *)(lineOut2 + out_x_strided) = o2;
735
+ *(uint32_t *)(lineOut3 + out_x_strided) = o3;
736
+ }
737
+ }
738
+ }
739
+
740
+ void ImagingResampleHorizontalConvolution8u(
741
+ uint8_t* C10_RESTRICT lineOut,
742
+ int64_t out_xsize,
743
+ const uint8_t* C10_RESTRICT lineIn,
744
+ int64_t in_xsize,
745
+ const int64_t* idx_ptr_xmin,
746
+ const int64_t* idx_ptr_size,
747
+ const int16_t* kk,
748
+ int kmax,
749
+ unsigned int coefs_precision,
750
+ int64_t num_channels,
751
+ bool is_last_line) {
752
+
753
+ // Interpolation horizontal pass processing only one vertical line.
754
+ // - Input data format is RGBA or RGB with R,G,B,A being uint8. In case of RGBA
755
+ // we can encode 4 values as a single uint32 value.
756
+ // - We split the size of weight vector for a given output index as a sum:
757
+ // ids_size = num_blocks_8 * 8 + num_blocks_4 * 4 + num_blocks_2 * 2 + num_blocks_1
758
+ // - We load and process 8 weights values in a loop ("block 8") then 4 weights and 2 weights values in
759
+ // in another loops ("block 4" and "block 2") and finally we process 1 weight value in the final loop ("block 1").
760
+
761
+ // Define various shuffling masks
762
+ const auto kmask_low = _mm256_set_epi8(
763
+ 11, 10, 9, 8, 11, 10, 9, 8, 11, 10, 9, 8, 11, 10, 9, 8,
764
+ 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0);
765
+ const auto kmask_high = _mm256_set_epi8(
766
+ 15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12,
767
+ 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4);
768
+ const auto kmask_hl = _mm256_set_epi8(
769
+ 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4,
770
+ 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0);
771
+
772
+ const auto mask_low_c4 = _mm256_set_epi8(
773
+ -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0,
774
+ -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);
775
+ const auto mask_high_c4 = _mm256_set_epi8(
776
+ -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8,
777
+ -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8);
778
+ const auto mask_low_c3 = _mm256_set_epi8(
779
+ -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0,
780
+ -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0);
781
+ const auto mask_high_c3 = _mm256_set_epi8(
782
+ -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6,
783
+ -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6);
784
+ const auto mask_hl_c3 = _mm256_set_epi8(
785
+ -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6,
786
+ -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0);
787
+ const auto mask_hl_c4 = _mm256_set_epi8(
788
+ -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8,
789
+ -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);
790
+
791
+ const auto mask_low128_c3 = _mm_set_epi8(
792
+ -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0);
793
+ const auto mask_low128_c4 = _mm_set_epi8(
794
+ -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);
795
+
796
+ const auto mask_low = (num_channels == 3) ? mask_low_c3 : mask_low_c4;
797
+ const auto mask_high = (num_channels == 3) ? mask_high_c3 : mask_high_c4;
798
+ const auto mask_hl = (num_channels == 3) ? mask_hl_c3 : mask_hl_c4;
799
+ const auto mask_low128 = (num_channels == 3) ? mask_low128_c3 : mask_low128_c4;
800
+
801
+ // out_xsize = output width, out_x = output x index
802
+ // ids_min is the input offset index corresponding to out_x
803
+ // ids_size is the interpolation size for out_x
804
+
805
+ const auto stride = num_channels * sizeof(uint8_t);
806
+ const auto zero = _mm_setzero_si128();
807
+
808
+ TORCH_INTERNAL_ASSERT(stride == 3 || stride == 4);
809
+
810
+ // Let's precompute ids_size limits for block 8, block 4 and block 2
811
+ //
812
+ // In block 8 (8 means we process 8 weight values together), we read at
813
+ // most 32 bytes input data (16 + 16 bytes for RGBA and 12 + 16 bytes for RGB)
814
+ // lineIn + stride * (i + ids_min) + 32 <= lineIn + stride * (ids_size + ids_min)
815
+ // --> i <= ids_size - 32.0 / stride
816
+ // Strict boundary:
817
+ // --> i < ids_size + 1 - int(ceil(32.0 / stride)) = ids_size - b8_delta
818
+ // Soft boundary for reading inside the buffer except its boundaries:
819
+ // --> i < ids_size + 1 - int(32.0 / stride) = ids_size - b8_delta_soft
820
+ // RGBA: b8_delta = b8_delta_soft = 7
821
+ // RGB : b8_delta = 10
822
+ // RGB : b8_delta_soft = 9
823
+ const auto b8_delta = (stride == 4) ? 7 : (is_last_line ? 10 : 9);
824
+
825
+ // In block 4 (4 means we process 4 weight values together), we read
826
+ // 16 bytes of input data.
827
+ // lineIn + stride * (i + ids_min) + 16 <= lineIn0 + stride * (ids_size + ids_min)
828
+ // --> i <= ids_size - 16.0 / stride
829
+ // Strict boundary:
830
+ // --> i < ids_size + 1 - int(ceil(16.0 / stride)) = ids_size - b4_delta
831
+ // Soft boundary for reading inside the buffer except its boundaries:
832
+ // --> i < ids_size + 1 - int(16.0 / stride) = ids_size - b4_delta_soft
833
+ // RGBA: b4_delta = b4_delta_soft = 3
834
+ // RGB : b4_delta = 5
835
+ // RGB : b4_delta_soft = 4
836
+ const auto b4_delta = (stride == 4) ? 3 : (is_last_line ? 5 : 4);
837
+
838
+ // In block 2 (2 means we process 2 weight values together), we read
839
+ // 8 bytes of input data.
840
+ // lineIn0 + stride * (i + ids_min) + 8 <= lineIn0 + stride * (ids_size + ids_min)
841
+ // --> i <= ids_size - 8.0 / stride
842
+ // Strict boundary:
843
+ // --> i < ids_size + 1 - int(ceil(8.0 / stride)) = ids_size - b2_delta
844
+ // Soft boundary for reading inside the buffer except its boundaries:
845
+ // --> i < ids_size + 1 - int(8.0 / stride) = ids_size - b2_delta_soft
846
+ // RGBA: b2_delta = b2_delta_soft = 1
847
+ // RGB : b2_delta = 2
848
+ // RGB : b2_delta_soft = 1
849
+ const auto b2_delta = (stride == 4) ? 1 : (is_last_line ? 2 : 1);
850
+
851
+ const auto max_out_x_strided = out_xsize * stride;
852
+ const auto max_in_x_strided = in_xsize * stride;
853
+
854
+ for (const auto out_x : c10::irange(out_xsize)) {
855
+ __m128i sss;
856
+ const auto ids_min = idx_ptr_xmin[out_x];
857
+ const auto ids_size = idx_ptr_size[out_x];
858
+ const auto * k = &kk[out_x * kmax];
859
+ int64_t i = 0;
860
+
861
+ const auto * lineIn_min = lineIn + ids_min;
862
+
863
+ if (ids_size < 8) {
864
+ sss = _mm_set1_epi32(1 << (coefs_precision - 1));
865
+ } else {
866
+ // Lower part will be added to higher, use only half of the error
867
+ auto sss256 = _mm256_set1_epi32(1 << (coefs_precision - 2));
868
+
869
+ // block 8
870
+ for (; i < ids_size - b8_delta; i += 8) {
871
+ // Load 8 values from weight vector
872
+ auto tmp = _mm_loadu_si128((__m128i*)&k[i]);
873
+ // ksource = [
874
+ // wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 wl_4 wh_4 wl_5 wh_5 wl_6 wh_6 wl_7 wh_7
875
+ // wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 wl_4 wh_4 wl_5 wh_5 wl_6 wh_6 wl_7 wh_7
876
+ // ]
877
+ auto ksource = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1);
878
+
879
+ // RGBA: Load 8 pixels from input:
880
+ // source = [
881
+ // r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3
882
+ // r4 g4 b4 a4 r5 g5 b5 a5 r6 g6 b6 a6 r7 g7 b7 a7
883
+ // ]
884
+ // RGB: Load 10 pixels from input (however we can process only 8 pixels):
885
+ // source = [
886
+ // r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5
887
+ // r4 g4 b4 r5 g5 b5 r6 g6 b6 r7 g7 b7 r8 g8 b8 r9
888
+ // ]
889
+ auto source = _mm256_inserti128_si256(_mm256_castsi128_si256(
890
+ _mm_loadu_si128((__m128i *) (lineIn_min + stride * i))),
891
+ _mm_loadu_si128((__m128i *) (lineIn_min + stride * (i + 4))), 1);
892
+
893
+ // Extract lower part of each lane, cast to epi16 and reorder RGBARGBA -> RRGGBBAA
894
+ // RGBA: pix1 = [
895
+ // r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0
896
+ // r4 0 r5 0 g4 0 g5 0 b4 0 b5 0 a4 0 a5 0
897
+ // ]
898
+ // RGB: pix1 = [
899
+ // r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0
900
+ // r4 0 r5 0 g4 0 g5 0 b4 0 b5 0 0 0 0 0
901
+ // ]
902
+ auto pix1 = _mm256_shuffle_epi8(source, mask_low);
903
+ // mmk1 = [
904
+ // wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ...
905
+ // wl_4 wh_4 wl_5 wh_5 wl_4 wh_4 wl_5 wh_5 ... ...
906
+ // ]
907
+ auto mmk1 = _mm256_shuffle_epi8(ksource, kmask_low);
908
+ // Compute output value as
909
+ // C += w0 * C0 + w1 * C1
910
+ // C += w4 * C4 + w5 * C5 for each channel in 32-bit precision
911
+ sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix1, mmk1));
912
+
913
+ // Same as above for higher part of each lane
914
+ auto pix2 = _mm256_shuffle_epi8(source, mask_high);
915
+ auto mmk2 = _mm256_shuffle_epi8(ksource, kmask_high);
916
+ // Compute output value as
917
+ // C += w2 * C2 + w3 * C3
918
+ // C += w6 * C6 + w7 * C7 for each channel in 32-bit precision
919
+ sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix2, mmk2));
920
+ }
921
+
922
+ // block 4
923
+ for (; i < ids_size - b4_delta; i += 4) {
924
+ // Load 4 values from weight vector
925
+ auto tmp = _mm_loadl_epi64((__m128i *) &k[i]);
926
+ // ksource = [
927
+ // wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 0 0 0 0 0 0 0 0
928
+ // wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 0 0 0 0 0 0 0 0
929
+ // ]
930
+ auto ksource = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1);
931
+
932
+ // Load pixels from input line
933
+ tmp = _mm_loadu_si128((__m128i *) (lineIn_min + stride * i));
934
+ // RGBA: source = [
935
+ // r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3
936
+ // r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3
937
+ // ]
938
+ // RGB: source = [
939
+ // r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5
940
+ // r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5
941
+ // ]
942
+ auto source = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1);
943
+
944
+ // Cast source to epi16 and reorder RGBARGBA -> RRGGBBAA
945
+ // RGBA: pix = [
946
+ // r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0
947
+ // r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 a2 0 a3 0
948
+ // ]
949
+ // RGB: pix = [
950
+ // r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0
951
+ // r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 0 0 0 0
952
+ // ]
953
+ auto pix = _mm256_shuffle_epi8(source, mask_hl);
954
+ // mmk = [
955
+ // wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ...
956
+ // wl_2 wh_2 wl_3 wh_3 wl_2 wh_2 wl_3 wh_3 ... ...
957
+ // ]
958
+ auto mmk = _mm256_shuffle_epi8(ksource, kmask_hl);
959
+ // Compute output value as
960
+ // C += w0 * C0 + w1 * C1
961
+ // C += w2 * C2 + w3 * C3 for each channel in 32-bit precision
962
+ sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix, mmk));
963
+ }
964
+
965
+ // Sum results between the lanes
966
+ sss = _mm_add_epi32(
967
+ _mm256_extracti128_si256(sss256, 0),
968
+ _mm256_extracti128_si256(sss256, 1));
969
+ }
970
+
971
+ // block 2
972
+ for (; i < ids_size - b2_delta; i += 2) {
973
+ // Load 2 values from weight vector
974
+ // mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ...]
975
+ auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]);
976
+ // Load pixels from input line
977
+ // RGBA: source = [
978
+ // r0 g0 b0 a0 r1 g1 b1 a1 0 0 0 0 0 0 0 0
979
+ // ]
980
+ // RGB: source = [
981
+ // r0 g0 b0 r1 g1 b1 r2 g2 0 0 0 0 0 0 0 0
982
+ // ]
983
+ auto source = _mm_loadl_epi64((__m128i *) (lineIn_min + stride * i));
984
+ // Cast source to epi16 and reorder RGBARGBA -> RRGGBBAA
985
+ auto pix = _mm_shuffle_epi8(source, mask_low128);
986
+ // Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision
987
+ sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
988
+ }
989
+
990
+ // block 1
991
+ const auto i32_aligned = num_channels == 4;
992
+ for (; i < ids_size - 1; i++) {
993
+ // Load 1 value from weight vector
994
+ // mmk = [wl_0 wh_0 0 0 wl_0 wh_0 0 0 ...]
995
+ auto mmk = _mm_set1_epi32(k[i]);
996
+ // Load one pixel from input line
997
+ // RGBA: pix = [
998
+ // r0 0 0 0 g0 0 0 0 b0 0 0 0 a0 0 0 0
999
+ // ]
1000
+ // RGB: pix = [
1001
+ // r0 0 0 0 g0 0 0 0 b0 0 0 0 r1 0 0 0
1002
+ // ]
1003
+ auto pix = mm_cvtepu8_epi32(lineIn_min + stride * i, i32_aligned);
1004
+ // Compute output value as C += w0 * C0 for each channel in 32-bit precision
1005
+ sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
1006
+ }
1007
+
1008
+ if (i == ids_size - 1) {
1009
+ // last element
1010
+ auto mmk = _mm_set1_epi32(k[i]);
1011
+ __m128i pix;
1012
+ auto p = lineIn_min + stride * i;
1013
+ if (num_channels == 3 && C10_UNLIKELY(is_last_line && ids_min + stride * i + 4 >= max_in_x_strided)) {
1014
+ uint8_t input[4];
1015
+ std::memcpy(input, p, 3);
1016
+ pix = mm_cvtepu8_epi32(input, true);
1017
+ } else {
1018
+ pix = mm_cvtepu8_epi32(p, i32_aligned);
1019
+ }
1020
+ sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
1021
+ }
1022
+
1023
+ // Convert fixed point values back to integers (truncating)
1024
+ sss = _mm_srai_epi32(sss, coefs_precision);
1025
+ // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
1026
+ // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d 0 0 0 0 0 0 0 0)
1027
+ sss = _mm_packs_epi32(sss, zero);
1028
+ // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
1029
+ // (a a b b c c d d) -> (a b c d 0 0 0 0)
1030
+ sss = _mm_packus_epi16(sss, zero);
1031
+ // Write the output into single uint32
1032
+ // (a b c d) -> x_uint32
1033
+ auto o = _mm_cvtsi128_si32(sss);
1034
+ const auto out_x_strided = stride * out_x;
1035
+ if (num_channels == 3 && C10_UNLIKELY(out_x_strided + 4 >= max_out_x_strided)) {
1036
+ if (C10_UNLIKELY(is_last_line)) {
1037
+ // When we handle the last line, we can not access the next 4 bytes
1038
+ // as they are out of memory bounds.
1039
+ std::memcpy(lineOut + out_x_strided, (uint8_t *) &o, 3);
1040
+ } else {
1041
+ // Memcpy 4-bytes is faster than 3-bytes and this is a boundary case when we want to write
1042
+ // 4 bytes (R G B | X) to the output buffer (X1 X2 X3 | R1).
1043
+ // The 4th byte in the register (X) has a garbage value and 4th byte in the output buffer (R1) has a correct
1044
+ // value which was previously computed by another line. In other words, it means that we can not overwrite
1045
+ // it by simply writing 4 bytes from the register to the output. We'll do the following:
1046
+ // v----------|
1047
+ // Output = [... X1 X2 X3 | R1 G1 B1 R2 ...]
1048
+ // First, we write R1 value to the 4th byte of (R G B | X) -> (R G B | R1)
1049
+ // Second, we write 4 bytes from the register to the output: (X1 X2 X3 | R1) -> (R G B | R1)
1050
+ // Output = [... R G B | R1 G1 B1 R2 ...]
1051
+ _write_endline_rgb_as_uint32(lineOut + out_x_strided, o);
1052
+ }
1053
+ } else if (num_channels == 3) {
1054
+ // Memcpy 4-bytes is faster than 3-bytes and here
1055
+ // we simply write 4 bytes (... R G B X 0 0 0 0 0 ...) where X is a garbage value
1056
+ // that we will overwrite on the next iteration: (... R G B R G B X 0 0 ...)
1057
+ std::memcpy(lineOut + out_x_strided, (uint8_t *) &o, 4);
1058
+ } else {
1059
+ // num_channels = 4 -> lineOut + out_x_strided should be uint32 aligned
1060
+ *(uint32_t *)(lineOut + out_x_strided) = o;
1061
+ }
1062
+ }
1063
+ }
1064
+
1065
+ void ImagingResampleVerticalConvolution8u(
1066
+ uint8_t* C10_RESTRICT lineOut,
1067
+ const uint8_t* C10_RESTRICT lineIn,
1068
+ int64_t xsize,
1069
+ int64_t ids_min,
1070
+ int64_t ids_size,
1071
+ const int16_t* k,
1072
+ unsigned int coefs_precision,
1073
+ int64_t num_channels) {
1074
+
1075
+ // Interpolation vertical pass processing one line.
1076
+ // - We process x-axis data with blocks of 8, 2 and 1
1077
+ // - We split the size of weight vector for a given output index as a sum: K = n * 2 + m.
1078
+
1079
+ // xsize = output width, also equals to input width
1080
+ // ids_size = interpolation size
1081
+ // ids_min = input y start index
1082
+ const auto stride = num_channels * sizeof(uint8_t);
1083
+
1084
+ TORCH_INTERNAL_ASSERT(stride == 3 || stride == 4);
1085
+
1086
+ const int64_t data_size = xsize * stride;
1087
+ const int64_t data_stride = stride;
1088
+ constexpr auto vec_size = 256 / 8;
1089
+
1090
+ const auto initial = _mm_set1_epi32(1 << (coefs_precision - 1));
1091
+ const auto initial_256 = _mm256_set1_epi32(1 << (coefs_precision - 1));
1092
+ const auto zero = _mm_setzero_si128();
1093
+ const auto zero_256 = _mm256_setzero_si256();
1094
+
1095
+ int64_t j = 0;
1096
+ // block 8
1097
+ const auto b8_usable_vec_stride = (vec_size / data_stride) * data_stride;
1098
+ for (; j < data_size - vec_size; j += b8_usable_vec_stride) {
1099
+ auto sss0 = initial_256;
1100
+ auto sss1 = initial_256;
1101
+ auto sss2 = initial_256;
1102
+ auto sss3 = initial_256;
1103
+ int64_t i = 0;
1104
+ const auto * lineIn_min = lineIn + j + ids_min;
1105
+
1106
+ for (; i < ids_size - 1; i += 2) {
1107
+ // Load 2 values from weight vector
1108
+ auto mmk = _mm256_set1_epi32(*(int32_t*)&k[i]);
1109
+
1110
+ // RGBA: Load 8 pixels per line
1111
+ // source1 = [
1112
+ // r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3
1113
+ // r4 g4 b4 a4 r5 g5 b5 a5 r6 g6 b6 a6 r7 g7 b7 a7
1114
+ // ]
1115
+ // RGB: Load 10 pixels per line (however we can process only 8 pixels):
1116
+ // source1 = [
1117
+ // r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5
1118
+ // r4 g4 b4 r5 g5 b5 r6 g6 b6 r7 g7 b7 r8 g8 b8 r9
1119
+ // ]
1120
+ auto source1 =
1121
+ _mm256_loadu_si256((__m256i*)(lineIn_min + data_size * i));
1122
+ auto source2 =
1123
+ _mm256_loadu_si256((__m256i*)(lineIn_min + data_size * (i + 1)));
1124
+
1125
+ // Interleave source1 and source2 from the low half of each 128-bit lane
1126
+ // and cast the result to epi16
1127
+ // RGBA: pix1 = [
1128
+ // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 a0 0 A0 0
1129
+ // r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 a1 0 A1 0
1130
+ // ]
1131
+ // RGB: pix1 = [
1132
+ // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 0 0 0 0
1133
+ // r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 0 0 0 0
1134
+ // ]
1135
+ auto source_lo = _mm256_unpacklo_epi8(source1, source2);
1136
+ auto pix1 = _mm256_unpacklo_epi8(source_lo, zero_256);
1137
+ // Compute output value as
1138
+ // C += w0 * c0 + w1 * C0
1139
+ // C += w0 * c1 + w1 * C1 for each channel in 32-bit precision
1140
+ sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk));
1141
+
1142
+ // RGBA: pix2 = [
1143
+ // r2 0 R2 0 g2 0 G2 0 b2 0 B2 0 a2 0 A2 0
1144
+ // r3 0 R3 0 g3 0 G3 0 b3 0 B3 0 a3 0 A3 0
1145
+ // ]
1146
+ // RGB: pix2 = [
1147
+ // r2 0 R2 0 g2 0 G2 0 b2 0 B2 0 0 0 0 0
1148
+ // r3 0 R3 0 g3 0 G3 0 b3 0 B3 0 0 0 0 0
1149
+ // ]
1150
+ auto pix2 = _mm256_unpackhi_epi8(source_lo, zero_256);
1151
+ // Compute output value as
1152
+ // C += w0 * c2 + w1 * C2
1153
+ // C += w0 * c3 + w1 * C3 for each channel in 32-bit precision
1154
+ sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
1155
+
1156
+ // Same as above for the high half of each 128-bit lane
1157
+ auto source_hi = _mm256_unpackhi_epi8(source1, source2);
1158
+ auto pix3 = _mm256_unpacklo_epi8(source_hi, zero_256);
1159
+ sss2 = _mm256_add_epi32(sss2, _mm256_madd_epi16(pix3, mmk));
1160
+ auto pix4 = _mm256_unpackhi_epi8(source_hi, zero_256);
1161
+ sss3 = _mm256_add_epi32(sss3, _mm256_madd_epi16(pix4, mmk));
1162
+ }
1163
+ // Same processing as above but with a single weight value
1164
+ for (; i < ids_size; i += 1) {
1165
+ auto mmk = _mm256_set1_epi32(k[i]);
1166
+
1167
+ auto source1 = _mm256_loadu_si256((__m256i*)(lineIn_min + i * data_size));
1168
+
1169
+ auto source_lo = _mm256_unpacklo_epi8(source1, zero_256);
1170
+ auto pix1 = _mm256_unpacklo_epi8(source_lo, zero_256);
1171
+ sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk));
1172
+ auto pix2 = _mm256_unpackhi_epi8(source_lo, zero_256);
1173
+ sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
1174
+
1175
+ auto source_hi = _mm256_unpackhi_epi8(source1, zero_256);
1176
+ auto pix3 = _mm256_unpacklo_epi8(source_hi, _mm256_setzero_si256());
1177
+ sss2 = _mm256_add_epi32(sss2, _mm256_madd_epi16(pix3, mmk));
1178
+ auto pix4 = _mm256_unpackhi_epi8(source_hi, _mm256_setzero_si256());
1179
+ sss3 = _mm256_add_epi32(sss3, _mm256_madd_epi16(pix4, mmk));
1180
+ }
1181
+ // Convert fixed point values back to integers (truncating)
1182
+ sss0 = _mm256_srai_epi32(sss0, coefs_precision);
1183
+ sss1 = _mm256_srai_epi32(sss1, coefs_precision);
1184
+ sss2 = _mm256_srai_epi32(sss2, coefs_precision);
1185
+ sss3 = _mm256_srai_epi32(sss3, coefs_precision);
1186
+ // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
1187
+ // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d)
1188
+ sss0 = _mm256_packs_epi32(sss0, sss1);
1189
+ sss2 = _mm256_packs_epi32(sss2, sss3);
1190
+ // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
1191
+ // (a a b b c c d d) -> (a b c d)
1192
+ sss0 = _mm256_packus_epi16(sss0, sss2);
1193
+
1194
+ // Stores 32 bytes
1195
+ _mm256_storeu_si256((__m256i*)(lineOut + j), sss0);
1196
+ }
1197
+
1198
+ // TODO: Do we also need block 4 ???
1199
+ // block 2
1200
+ const auto b2_usable_vec_stride = (8 / data_stride) * data_stride;
1201
+ for (; j < data_size - vec_size / 4; j += b2_usable_vec_stride) {
1202
+ auto sss0 = initial;
1203
+ auto sss1 = initial;
1204
+ int64_t i = 0;
1205
+ const auto * lineIn_min = lineIn + j + ids_min;
1206
+
1207
+ for (; i < ids_size - 1; i += 2) {
1208
+ // Load 2 values from weight vector
1209
+ // mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ]
1210
+ auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]);
1211
+
1212
+ // Load 2 pixels per line
1213
+ // RGBA: source1 = [
1214
+ // r0 g0 b0 a0 r1 g1 b1 a1 0 0 0 0 0 0 0 0
1215
+ // ]
1216
+ // RGB: source1 = [
1217
+ // r0 g0 b0 r1 g1 b1 r2 g2 0 0 0 0 0 0 0 0
1218
+ // ]
1219
+ auto source1 = _mm_loadl_epi64((__m128i *) (lineIn_min + i * data_size));
1220
+ auto source2 = _mm_loadl_epi64((__m128i *) (lineIn_min + (i + 1) * data_size));
1221
+ // Interleave source1 and source2 and cast the result to epi16
1222
+ // RGBA: pix = [
1223
+ // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 a0 0 A0 0
1224
+ // ]
1225
+ // RGB: pix = [
1226
+ // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 0 0 0 0
1227
+ // ]
1228
+ auto source = _mm_unpacklo_epi8(source1, source2);
1229
+ auto pix = _mm_unpacklo_epi8(source, zero);
1230
+ // Compute output value as C += w0 * c0 + w1 * C0 for each channel in 32-bit precision
1231
+ sss0 = _mm_add_epi32(sss0, _mm_madd_epi16(pix, mmk));
1232
+ // RGBA: pix = [
1233
+ // r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 a1 0 A1 0
1234
+ // ]
1235
+ // RGB: pix = [
1236
+ // r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 0 0 0 0
1237
+ // ]
1238
+ pix = _mm_unpackhi_epi8(source, zero);
1239
+ // Compute output value as C += w0 * c1 + w1 * C1 for each channel in 32-bit precision
1240
+ sss1 = _mm_add_epi32(sss1, _mm_madd_epi16(pix, mmk));
1241
+ }
1242
+ // Same processing as above but with a single weight value
1243
+ for (; i < ids_size; i += 1) {
1244
+ auto mmk = _mm_set1_epi32(k[i]);
1245
+
1246
+ auto source1 = _mm_loadl_epi64((__m128i*) (lineIn_min + i * data_size));
1247
+
1248
+ auto source = _mm_unpacklo_epi8(source1, zero);
1249
+ auto pix1 = _mm_unpacklo_epi8(source, zero);
1250
+ sss0 = _mm_add_epi32(sss0, _mm_madd_epi16(pix1, mmk));
1251
+ auto pix2 = _mm_unpackhi_epi8(source, zero);
1252
+ sss1 = _mm_add_epi32(sss1, _mm_madd_epi16(pix2, mmk));
1253
+ }
1254
+ // Convert fixed point values back to integers (truncating)
1255
+ sss0 = _mm_srai_epi32(sss0, coefs_precision);
1256
+ sss1 = _mm_srai_epi32(sss1, coefs_precision);
1257
+ // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
1258
+ // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d)
1259
+ sss0 = _mm_packs_epi32(sss0, sss1);
1260
+ // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
1261
+ // (a a b b c c d d) -> (a b c d)
1262
+ sss0 = _mm_packus_epi16(sss0, sss0);
1263
+ // Store 2 pixels to the output
1264
+ _mm_storel_epi64((__m128i*)(lineOut + j), sss0);
1265
+ }
1266
+
1267
+ // block 1
1268
+ const auto b1_usable_vec_stride = (4 / data_stride) * data_stride;
1269
+ const auto i32_aligned = num_channels == 4;
1270
+ for (; j < data_size - 4; j += b1_usable_vec_stride) {
1271
+ auto sss = initial;
1272
+ int64_t i = 0;
1273
+ const auto * lineIn_min = lineIn + j + ids_min;
1274
+
1275
+ for (; i < ids_size - 1; i += 2) {
1276
+ // Load 2 values from weight vector
1277
+ // mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ]
1278
+ auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]);
1279
+
1280
+ // Load one pixel per line
1281
+ // RGBA: source1 = [
1282
+ // r0 g0 b0 a0 0 0 0 0 0 0 0 0 0 0 0 0
1283
+ // ]
1284
+ // RGB: source1 = [
1285
+ // r0 g0 b0 r1 0 0 0 0 0 0 0 0 0 0 0 0
1286
+ // ]
1287
+ auto source1 = mm_cvtsi32_si128(lineIn_min + i * data_size, i32_aligned);
1288
+ auto source2 = mm_cvtsi32_si128(lineIn_min + (i + 1) * data_size, i32_aligned);
1289
+
1290
+ // Interleave source1 and source2 and cast the result to epi16
1291
+ // RGBA: pix = [
1292
+ // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 a0 0 A0 0
1293
+ // ]
1294
+ // RGB: pix = [
1295
+ // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 0 0 0 0
1296
+ // ]
1297
+ auto source = _mm_unpacklo_epi8(source1, source2);
1298
+ auto pix = _mm_unpacklo_epi8(source, zero);
1299
+ // Compute output value as C += w0 * c0 + w1 * C0 for each channel in 32-bit precision
1300
+ sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
1301
+ }
1302
+
1303
+ for (; i < ids_size; i++) {
1304
+ auto mmk = _mm_set1_epi32(k[i]);
1305
+ auto pix = mm_cvtepu8_epi32(lineIn_min + i * data_size, i32_aligned);
1306
+ sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
1307
+ }
1308
+ sss = _mm_srai_epi32(sss, coefs_precision);
1309
+ sss = _mm_packs_epi32(sss, zero);
1310
+ sss = _mm_packus_epi16(sss, zero);
1311
+
1312
+ auto o = _mm_cvtsi128_si32(sss);
1313
+
1314
+ // Here we write 4 bytes to the output even if num_channels < 4, e.g o = {r,g,b,X} for num_channels=3
1315
+ // It is OK to write 4th byte (e.g. X) as on the next step we will overwrite it with new data.
1316
+ // We also won't go out of bounds of lineOut memory allocation
1317
+ std::memcpy(lineOut + j, (uint8_t *) &o, 4);
1318
+ }
1319
+
1320
+ for (; j < data_size; j += data_stride) {
1321
+ auto sss = initial;
1322
+ int64_t i = 0;
1323
+ const auto * lineIn_min = lineIn + j + ids_min;
1324
+ // For RGBA we can use (ids_size - 1) as tighter limit but for RGB we can read outside memory boundary
1325
+ // for the last remaining line
1326
+ for (; i < ids_size - 2; i += 2) {
1327
+ // Load two coefficients at once
1328
+ auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]);
1329
+
1330
+ // Load 2 lines
1331
+ auto source1 = mm_cvtsi32_si128(lineIn_min + i * data_size, i32_aligned);
1332
+ auto source2 = mm_cvtsi32_si128(lineIn_min + (i + 1) * data_size, i32_aligned);
1333
+
1334
+ auto source = _mm_unpacklo_epi8(source1, source2);
1335
+ auto pix = _mm_unpacklo_epi8(source, zero);
1336
+ sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
1337
+ }
1338
+
1339
+ // Same processing as above but with a single weight value
1340
+ for (; i < ids_size; i++) {
1341
+ auto mmk = _mm_set1_epi32(k[i]);
1342
+
1343
+ const uint8_t * p = lineIn_min + i * data_size;
1344
+ __m128i pix;
1345
+ // There is no much perf gain using more detailed condition like
1346
+ // num_channels == 3 && ids_min + j + data_size * i + 4 >= in_max_size
1347
+ // const int64_t in_max_size = data_size * in_ysize;
1348
+ if (num_channels == 3) {
1349
+ uint8_t input[4];
1350
+ std::memcpy(input, p, 3);
1351
+ pix = mm_cvtepu8_epi32(input, true);
1352
+ } else {
1353
+ pix = mm_cvtepu8_epi32(p, true);
1354
+ }
1355
+ sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
1356
+ }
1357
+
1358
+ // Convert fixed point values back to integers (truncating)
1359
+ sss = _mm_srai_epi32(sss, coefs_precision);
1360
+ // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
1361
+ // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d)
1362
+ sss = _mm_packs_epi32(sss, zero);
1363
+ // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
1364
+ // (a a b b c c d d) -> (a b c d)
1365
+ sss = _mm_packus_epi16(sss, zero);
1366
+ // Store one pixel to the output
1367
+ auto o = _mm_cvtsi128_si32(sss);
1368
+ if (num_channels == 3 && C10_UNLIKELY(j + 4 >= data_size)) {
1369
+ std::memcpy(lineOut + j, (uint8_t *) &o, 3);
1370
+ } else {
1371
+ std::memcpy(lineOut + j, (uint8_t *) &o, 4);
1372
+ }
1373
+ }
1374
+ }
1375
+
1376
+ } // anonymous namespace
1377
+ #endif // CPU_CAPABILITY_AVX2
1378
+
1379
+ #else
1380
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
1381
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/avx_mathfun.h ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ /*
4
+ AVX implementation of sin, cos, sincos, exp and log
5
+
6
+ Based on "sse_mathfun.h", by Julien Pommier
7
+ http://gruntthepeon.free.fr/ssemath/
8
+
9
+ Copyright (C) 2012 Giovanni Garberoglio
10
+ Interdisciplinary Laboratory for Computational Science (LISC)
11
+ Fondazione Bruno Kessler and University of Trento
12
+ via Sommarive, 18
13
+ I-38123 Trento (Italy)
14
+
15
+ This software is provided 'as-is', without any express or implied
16
+ warranty. In no event will the authors be held liable for any damages
17
+ arising from the use of this software.
18
+
19
+ Permission is granted to anyone to use this software for any purpose,
20
+ including commercial applications, and to alter it and redistribute it
21
+ freely, subject to the following restrictions:
22
+
23
+ 1. The origin of this software must not be misrepresented; you must not
24
+ claim that you wrote the original software. If you use this software
25
+ in a product, an acknowledgment in the product documentation would be
26
+ appreciated but is not required.
27
+ 2. Altered source versions must be plainly marked as such, and must not be
28
+ misrepresented as being the original software.
29
+ 3. This notice may not be removed or altered from any source distribution.
30
+
31
+ (this is the zlib license)
32
+ */
33
+
34
+ #include <ATen/native/cpu/Intrinsics.h>
35
+
36
+ /* The original source of this file has been modified. */
37
+ #if defined(CPU_CAPABILITY_AVX2)
38
+
39
+ #if defined(__GNUC__)
40
+ # define ALIGN32_BEG __attribute__((aligned(32)))
41
+ #elif defined(_WIN32)
42
+ # define ALIGN32_BEG __declspec(align(32))
43
+ #endif
44
+
45
+ typedef __m256 v8sf; // vector of 8 float (avx2)
46
+ typedef __m256i v8si; // vector of 8 int (avx2)
47
+
48
+ /* declare some AVX constants -- why can't I figure a better way to do that? */
49
+ #define _PS256_CONST(Name, Val) \
50
+ static const ALIGN32_BEG float _ps256_##Name[8] = { Val, Val, Val, Val, Val, Val, Val, Val }
51
+ #define _PI32_CONST256(Name, Val) \
52
+ static const ALIGN32_BEG int _pi32_256_##Name[8] = { Val, Val, Val, Val, Val, Val, Val, Val }
53
+ #define _PS256_CONST_TYPE(Name, Type, Val) \
54
+ static const ALIGN32_BEG Type _ps256_##Name[8] = { Val, Val, Val, Val, Val, Val, Val, Val }
55
+
56
+ _PS256_CONST(1 , 1.0f);
57
+ _PS256_CONST(0p5, 0.5f);
58
+ /* the smallest non denormalized float number */
59
+ _PS256_CONST_TYPE(min_norm_pos, int, 0x00800000);
60
+ _PS256_CONST_TYPE(mant_mask, int, 0x7f800000);
61
+ _PS256_CONST_TYPE(inv_mant_mask, int, ~0x7f800000);
62
+
63
+ _PS256_CONST_TYPE(sign_mask, int, (int)0x80000000);
64
+ _PS256_CONST_TYPE(inv_sign_mask, int, ~0x80000000);
65
+
66
+ _PI32_CONST256(0, 0);
67
+ _PI32_CONST256(1, 1);
68
+ _PI32_CONST256(inv1, ~1);
69
+ _PI32_CONST256(2, 2);
70
+ _PI32_CONST256(4, 4);
71
+ _PI32_CONST256(0x7f, 0x7f);
72
+
73
+ _PS256_CONST(cephes_SQRTHF, 0.707106781186547524);
74
+ _PS256_CONST(cephes_log_p0, 7.0376836292E-2);
75
+ _PS256_CONST(cephes_log_p1, - 1.1514610310E-1);
76
+ _PS256_CONST(cephes_log_p2, 1.1676998740E-1);
77
+ _PS256_CONST(cephes_log_p3, - 1.2420140846E-1);
78
+ _PS256_CONST(cephes_log_p4, + 1.4249322787E-1);
79
+ _PS256_CONST(cephes_log_p5, - 1.6668057665E-1);
80
+ _PS256_CONST(cephes_log_p6, + 2.0000714765E-1);
81
+ _PS256_CONST(cephes_log_p7, - 2.4999993993E-1);
82
+ _PS256_CONST(cephes_log_p8, + 3.3333331174E-1);
83
+ _PS256_CONST(cephes_log_q1, -2.12194440e-4);
84
+ _PS256_CONST(cephes_log_q2, 0.693359375);
85
+
86
+
87
+ /* natural logarithm computed for 8 simultaneous float
88
+ return NaN for x <= 0
89
+ */
90
+ inline v8sf log256_ps(v8sf x) {
91
+ v8si imm0;
92
+ v8sf one = *(v8sf*)_ps256_1;
93
+
94
+ //v8sf invalid_mask = _mm256_cmple_ps(x, _mm256_setzero_ps());
95
+ v8sf invalid_mask = _mm256_cmp_ps(x, _mm256_setzero_ps(), _CMP_LE_OS);
96
+
97
+ x = _mm256_max_ps(x, *(v8sf*)_ps256_min_norm_pos); /* cut off denormalized stuff */
98
+
99
+ // can be done with AVX2
100
+ imm0 = _mm256_srli_epi32(_mm256_castps_si256(x), 23);
101
+
102
+ /* keep only the fractional part */
103
+ x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_mant_mask);
104
+ x = _mm256_or_ps(x, *(v8sf*)_ps256_0p5);
105
+
106
+ // this is again another AVX2 instruction
107
+ imm0 = _mm256_sub_epi32(imm0, *(v8si*)_pi32_256_0x7f);
108
+ v8sf e = _mm256_cvtepi32_ps(imm0);
109
+
110
+ e = _mm256_add_ps(e, one);
111
+
112
+ /* part2:
113
+ if( x < SQRTHF ) {
114
+ e -= 1;
115
+ x = x + x - 1.0;
116
+ } else { x = x - 1.0; }
117
+ */
118
+ //v8sf mask = _mm256_cmplt_ps(x, *(v8sf*)_ps256_cephes_SQRTHF);
119
+ v8sf mask = _mm256_cmp_ps(x, *(v8sf*)_ps256_cephes_SQRTHF, _CMP_LT_OS);
120
+ v8sf tmp = _mm256_and_ps(x, mask);
121
+ x = _mm256_sub_ps(x, one);
122
+ e = _mm256_sub_ps(e, _mm256_and_ps(one, mask));
123
+ x = _mm256_add_ps(x, tmp);
124
+
125
+ v8sf z = _mm256_mul_ps(x,x);
126
+
127
+ v8sf y = *(v8sf*)_ps256_cephes_log_p0;
128
+ y = _mm256_mul_ps(y, x);
129
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p1);
130
+ y = _mm256_mul_ps(y, x);
131
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p2);
132
+ y = _mm256_mul_ps(y, x);
133
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p3);
134
+ y = _mm256_mul_ps(y, x);
135
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p4);
136
+ y = _mm256_mul_ps(y, x);
137
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p5);
138
+ y = _mm256_mul_ps(y, x);
139
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p6);
140
+ y = _mm256_mul_ps(y, x);
141
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p7);
142
+ y = _mm256_mul_ps(y, x);
143
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p8);
144
+ y = _mm256_mul_ps(y, x);
145
+
146
+ y = _mm256_mul_ps(y, z);
147
+
148
+ tmp = _mm256_mul_ps(e, *(v8sf*)_ps256_cephes_log_q1);
149
+ y = _mm256_add_ps(y, tmp);
150
+
151
+
152
+ tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5);
153
+ y = _mm256_sub_ps(y, tmp);
154
+
155
+ tmp = _mm256_mul_ps(e, *(v8sf*)_ps256_cephes_log_q2);
156
+ x = _mm256_add_ps(x, y);
157
+ x = _mm256_add_ps(x, tmp);
158
+ x = _mm256_or_ps(x, invalid_mask); // negative arg will be NAN
159
+ return x;
160
+ }
161
+
162
+ _PS256_CONST(exp_hi, 88.3762626647949f);
163
+ _PS256_CONST(exp_lo, -88.3762626647949f);
164
+
165
+ _PS256_CONST(cephes_LOG2EF, 1.44269504088896341);
166
+ _PS256_CONST(cephes_exp_C1, 0.693359375);
167
+ _PS256_CONST(cephes_exp_C2, -2.12194440e-4);
168
+
169
+ _PS256_CONST(cephes_exp_p0, 1.9875691500E-4);
170
+ _PS256_CONST(cephes_exp_p1, 1.3981999507E-3);
171
+ _PS256_CONST(cephes_exp_p2, 8.3334519073E-3);
172
+ _PS256_CONST(cephes_exp_p3, 4.1665795894E-2);
173
+ _PS256_CONST(cephes_exp_p4, 1.6666665459E-1);
174
+ _PS256_CONST(cephes_exp_p5, 5.0000001201E-1);
175
+
176
+ inline v8sf exp256_ps(v8sf x) {
177
+ v8sf tmp = _mm256_setzero_ps(), fx;
178
+ v8si imm0;
179
+ v8sf one = *(v8sf*)_ps256_1;
180
+
181
+ x = _mm256_min_ps(x, *(v8sf*)_ps256_exp_hi);
182
+ x = _mm256_max_ps(x, *(v8sf*)_ps256_exp_lo);
183
+
184
+ /* express exp(x) as exp(g + n*log(2)) */
185
+ fx = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_LOG2EF);
186
+ fx = _mm256_add_ps(fx, *(v8sf*)_ps256_0p5);
187
+
188
+ /* how to perform a floorf with SSE: just below */
189
+ //imm0 = _mm256_cvttps_epi32(fx);
190
+ //tmp = _mm256_cvtepi32_ps(imm0);
191
+
192
+ tmp = _mm256_floor_ps(fx);
193
+
194
+ /* if greater, subtract 1 */
195
+ //v8sf mask = _mm256_cmpgt_ps(tmp, fx);
196
+ v8sf mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS);
197
+ mask = _mm256_and_ps(mask, one);
198
+ fx = _mm256_sub_ps(tmp, mask);
199
+
200
+ tmp = _mm256_mul_ps(fx, *(v8sf*)_ps256_cephes_exp_C1);
201
+ v8sf z = _mm256_mul_ps(fx, *(v8sf*)_ps256_cephes_exp_C2);
202
+ x = _mm256_sub_ps(x, tmp);
203
+ x = _mm256_sub_ps(x, z);
204
+
205
+ z = _mm256_mul_ps(x,x);
206
+
207
+ v8sf y = *(v8sf*)_ps256_cephes_exp_p0;
208
+ y = _mm256_mul_ps(y, x);
209
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p1);
210
+ y = _mm256_mul_ps(y, x);
211
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p2);
212
+ y = _mm256_mul_ps(y, x);
213
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p3);
214
+ y = _mm256_mul_ps(y, x);
215
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p4);
216
+ y = _mm256_mul_ps(y, x);
217
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p5);
218
+ y = _mm256_mul_ps(y, z);
219
+ y = _mm256_add_ps(y, x);
220
+ y = _mm256_add_ps(y, one);
221
+
222
+ /* build 2^n */
223
+ imm0 = _mm256_cvttps_epi32(fx);
224
+ // another two AVX2 instructions
225
+ imm0 = _mm256_add_epi32(imm0, *(v8si*)_pi32_256_0x7f);
226
+ imm0 = _mm256_slli_epi32(imm0, 23);
227
+ v8sf pow2n = _mm256_castsi256_ps(imm0);
228
+ y = _mm256_mul_ps(y, pow2n);
229
+ return y;
230
+ }
231
+
232
+ _PS256_CONST(minus_cephes_DP1, -0.78515625);
233
+ _PS256_CONST(minus_cephes_DP2, -2.4187564849853515625e-4);
234
+ _PS256_CONST(minus_cephes_DP3, -3.77489497744594108e-8);
235
+ _PS256_CONST(sincof_p0, -1.9515295891E-4);
236
+ _PS256_CONST(sincof_p1, 8.3321608736E-3);
237
+ _PS256_CONST(sincof_p2, -1.6666654611E-1);
238
+ _PS256_CONST(coscof_p0, 2.443315711809948E-005);
239
+ _PS256_CONST(coscof_p1, -1.388731625493765E-003);
240
+ _PS256_CONST(coscof_p2, 4.166664568298827E-002);
241
+ _PS256_CONST(cephes_FOPI, 1.27323954473516); // 4 / M_PI
242
+
243
+
244
+ /* evaluation of 8 sines at once using AVX intrinsics
245
+
246
+ The code is the exact rewriting of the cephes sinf function.
247
+ Precision is excellent as long as x < 8192 (I did not bother to
248
+ take into account the special handling they have for greater values
249
+ -- it does not return garbage for arguments over 8192, though, but
250
+ the extra precision is missing).
251
+
252
+ Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the
253
+ surprising but correct result.
254
+
255
+ */
256
+ inline v8sf sin256_ps(v8sf x) { // any x
257
+ v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, sign_bit, y;
258
+ v8si imm0, imm2;
259
+
260
+ sign_bit = x;
261
+ /* take the absolute value */
262
+ x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);
263
+ /* extract the sign bit (upper one) */
264
+ sign_bit = _mm256_and_ps(sign_bit, *(v8sf*)_ps256_sign_mask);
265
+
266
+ /* scale by 4/Pi */
267
+ y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI);
268
+
269
+ /*
270
+ Here we start a series of integer operations, which are in the
271
+ realm of AVX2.
272
+ If we don't have AVX, let's perform them using SSE2 directives
273
+ */
274
+
275
+ /* store the integer part of y in mm0 */
276
+ imm2 = _mm256_cvttps_epi32(y);
277
+ /* j=(j+1) & (~1) (see the cephes sources) */
278
+ // another two AVX2 instruction
279
+ imm2 = _mm256_add_epi32(imm2, *(v8si*)_pi32_256_1);
280
+ imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_inv1);
281
+ y = _mm256_cvtepi32_ps(imm2);
282
+
283
+ /* get the swap sign flag */
284
+ imm0 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_4);
285
+ imm0 = _mm256_slli_epi32(imm0, 29);
286
+ /* get the polynom selection mask
287
+ there is one polynom for 0 <= x <= Pi/4
288
+ and another one for Pi/4<x<=Pi/2
289
+
290
+ Both branches will be computed.
291
+ */
292
+ imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2);
293
+ imm2 = _mm256_cmpeq_epi32(imm2,*(v8si*)_pi32_256_0);
294
+
295
+ v8sf swap_sign_bit = _mm256_castsi256_ps(imm0);
296
+ v8sf poly_mask = _mm256_castsi256_ps(imm2);
297
+ sign_bit = _mm256_xor_ps(sign_bit, swap_sign_bit);
298
+
299
+ /* The magic pass: "Extended precision modular arithmetic"
300
+ x = ((x - y * DP1) - y * DP2) - y * DP3; */
301
+ xmm1 = *(v8sf*)_ps256_minus_cephes_DP1;
302
+ xmm2 = *(v8sf*)_ps256_minus_cephes_DP2;
303
+ xmm3 = *(v8sf*)_ps256_minus_cephes_DP3;
304
+ xmm1 = _mm256_mul_ps(y, xmm1);
305
+ xmm2 = _mm256_mul_ps(y, xmm2);
306
+ xmm3 = _mm256_mul_ps(y, xmm3);
307
+ x = _mm256_add_ps(x, xmm1);
308
+ x = _mm256_add_ps(x, xmm2);
309
+ x = _mm256_add_ps(x, xmm3);
310
+
311
+ /* Evaluate the first polynom (0 <= x <= Pi/4) */
312
+ y = *(v8sf*)_ps256_coscof_p0;
313
+ v8sf z = _mm256_mul_ps(x,x);
314
+
315
+ y = _mm256_mul_ps(y, z);
316
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p1);
317
+ y = _mm256_mul_ps(y, z);
318
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p2);
319
+ y = _mm256_mul_ps(y, z);
320
+ y = _mm256_mul_ps(y, z);
321
+ v8sf tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5);
322
+ y = _mm256_sub_ps(y, tmp);
323
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_1);
324
+
325
+ /* Evaluate the second polynom (Pi/4 <= x <= 0) */
326
+
327
+ v8sf y2 = *(v8sf*)_ps256_sincof_p0;
328
+ y2 = _mm256_mul_ps(y2, z);
329
+ y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p1);
330
+ y2 = _mm256_mul_ps(y2, z);
331
+ y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p2);
332
+ y2 = _mm256_mul_ps(y2, z);
333
+ y2 = _mm256_mul_ps(y2, x);
334
+ y2 = _mm256_add_ps(y2, x);
335
+
336
+ /* select the correct result from the two polynoms */
337
+ xmm3 = poly_mask;
338
+ y2 = _mm256_and_ps(xmm3, y2); //, xmm3);
339
+ y = _mm256_andnot_ps(xmm3, y);
340
+ y = _mm256_add_ps(y,y2);
341
+ /* update the sign */
342
+ y = _mm256_xor_ps(y, sign_bit);
343
+
344
+ return y;
345
+ }
346
+
347
+ /* almost the same as sin_ps */
348
+ inline v8sf cos256_ps(v8sf x) { // any x
349
+ v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, y;
350
+ v8si imm0, imm2;
351
+
352
+ /* take the absolute value */
353
+ x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);
354
+
355
+ /* scale by 4/Pi */
356
+ y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI);
357
+
358
+ /* store the integer part of y in mm0 */
359
+ imm2 = _mm256_cvttps_epi32(y);
360
+ /* j=(j+1) & (~1) (see the cephes sources) */
361
+ imm2 = _mm256_add_epi32(imm2, *(v8si*)_pi32_256_1);
362
+ imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_inv1);
363
+ y = _mm256_cvtepi32_ps(imm2);
364
+ imm2 = _mm256_sub_epi32(imm2, *(v8si*)_pi32_256_2);
365
+
366
+ /* get the swap sign flag */
367
+ imm0 = _mm256_andnot_si256(imm2, *(v8si*)_pi32_256_4);
368
+ imm0 = _mm256_slli_epi32(imm0, 29);
369
+ /* get the polynom selection mask */
370
+ imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2);
371
+ imm2 = _mm256_cmpeq_epi32(imm2, *(v8si*)_pi32_256_0);
372
+
373
+ v8sf sign_bit = _mm256_castsi256_ps(imm0);
374
+ v8sf poly_mask = _mm256_castsi256_ps(imm2);
375
+
376
+ /* The magic pass: "Extended precision modular arithmetic"
377
+ x = ((x - y * DP1) - y * DP2) - y * DP3; */
378
+ xmm1 = *(v8sf*)_ps256_minus_cephes_DP1;
379
+ xmm2 = *(v8sf*)_ps256_minus_cephes_DP2;
380
+ xmm3 = *(v8sf*)_ps256_minus_cephes_DP3;
381
+ xmm1 = _mm256_mul_ps(y, xmm1);
382
+ xmm2 = _mm256_mul_ps(y, xmm2);
383
+ xmm3 = _mm256_mul_ps(y, xmm3);
384
+ x = _mm256_add_ps(x, xmm1);
385
+ x = _mm256_add_ps(x, xmm2);
386
+ x = _mm256_add_ps(x, xmm3);
387
+
388
+ /* Evaluate the first polynom (0 <= x <= Pi/4) */
389
+ y = *(v8sf*)_ps256_coscof_p0;
390
+ v8sf z = _mm256_mul_ps(x,x);
391
+
392
+ y = _mm256_mul_ps(y, z);
393
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p1);
394
+ y = _mm256_mul_ps(y, z);
395
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p2);
396
+ y = _mm256_mul_ps(y, z);
397
+ y = _mm256_mul_ps(y, z);
398
+ v8sf tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5);
399
+ y = _mm256_sub_ps(y, tmp);
400
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_1);
401
+
402
+ /* Evaluate the second polynom (Pi/4 <= x <= 0) */
403
+
404
+ v8sf y2 = *(v8sf*)_ps256_sincof_p0;
405
+ y2 = _mm256_mul_ps(y2, z);
406
+ y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p1);
407
+ y2 = _mm256_mul_ps(y2, z);
408
+ y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p2);
409
+ y2 = _mm256_mul_ps(y2, z);
410
+ y2 = _mm256_mul_ps(y2, x);
411
+ y2 = _mm256_add_ps(y2, x);
412
+
413
+ /* select the correct result from the two polynoms */
414
+ xmm3 = poly_mask;
415
+ y2 = _mm256_and_ps(xmm3, y2); //, xmm3);
416
+ y = _mm256_andnot_ps(xmm3, y);
417
+ y = _mm256_add_ps(y,y2);
418
+ /* update the sign */
419
+ y = _mm256_xor_ps(y, sign_bit);
420
+
421
+ return y;
422
+ }
423
+
424
+ /* since sin256_ps and cos256_ps are almost identical, sincos256_ps could replace both of them..
425
+ it is almost as fast, and gives you a free cosine with your sine */
426
+ inline void sincos256_ps(v8sf x, v8sf *s, v8sf *c) {
427
+
428
+ v8sf xmm1, xmm2, xmm3 = _mm256_setzero_ps(), sign_bit_sin, y;
429
+ v8si imm0, imm2, imm4;
430
+
431
+ sign_bit_sin = x;
432
+ /* take the absolute value */
433
+ x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);
434
+ /* extract the sign bit (upper one) */
435
+ sign_bit_sin = _mm256_and_ps(sign_bit_sin, *(v8sf*)_ps256_sign_mask);
436
+
437
+ /* scale by 4/Pi */
438
+ y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI);
439
+
440
+ /* store the integer part of y in imm2 */
441
+ imm2 = _mm256_cvttps_epi32(y);
442
+
443
+ /* j=(j+1) & (~1) (see the cephes sources) */
444
+ imm2 = _mm256_add_epi32(imm2, *(v8si*)_pi32_256_1);
445
+ imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_inv1);
446
+
447
+ y = _mm256_cvtepi32_ps(imm2);
448
+ imm4 = imm2;
449
+
450
+ /* get the swap sign flag for the sine */
451
+ imm0 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_4);
452
+ imm0 = _mm256_slli_epi32(imm0, 29);
453
+ //v8sf swap_sign_bit_sin = _mm256_castsi256_ps(imm0);
454
+
455
+ /* get the polynom selection mask for the sine*/
456
+ imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2);
457
+ imm2 = _mm256_cmpeq_epi32(imm2, *(v8si*)_pi32_256_0);
458
+ //v8sf poly_mask = _mm256_castsi256_ps(imm2);
459
+
460
+ v8sf swap_sign_bit_sin = _mm256_castsi256_ps(imm0);
461
+ v8sf poly_mask = _mm256_castsi256_ps(imm2);
462
+
463
+ /* The magic pass: "Extended precision modular arithmetic"
464
+ x = ((x - y * DP1) - y * DP2) - y * DP3; */
465
+ xmm1 = *(v8sf*)_ps256_minus_cephes_DP1;
466
+ xmm2 = *(v8sf*)_ps256_minus_cephes_DP2;
467
+ xmm3 = *(v8sf*)_ps256_minus_cephes_DP3;
468
+ xmm1 = _mm256_mul_ps(y, xmm1);
469
+ xmm2 = _mm256_mul_ps(y, xmm2);
470
+ xmm3 = _mm256_mul_ps(y, xmm3);
471
+ x = _mm256_add_ps(x, xmm1);
472
+ x = _mm256_add_ps(x, xmm2);
473
+ x = _mm256_add_ps(x, xmm3);
474
+
475
+ imm4 = _mm256_sub_epi32(imm4, *(v8si*)_pi32_256_2);
476
+ imm4 = _mm256_andnot_si256(imm4, *(v8si*)_pi32_256_4);
477
+ imm4 = _mm256_slli_epi32(imm4, 29);
478
+
479
+ v8sf sign_bit_cos = _mm256_castsi256_ps(imm4);
480
+
481
+ sign_bit_sin = _mm256_xor_ps(sign_bit_sin, swap_sign_bit_sin);
482
+
483
+ /* Evaluate the first polynom (0 <= x <= Pi/4) */
484
+ v8sf z = _mm256_mul_ps(x,x);
485
+ y = *(v8sf*)_ps256_coscof_p0;
486
+
487
+ y = _mm256_mul_ps(y, z);
488
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p1);
489
+ y = _mm256_mul_ps(y, z);
490
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p2);
491
+ y = _mm256_mul_ps(y, z);
492
+ y = _mm256_mul_ps(y, z);
493
+ v8sf tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5);
494
+ y = _mm256_sub_ps(y, tmp);
495
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_1);
496
+
497
+ /* Evaluate the second polynom (Pi/4 <= x <= 0) */
498
+
499
+ v8sf y2 = *(v8sf*)_ps256_sincof_p0;
500
+ y2 = _mm256_mul_ps(y2, z);
501
+ y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p1);
502
+ y2 = _mm256_mul_ps(y2, z);
503
+ y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p2);
504
+ y2 = _mm256_mul_ps(y2, z);
505
+ y2 = _mm256_mul_ps(y2, x);
506
+ y2 = _mm256_add_ps(y2, x);
507
+
508
+ /* select the correct result from the two polynoms */
509
+ xmm3 = poly_mask;
510
+ v8sf ysin2 = _mm256_and_ps(xmm3, y2);
511
+ v8sf ysin1 = _mm256_andnot_ps(xmm3, y);
512
+ y2 = _mm256_sub_ps(y2,ysin2);
513
+ y = _mm256_sub_ps(y, ysin1);
514
+
515
+ xmm1 = _mm256_add_ps(ysin1,ysin2);
516
+ xmm2 = _mm256_add_ps(y,y2);
517
+
518
+ /* update the sign */
519
+ *s = _mm256_xor_ps(xmm1, sign_bit_sin);
520
+ *c = _mm256_xor_ps(xmm2, sign_bit_cos);
521
+ }
522
+
523
+ #endif // CPU_CAPABILITY_AVX2
524
+
525
+ #else
526
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
527
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/int_mm_kernel.h ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/core/Tensor.h>
5
+ #include <ATen/native/DispatchStub.h>
6
+
7
+ namespace at::native {
8
+
9
+ using weight_to_int4pack_fn = void (*)(const Tensor&, const Tensor&);
10
+ using int4pack_mm_fn =
11
+ void (*)(const Tensor&, const Tensor&, const Tensor&, int, const Tensor&);
12
+ using int8pack_mm_fn =
13
+ void (*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&);
14
+ using dyn_quant_pack_4bit_weight_fn = void (*)(
15
+ Tensor&,
16
+ const Tensor&,
17
+ const Tensor&,
18
+ const std::optional<Tensor>& bias,
19
+ const int64_t,
20
+ const int64_t,
21
+ const int64_t);
22
+ using dyn_quant_matmul_4bit_fn = void (*)(
23
+ const Tensor&,
24
+ const Tensor&,
25
+ const Tensor&,
26
+ const int64_t,
27
+ const int64_t,
28
+ const int64_t,
29
+ const int64_t);
30
+
31
+ DECLARE_DISPATCH(weight_to_int4pack_fn, weight_to_int4pack_stub)
32
+ DECLARE_DISPATCH(int4pack_mm_fn, int4pack_mm_stub)
33
+ DECLARE_DISPATCH(int8pack_mm_fn, int8pack_mm_stub)
34
+ DECLARE_DISPATCH(
35
+ dyn_quant_pack_4bit_weight_fn,
36
+ dyn_quant_pack_4bit_weight_stub)
37
+ DECLARE_DISPATCH(dyn_quant_matmul_4bit_fn, dyn_quant_matmul_4bit_stub)
38
+
39
+ } // namespace at::native
40
+
41
+ #else
42
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
43
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/mixed_data_type.h ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/core/Tensor.h>
5
+
6
+ namespace at::native {
7
+
8
+ inline ScalarType first_type() {
9
+ return ScalarType::Undefined;
10
+ }
11
+
12
+ template <typename... Args>
13
+ inline ScalarType first_type(const Tensor& arg, const Args&... parameters) {
14
+ return arg.defined() ? arg.scalar_type() : first_type(parameters...);
15
+ }
16
+
17
+ template <typename... Args>
18
+ inline bool is_mixed_type(const Tensor& input, const Args&... parameters) {
19
+ const auto parameter_type = first_type(parameters...);
20
+ return ((parameter_type != ScalarType::Undefined) &&
21
+ (parameter_type != input.scalar_type()));
22
+ }
23
+
24
+ // currently on CPU, mixed data type is only supported
25
+ // when input is 'BFloat16' or 'Half' and parameters are 'Float'
26
+ inline void check_mixed_data_type(const Tensor& input) {
27
+ TORCH_CHECK(at::isReducedFloatingType(input.scalar_type()),
28
+ "mixed dtype (CPU): all inputs must share same datatype.");
29
+ }
30
+
31
+ template <typename... Args>
32
+ inline void check_mixed_data_type(const Tensor& input, const Tensor& parameter, const Args&... parameters) {
33
+ TORCH_CHECK(!parameter.defined() || parameter.scalar_type() == ScalarType::Float,
34
+ "mixed dtype (CPU): expect parameter to have scalar type of Float");
35
+ check_mixed_data_type(input, parameters...);
36
+ }
37
+
38
+ inline ScalarType param_scalar_type(const Tensor& t, bool is_mixed_type) {
39
+ return is_mixed_type ? ScalarType::Float : t.scalar_type();
40
+ }
41
+
42
+ } // namespace at::native
43
+
44
+ #else
45
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
46
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/utils.h ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ #include <ATen/Parallel.h>
5
+ #include <ATen/core/TensorAccessor.h>
6
+ #include <ATen/cpu/vec/vec.h>
7
+ #include <c10/util/llvmMathExtras.h>
8
+
9
+ #ifdef USE_FBGEMM
10
+ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi")
11
+ #include <fbgemm/Fbgemm.h>
12
+ C10_DIAGNOSTIC_POP()
13
+ #endif
14
+
15
+ namespace at::native {
16
+
17
+ template <typename T>
18
+ inline void _store(T* dst, at::vec::Vectorized<T> src) {
19
+ src.store(dst);
20
+ }
21
+
22
+ inline void _store(at::BFloat16* dst, at::vec::Vectorized<float> src) {
23
+ auto res = at::vec::convert_float_bfloat16(src, src);
24
+ res.store(dst, at::vec::Vectorized<float>::size());
25
+ }
26
+
27
+ inline void _store(at::Half* dst, at::vec::Vectorized<float> src) {
28
+ auto res = at::vec::convert_float_half(src, src);
29
+ res.store(dst, at::vec::Vectorized<float>::size());
30
+ }
31
+
32
+ inline namespace CPU_CAPABILITY {
33
+
34
+ template <typename T>
35
+ inline T data_index_init(T offset) {
36
+ return offset;
37
+ }
38
+
39
+ template <typename T, typename... Args>
40
+ inline T data_index_init(T offset, T& x, const T& X, Args&&... args) {
41
+ offset = data_index_init(offset, std::forward<Args>(args)...);
42
+ x = offset % X;
43
+ return offset / X;
44
+ }
45
+
46
+ inline bool data_index_step() {
47
+ return true;
48
+ }
49
+
50
+ template <typename T, typename... Args>
51
+ inline bool data_index_step(T& x, const T& X, Args&&... args) {
52
+ if (data_index_step(std::forward<Args>(args)...)) {
53
+ x = ((x + 1) == X) ? 0 : (x + 1);
54
+ return x == 0;
55
+ }
56
+ return false;
57
+ }
58
+
59
+ // Helper struct for bfloat16/float16 vectorization
60
+ // Useful when you need float as immediate dtype or accumulate dtype
61
+ using namespace vec;
62
+ struct Vec2 {
63
+ Vectorized<float> val0, val1;
64
+ Vec2(Vectorized<float> v0, Vectorized<float> v1) : val0(v0), val1(v1) {}
65
+ Vec2(float v) : val0(v), val1(v) {}
66
+ static Vec2 loadu(const BFloat16* ptr) {
67
+ auto [v0, v1] = convert_bfloat16_float(Vectorized<BFloat16>::loadu(ptr));
68
+ return {v0, v1};
69
+ }
70
+ static Vec2 loadu(const Half* ptr) {
71
+ auto [v0, v1] = convert_half_float(Vectorized<Half>::loadu(ptr));
72
+ return {v0, v1};
73
+ }
74
+ static Vec2 loadu(const float* ptr) {
75
+ return {Vectorized<float>::loadu(ptr), Vectorized<float>::loadu(ptr + Vectorized<float>::size())};
76
+ }
77
+ void store(BFloat16* ptr) const {
78
+ Vectorized<BFloat16> val = convert_float_bfloat16(val0, val1);
79
+ val.store(ptr);
80
+ }
81
+ void store(Half* ptr) const {
82
+ Vectorized<Half> val = convert_float_half(val0, val1);
83
+ val.store(ptr);
84
+ }
85
+ void store(float* ptr) const {
86
+ val0.store(ptr);
87
+ val1.store(ptr + Vectorized<float>::size());
88
+ }
89
+ };
90
+ inline Vec2 operator+(const Vec2& a, const Vec2& b) { return {a.val0 + b.val0, a.val1 + b.val1}; }
91
+ inline Vec2 operator*(const Vec2& a, const Vec2& b) { return {a.val0 * b.val0, a.val1 * b.val1}; }
92
+ inline Vec2 operator-(const Vec2& a, const Vec2& b) { return {a.val0 - b.val0, a.val1 - b.val1}; }
93
+ inline Vec2 operator/(const Vec2& a, const Vec2& b) { return {a.val0 / b.val0, a.val1 / b.val1}; }
94
+ inline Vec2 maximum(const Vec2& a, const Vec2& b) { return {vec::maximum(a.val0, b.val0), vec::maximum(a.val1, b.val1)}; }
95
+ inline Vec2 minimum(const Vec2& a, const Vec2& b) { return {vec::minimum(a.val0, b.val0), vec::minimum(a.val1, b.val1)}; }
96
+
97
+ template <typename scalar_t> struct VectorizedType { using type = Vectorized<scalar_t>; };
98
+ template <> struct VectorizedType<BFloat16> { using type = Vec2; };
99
+ template <> struct VectorizedType<Half> { using type = Vec2; };
100
+ template <typename scalar_t> using VecType = typename VectorizedType<scalar_t>::type;
101
+
102
+ // Helper for mixed data type parameter Vec::load
103
+ inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const BFloat16* ptr) {
104
+ return convert_bfloat16_float(Vectorized<BFloat16>::loadu(ptr));
105
+ }
106
+
107
+ inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const Half* ptr) {
108
+ return convert_half_float(Vectorized<Half>::loadu(ptr));
109
+ }
110
+
111
+ inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const float* ptr) {
112
+ using Vec = Vectorized<float>;
113
+ return std::make_tuple(Vec::loadu(ptr), Vec::loadu(ptr + Vec::size()));
114
+ }
115
+
116
+ inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const BFloat16* ptr, int64_t count) {
117
+ return convert_bfloat16_float(Vectorized<BFloat16>::loadu(ptr, count));
118
+ }
119
+
120
+ inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const Half* ptr, int64_t count) {
121
+ return convert_half_float(Vectorized<Half>::loadu(ptr, count));
122
+ }
123
+
124
+ inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const float* ptr, int64_t count) {
125
+ using Vec = Vectorized<float>;
126
+ if (count > Vec::size()) {
127
+ return std::make_tuple(Vec::loadu(ptr), Vec::loadu(ptr + Vec::size(), count - Vec::size()));
128
+ } else {
129
+ return std::make_tuple(Vec::loadu(ptr, count), Vec(0));
130
+ }
131
+ }
132
+
133
+ } // namespace
134
+
135
+ namespace utils {
136
+
137
+ template <typename T>
138
+ T CeilLog2(const T& x) {
139
+ if (x <= 2) {
140
+ return 1;
141
+ }
142
+ // Last set bit is floor(log2(x)), floor + 1 is ceil
143
+ // except when x is an exact powers of 2, so subtract 1 first
144
+ return static_cast<T>(llvm::findLastSet(static_cast<uint64_t>(x) - 1)) + 1;
145
+ }
146
+
147
+ // matrix transpose:
148
+ // src has shape of M by N, with leading dimension of ld_src
149
+ // dst has shape of N by M, with leading dimension of ld_dst
150
+ template <typename T>
151
+ inline void transpose(int64_t M, int64_t N, const T* src, int64_t ld_src, T* dst, int64_t ld_dst) {
152
+ for (int64_t j = 0; j < N; j++) {
153
+ for (int64_t i = 0; i < M; i++) {
154
+ dst[j * ld_dst + i] = c10::load(&(src[i * ld_src + j]));
155
+ }
156
+ }
157
+ }
158
+
159
+ #ifdef USE_FBGEMM
160
+ template <>
161
+ inline void transpose<float>(int64_t M, int64_t N, const float* src, int64_t ld_src, float* dst, int64_t ld_dst) {
162
+ TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
163
+ fbgemm::transpose_simd<float>(M, N, src, ld_src, dst, ld_dst);
164
+ }
165
+
166
+ template <>
167
+ inline void transpose<uint16_t>(int64_t M, int64_t N, const uint16_t* src, int64_t ld_src, uint16_t* dst, int64_t ld_dst) {
168
+ TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
169
+ fbgemm::transpose_simd<uint16_t>(M, N, src, ld_src, dst, ld_dst);
170
+ }
171
+
172
+ template <>
173
+ inline void transpose<uint8_t>(int64_t M, int64_t N, const uint8_t* src, int64_t ld_src, uint8_t* dst, int64_t ld_dst) {
174
+ TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
175
+ fbgemm::transpose_simd<uint8_t>(M, N, src, ld_src, dst, ld_dst);
176
+ }
177
+ #endif
178
+
179
+ template <typename index_t, typename F>
180
+ inline void parallel_sparse_csr(
181
+ const TensorAccessor<index_t, 1>& crow_acc,
182
+ const int64_t M,
183
+ const int64_t nnz,
184
+ const F& f) {
185
+ TORCH_CHECK(crow_acc.size(0) == M + 1);
186
+
187
+ // directly parallel on `M` may lead to load imbalance,
188
+ // statically determine thread partition here to average payload
189
+ // for each thread.
190
+ int num_threads = at::get_num_threads();
191
+ std::vector<int64_t> thread_splits(num_threads + 1, M);
192
+
193
+ int64_t thread_averge_payload = std::max((int64_t)1, divup(nnz, num_threads));
194
+
195
+ thread_splits[0] = 0;
196
+ int64_t sum = 0;
197
+ int64_t t = 1;
198
+ for (const auto m : c10::irange(M)) {
199
+ int64_t row_start = crow_acc[m];
200
+ int64_t row_end = crow_acc[m + 1];
201
+ sum += row_end - row_start;
202
+ if (sum > t * thread_averge_payload) {
203
+ thread_splits[t] = m;
204
+ t++;
205
+ }
206
+ }
207
+ // need to restore the last index,
208
+ // due to rounding error when calculating `thread_averge_payload`.
209
+ thread_splits[num_threads] = M;
210
+
211
+ at::parallel_for(0, num_threads, 1, [&](int64_t cbegin, int64_t cend) {
212
+ int tid = at::get_thread_num();
213
+ int64_t begin = thread_splits[tid];
214
+ int64_t end = thread_splits[tid + 1];
215
+ f(begin, end);
216
+ });
217
+ }
218
+
219
+ } // namespace utils
220
+
221
+ } // namespace at::native
222
+
223
+ #else
224
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
225
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/zmath.h ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ // Complex number math operations that act as no-ops for other dtypes.
5
+ #include <c10/util/complex.h>
6
+ #include <c10/util/MathConstants.h>
7
+ #include<ATen/NumericUtils.h>
8
+
9
+ namespace at::native {
10
+ inline namespace CPU_CAPABILITY {
11
+
12
+ template <typename SCALAR_TYPE, typename VALUE_TYPE=SCALAR_TYPE>
13
+ inline VALUE_TYPE zabs (SCALAR_TYPE z) {
14
+ return z;
15
+ }
16
+
17
+ template<>
18
+ inline c10::complex<float> zabs <c10::complex<float>> (c10::complex<float> z) {
19
+ return c10::complex<float>(std::abs(z));
20
+ }
21
+
22
+ template<>
23
+ inline float zabs <c10::complex<float>, float> (c10::complex<float> z) {
24
+ return std::abs(z);
25
+ }
26
+
27
+ template<>
28
+ inline c10::complex<double> zabs <c10::complex<double>> (c10::complex<double> z) {
29
+ return c10::complex<double>(std::abs(z));
30
+ }
31
+
32
+ template<>
33
+ inline double zabs <c10::complex<double>, double> (c10::complex<double> z) {
34
+ return std::abs(z);
35
+ }
36
+
37
+ // This overload corresponds to non-complex dtypes.
38
+ // The function is consistent with its NumPy equivalent
39
+ // for non-complex dtypes where `pi` is returned for
40
+ // negative real numbers and `0` is returned for 0 or positive
41
+ // real numbers.
42
+ // Note: `nan` is propagated.
43
+ template <typename SCALAR_TYPE, typename VALUE_TYPE=SCALAR_TYPE>
44
+ inline VALUE_TYPE angle_impl (SCALAR_TYPE z) {
45
+ if (at::_isnan(z)) {
46
+ return z;
47
+ }
48
+ return z < 0 ? c10::pi<double> : 0;
49
+ }
50
+
51
+ template<>
52
+ inline c10::complex<float> angle_impl <c10::complex<float>> (c10::complex<float> z) {
53
+ return c10::complex<float>(std::arg(z), 0.0);
54
+ }
55
+
56
+ template<>
57
+ inline float angle_impl <c10::complex<float>, float> (c10::complex<float> z) {
58
+ return std::arg(z);
59
+ }
60
+
61
+ template<>
62
+ inline c10::complex<double> angle_impl <c10::complex<double>> (c10::complex<double> z) {
63
+ return c10::complex<double>(std::arg(z), 0.0);
64
+ }
65
+
66
+ template<>
67
+ inline double angle_impl <c10::complex<double>, double> (c10::complex<double> z) {
68
+ return std::arg(z);
69
+ }
70
+
71
+ template <typename SCALAR_TYPE, typename VALUE_TYPE=SCALAR_TYPE>
72
+ constexpr VALUE_TYPE real_impl (SCALAR_TYPE z) {
73
+ return z; //No-Op
74
+ }
75
+
76
+ template<>
77
+ constexpr c10::complex<float> real_impl <c10::complex<float>> (c10::complex<float> z) {
78
+ return c10::complex<float>(z.real(), 0.0);
79
+ }
80
+
81
+ template<>
82
+ constexpr float real_impl <c10::complex<float>, float> (c10::complex<float> z) {
83
+ return z.real();
84
+ }
85
+
86
+ template<>
87
+ constexpr c10::complex<double> real_impl <c10::complex<double>> (c10::complex<double> z) {
88
+ return c10::complex<double>(z.real(), 0.0);
89
+ }
90
+
91
+ template<>
92
+ constexpr double real_impl <c10::complex<double>, double> (c10::complex<double> z) {
93
+ return z.real();
94
+ }
95
+
96
+ template <typename SCALAR_TYPE, typename VALUE_TYPE=SCALAR_TYPE>
97
+ constexpr VALUE_TYPE imag_impl (SCALAR_TYPE /*z*/) {
98
+ return 0;
99
+ }
100
+
101
+ template<>
102
+ constexpr c10::complex<float> imag_impl <c10::complex<float>> (c10::complex<float> z) {
103
+ return c10::complex<float>(z.imag(), 0.0);
104
+ }
105
+
106
+ template<>
107
+ constexpr float imag_impl <c10::complex<float>, float> (c10::complex<float> z) {
108
+ return z.imag();
109
+ }
110
+
111
+ template<>
112
+ constexpr c10::complex<double> imag_impl <c10::complex<double>> (c10::complex<double> z) {
113
+ return c10::complex<double>(z.imag(), 0.0);
114
+ }
115
+
116
+ template<>
117
+ constexpr double imag_impl <c10::complex<double>, double> (c10::complex<double> z) {
118
+ return z.imag();
119
+ }
120
+
121
+ template <typename TYPE>
122
+ inline TYPE conj_impl (TYPE z) {
123
+ return z; //No-Op
124
+ }
125
+
126
+ template<>
127
+ inline c10::complex<at::Half> conj_impl <c10::complex<at::Half>> (c10::complex<at::Half> z) {
128
+ return c10::complex<at::Half>{z.real(), -z.imag()};
129
+ }
130
+
131
+ template<>
132
+ inline c10::complex<float> conj_impl <c10::complex<float>> (c10::complex<float> z) {
133
+ return c10::complex<float>(z.real(), -z.imag());
134
+ }
135
+
136
+ template<>
137
+ inline c10::complex<double> conj_impl <c10::complex<double>> (c10::complex<double> z) {
138
+ return c10::complex<double>(z.real(), -z.imag());
139
+ }
140
+
141
+ template <typename TYPE>
142
+ inline TYPE ceil_impl (TYPE z) {
143
+ return std::ceil(z);
144
+ }
145
+
146
+ template <>
147
+ inline c10::complex<float> ceil_impl (c10::complex<float> z) {
148
+ return c10::complex<float>(std::ceil(z.real()), std::ceil(z.imag()));
149
+ }
150
+
151
+ template <>
152
+ inline c10::complex<double> ceil_impl (c10::complex<double> z) {
153
+ return c10::complex<double>(std::ceil(z.real()), std::ceil(z.imag()));
154
+ }
155
+
156
+ template<typename T>
157
+ inline c10::complex<T> sgn_impl (c10::complex<T> z) {
158
+ if (z == c10::complex<T>(0, 0)) {
159
+ return c10::complex<T>(0, 0);
160
+ } else {
161
+ return z / zabs(z);
162
+ }
163
+ }
164
+
165
+ template <typename TYPE>
166
+ inline TYPE floor_impl (TYPE z) {
167
+ return std::floor(z);
168
+ }
169
+
170
+ template <>
171
+ inline c10::complex<float> floor_impl (c10::complex<float> z) {
172
+ return c10::complex<float>(std::floor(z.real()), std::floor(z.imag()));
173
+ }
174
+
175
+ template <>
176
+ inline c10::complex<double> floor_impl (c10::complex<double> z) {
177
+ return c10::complex<double>(std::floor(z.real()), std::floor(z.imag()));
178
+ }
179
+
180
+ template <typename TYPE>
181
+ inline TYPE round_impl (TYPE z) {
182
+ return std::nearbyint(z);
183
+ }
184
+
185
+ template <>
186
+ inline c10::complex<float> round_impl (c10::complex<float> z) {
187
+ return c10::complex<float>(std::nearbyint(z.real()), std::nearbyint(z.imag()));
188
+ }
189
+
190
+ template <>
191
+ inline c10::complex<double> round_impl (c10::complex<double> z) {
192
+ return c10::complex<double>(std::nearbyint(z.real()), std::nearbyint(z.imag()));
193
+ }
194
+
195
+ template <typename TYPE>
196
+ inline TYPE trunc_impl (TYPE z) {
197
+ return std::trunc(z);
198
+ }
199
+
200
+ template <>
201
+ inline c10::complex<float> trunc_impl (c10::complex<float> z) {
202
+ return c10::complex<float>(std::trunc(z.real()), std::trunc(z.imag()));
203
+ }
204
+
205
+ template <>
206
+ inline c10::complex<double> trunc_impl (c10::complex<double> z) {
207
+ return c10::complex<double>(std::trunc(z.real()), std::trunc(z.imag()));
208
+ }
209
+
210
+ template <typename TYPE, std::enable_if_t<!c10::is_complex<TYPE>::value, int> = 0>
211
+ inline TYPE max_impl (TYPE a, TYPE b) {
212
+ if (_isnan<TYPE>(a) || _isnan<TYPE>(b)) {
213
+ return std::numeric_limits<TYPE>::quiet_NaN();
214
+ } else {
215
+ return std::max(a, b);
216
+ }
217
+ }
218
+
219
+ template <typename TYPE, std::enable_if_t<c10::is_complex<TYPE>::value, int> = 0>
220
+ inline TYPE max_impl (TYPE a, TYPE b) {
221
+ if (_isnan<TYPE>(a)) {
222
+ return a;
223
+ } else if (_isnan<TYPE>(b)) {
224
+ return b;
225
+ } else {
226
+ return std::abs(a) > std::abs(b) ? a : b;
227
+ }
228
+ }
229
+
230
+ template <typename TYPE, std::enable_if_t<!c10::is_complex<TYPE>::value, int> = 0>
231
+ inline TYPE min_impl (TYPE a, TYPE b) {
232
+ if (_isnan<TYPE>(a) || _isnan<TYPE>(b)) {
233
+ return std::numeric_limits<TYPE>::quiet_NaN();
234
+ } else {
235
+ return std::min(a, b);
236
+ }
237
+ }
238
+
239
+ template <typename TYPE, std::enable_if_t<c10::is_complex<TYPE>::value, int> = 0>
240
+ inline TYPE min_impl (TYPE a, TYPE b) {
241
+ if (_isnan<TYPE>(a)) {
242
+ return a;
243
+ } else if (_isnan<TYPE>(b)) {
244
+ return b;
245
+ } else {
246
+ return std::abs(a) < std::abs(b) ? a : b;
247
+ }
248
+ }
249
+
250
+ } // end namespace
251
+ } //end at::native
252
+
253
+ #else
254
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
255
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/CUDAJitLoops.cuh ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ #include <ATen/jit_macros.h>
4
+
5
+ // Jiterator functions are guarded behind this macro
6
+ #if AT_USE_JITERATOR()
7
+
8
+ #include <ATen/OpMathType.h>
9
+ #include <ATen/TensorIterator.h>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <ATen/cuda/detail/OffsetCalculator.cuh>
12
+ #include <ATen/native/cuda/jit_utils.h>
13
+ #include <ATen/native/cuda/MemoryAccess.cuh>
14
+ #include <ATen/native/cuda/thread_constants.h>
15
+
16
+ #include <ATen/native/cuda/Loops.cuh>
17
+
18
+ #include <c10/macros/Macros.h>
19
+ #include <c10/core/ScalarType.h>
20
+ #include <c10/util/SmallBuffer.h>
21
+
22
+ #include <array>
23
+ #include <initializer_list>
24
+ #include <type_traits>
25
+ #include <tuple>
26
+ #include <mutex>
27
+
28
+ namespace at::native {
29
+
30
+ template <typename Tuple, std::size_t... I>
31
+ // warning : unused parameter when tuple is empty.
32
+ constexpr auto tuple_to_array_helper(const Tuple& t [[maybe_unused]], std::index_sequence<I...> seq) {
33
+ constexpr auto size = seq.size();
34
+ return std::array<const void*, size>{static_cast<const void*>(&std::get<I>(t))...};
35
+ }
36
+
37
+ // Helper function convert tuple to std::array<const void*, N>
38
+ // for passing the arguments to CUDA Kernel
39
+ // NOTE: We capture tuple by reference,
40
+ // so the pointers in returned array are only valid
41
+ // till tuple is alive.
42
+ template <typename ...Args>
43
+ constexpr auto tuple_to_array(const std::tuple<Args...>& extra_args) {
44
+ constexpr auto tuple_size = sizeof...(Args);
45
+ return tuple_to_array_helper(extra_args, std::make_index_sequence<tuple_size>{});
46
+ }
47
+
48
+ struct JittedVecKernelCache {
49
+ // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
50
+ at::cuda::jit::NvrtcFunction vec1;
51
+ at::cuda::jit::NvrtcFunction vec2;
52
+ at::cuda::jit::NvrtcFunction vec4;
53
+ at::cuda::jit::NvrtcFunction vec8;
54
+ #ifdef USE_ROCM
55
+ at::cuda::jit::NvrtcFunction vec16;
56
+ #endif
57
+
58
+ };
59
+
60
+ struct JittedKernelVariantCache {
61
+ JittedVecKernelCache vec;
62
+ at::cuda::jit::NvrtcFunction noncontiguous;
63
+ at::cuda::jit::NvrtcFunction dynamic_contiguous;
64
+ at::cuda::jit::NvrtcFunction dynamic_noncontiguous;
65
+ };
66
+
67
+ inline c10::SmallBuffer<const void*, 64> pack_kernel_args(
68
+ std::initializer_list<const void*> args,
69
+ c10::ArrayRef<const void*> extra_args) {
70
+ c10::SmallBuffer<const void*, 64> ret(args.size() + extra_args.size());
71
+ std::copy(args.begin(), args.end(), ret.data());
72
+ std::copy(extra_args.begin(), extra_args.end(), ret.data() + args.size());
73
+ return ret;
74
+ }
75
+
76
+ template<typename array_t,
77
+ typename inp_calc_t,
78
+ typename out_calc_t,
79
+ typename loader_t,
80
+ typename storer_t>
81
+ void launch_jitted_unrolled_kernel(
82
+ std::mutex &jiterator_mutex,
83
+ at::cuda::jit::NvrtcFunction &fn_cache,
84
+ const at::cuda::jit::KernelDescriptor &desc,
85
+ int64_t N,
86
+ array_t data,
87
+ inp_calc_t ic,
88
+ out_calc_t oc,
89
+ loader_t l,
90
+ storer_t s,
91
+ bool contiguous,
92
+ at::cuda::jit::BinaryFuncVariant scalar_pos,
93
+ const void* scalar_val,
94
+ c10::ArrayRef<const void*> extra_args) {
95
+
96
+ TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
97
+
98
+ int tws = at::cuda::jit::calc_thread_work_size(desc.nInputs, desc.nOutputs, desc.f_inputs_type, desc.result_type);
99
+ int bws = tws * num_threads();
100
+ //casting result to int is always safe, intermediate is int64 and won't overflow
101
+ const uint32_t grid = (N + bws - 1) / bws;
102
+
103
+ if (!fn_cache.function) {
104
+ const std::lock_guard<std::mutex> lock{jiterator_mutex};
105
+ if (!fn_cache.function) {
106
+ constexpr bool dynamic_casting = !std::is_same<decltype(l), memory::LoadWithoutCast>() ||
107
+ !std::is_same<decltype(s), memory::StoreWithoutCast>();
108
+ auto code = at::cuda::jit::generate_code(
109
+ desc, contiguous, dynamic_casting, scalar_pos, tws);
110
+ fn_cache = at::cuda::jit::jit_pwise_function(code, desc.name);
111
+ }
112
+ }
113
+
114
+ auto args = pack_kernel_args({&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args);
115
+ at::cuda::jit::launch_jitted_pwise_function(fn_cache, args.data(), {grid, 1u, 1u},
116
+ {num_threads(), 1u, 1u});
117
+ }
118
+
119
+ template<int arity, typename array_t>
120
+ void launch_jitted_vectorized_kernel(
121
+ std::mutex &jiterator_mutex, JittedVecKernelCache &fn_cache,
122
+ const at::cuda::jit::KernelDescriptor &desc, int64_t N, array_t data,
123
+ at::cuda::jit::BinaryFuncVariant scalar_pos,
124
+ const void *scalar_val, c10::ArrayRef<const void*> extra_args) {
125
+ TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
126
+
127
+ int tws = at::cuda::jit::calc_thread_work_size(desc.nInputs, desc.nOutputs, desc.f_inputs_type, desc.result_type);
128
+ int bws = tws * num_threads();
129
+ // N is still int64_t for the computation, but it's always safe to cast result to int
130
+ const uint32_t grid = (N + bws - 1) / bws;
131
+
132
+ int vec_size = at::cuda::jit::can_vectorize_up_to(
133
+ desc, c10::ArrayRef<char*>(data.data(), data.size()));
134
+
135
+ #ifndef USE_ROCM
136
+ const auto input_size = c10::scalarTypeToTypeMeta(desc.f_inputs_type).itemsize();
137
+ const int optimal_vec_size = 16 / static_cast<int>(input_size);
138
+ vec_size = std::min<int>(optimal_vec_size, vec_size);
139
+ // Here we purposely omit vec8 for 1-byte data because of a bug in NVCC
140
+ // that causes some numerical mismatches with uint8 on sm80 and sm90.
141
+ // TODO: Revisit this after CUDA 12.8 update.
142
+ if (input_size < 2) {
143
+ vec_size = std::min<int>(vec_size, 4);
144
+ }
145
+ #endif
146
+
147
+ // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
148
+ // fn_ptr is set to the appropriate function based on the vec size and GPU used
149
+ at::cuda::jit::NvrtcFunction* fn_ptr = nullptr;
150
+
151
+ #ifdef USE_ROCM
152
+ if (vec_size == 16) {
153
+ fn_ptr = &fn_cache.vec16;
154
+ } else
155
+ #endif
156
+ if (vec_size == 8) {
157
+ fn_ptr = &fn_cache.vec8;
158
+ } else if (vec_size == 4) {
159
+ fn_ptr = &fn_cache.vec4;
160
+ } else if (vec_size == 2) {
161
+ fn_ptr = &fn_cache.vec2;
162
+ } else if (vec_size ==1) {
163
+ fn_ptr = &fn_cache.vec1;
164
+ } else {
165
+ TORCH_INTERNAL_ASSERT(false, "unexpected vec_size for jitter vectorized kernel");
166
+ }
167
+
168
+ bool vectorized = vec_size > 1;
169
+
170
+ if (!fn_ptr->function) {
171
+ const std::lock_guard<std::mutex> lock{jiterator_mutex};
172
+ if (!fn_ptr->function) { // cache miss!
173
+
174
+ // Generates program
175
+ auto code = at::cuda::jit::generate_code(
176
+ desc, /*contiguous=*/true, /*dynamic_casting=*/false,
177
+ scalar_pos, tws, vectorized, vec_size);
178
+ std::string kernel_name = vectorized ? desc.name + "_vectorized" + std::to_string(vec_size) : desc.name;
179
+
180
+ // Acquires the program
181
+ *fn_ptr = at::cuda::jit::jit_pwise_function(code, kernel_name);
182
+ }
183
+ }
184
+
185
+ if (vectorized) {
186
+ auto args = pack_kernel_args({&N, &data, scalar_val}, extra_args);
187
+ at::cuda::jit::launch_jitted_pwise_function(
188
+ *fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
189
+ } else {
190
+ // NVCC complains about unused variables l and s.
191
+ // It should be false positive in most cases, so we suppress the warnings.
192
+ #pragma nv_diagnostic push
193
+ #pragma nv_diag_suppress 177
194
+ auto ic = TrivialOffsetCalculator<arity>();
195
+ auto oc = TrivialOffsetCalculator<1>();
196
+ auto l = memory::LoadWithoutCast();
197
+ auto s = memory::StoreWithoutCast();
198
+
199
+ auto args = pack_kernel_args(
200
+ {&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args);
201
+ at::cuda::jit::launch_jitted_pwise_function(
202
+ *fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
203
+ #pragma nv_diagnostic pop
204
+ }
205
+ }
206
+
207
+ template <int arity>
208
+ void jitted_gpu_kernel_generic(
209
+ std::mutex &jiterator_mutex,
210
+ JittedKernelVariantCache &cache,
211
+ const at::cuda::jit::KernelDescriptor &desc,
212
+ at::cuda::jit::BinaryFuncVariant scalar_pos,
213
+ c10::ArrayRef<const void*> extra_args,
214
+ TensorIteratorBase& iter,
215
+ const bool dynamic_casting,
216
+ const void *scalar_val) {
217
+ TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
218
+ TORCH_INTERNAL_ASSERT(iter.ninputs() == arity);
219
+ TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
220
+
221
+ constexpr int ntensors = arity + 1;
222
+ std::array<char*, ntensors> data;
223
+ for (auto i : c10::irange(ntensors)) {
224
+ data[i] = (char*)iter.data_ptr(i);
225
+ }
226
+
227
+ int64_t numel = iter.numel();
228
+ bool contiguous = iter.is_contiguous();
229
+
230
+ // Decides which of 4 kernel types to launch
231
+ // Variations are:
232
+ // - Case 1: no dynamic casting and contiguous
233
+ // - Case 2: no dynamic casting and noncontiguous
234
+ // - Case 3: dynamic casting and contiguous
235
+ // - Case 4: dynamic casting and noncontiguous
236
+ // These cases align with the non-jitted CUDALoops.cuh cases in gpu_kernel_impl
237
+
238
+ if (!dynamic_casting) {
239
+ if (contiguous) {
240
+ // Case 1: no dynamic casting and contiguous
241
+ launch_jitted_vectorized_kernel<arity>(
242
+ jiterator_mutex, cache.vec, desc,
243
+ numel, data, scalar_pos, scalar_val, extra_args);
244
+ return;
245
+ }
246
+
247
+ // Case 2: no dynamic casting and noncontiguous
248
+ auto input_offset_calculator = make_input_offset_calculator<arity>(iter);
249
+ auto output_offset_calculator = make_output_offset_calculator(iter);
250
+ auto loader = memory::LoadWithoutCast();
251
+ auto storer = memory::StoreWithoutCast();
252
+ launch_jitted_unrolled_kernel(
253
+ jiterator_mutex, cache.noncontiguous, desc, numel, data,
254
+ input_offset_calculator, output_offset_calculator, loader,
255
+ storer, contiguous, scalar_pos, scalar_val, extra_args);
256
+ return;
257
+ }
258
+
259
+ // Cases 3 and 4 are handled below
260
+ // Both require construction of a storer (this asserts 1 output) and one or more loaders
261
+
262
+ // Creates store cast to output (the zeroth tensor in TensorIterator)
263
+ auto storer = memory::StoreWithCast<1>(iter);
264
+
265
+ // Creates load casts from inputs (note offset indexing into the iterators 1...n tensors)
266
+ auto loader = memory::LoadWithCast<arity>(iter);
267
+
268
+ if (contiguous) {
269
+ // Case 3: dynamic casting and contiguous
270
+ auto input_offset_calculator = TrivialOffsetCalculator<arity>();
271
+ auto output_offset_calculator = TrivialOffsetCalculator<1>();
272
+ launch_jitted_unrolled_kernel(
273
+ jiterator_mutex, cache.dynamic_contiguous, desc, numel, data, input_offset_calculator,
274
+ output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args);
275
+ return;
276
+ }
277
+
278
+ // Case 4: dynamic casting and noncontiguous
279
+ auto input_offset_calculator = make_input_offset_calculator<arity>(iter);
280
+ auto output_offset_calculator = make_output_offset_calculator(iter);
281
+ launch_jitted_unrolled_kernel(
282
+ jiterator_mutex, cache.dynamic_noncontiguous, desc, numel, data, input_offset_calculator,
283
+ output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args);
284
+ }
285
+
286
+ // NOTE: static to reduce chances of name collision.
287
+ template <
288
+ char const* name,
289
+ typename result_type,
290
+ typename f_inputs_type,
291
+ int arity,
292
+ at::cuda::jit::BinaryFuncVariant scalar_pos =
293
+ at::cuda::jit::BinaryFuncVariant::NoScalar,
294
+ typename... ExtraArgs>
295
+ static void jitted_gpu_kernel_impl(
296
+ TensorIteratorBase& iter,
297
+ const std::string &f,
298
+ const bool dynamic_casting,
299
+ at::opmath_type<f_inputs_type> scalar_val,
300
+ const std::tuple<ExtraArgs...>& extra_args) {
301
+
302
+ // TODO: Memory use can probably be optimized by reusing kernels across GPUs with
303
+ // the same compute capability
304
+ static std::mutex jiterator_mutex;
305
+ static std::vector<JittedKernelVariantCache> device_caches(c10::cuda::device_count());
306
+
307
+ constexpr int nInputs = arity;
308
+ constexpr int nOutputs = 1; // TODO: Support more than 1 output
309
+ static const auto desc = at::cuda::jit::make_kernel_descriptor<
310
+ result_type, f_inputs_type, ExtraArgs...>(name, f, nInputs, nOutputs);
311
+
312
+ auto &cache = device_caches[iter.device().index()];
313
+ auto extra_args_array = tuple_to_array(extra_args);
314
+ return jitted_gpu_kernel_generic<arity>(
315
+ jiterator_mutex,
316
+ cache,
317
+ desc,
318
+ scalar_pos,
319
+ extra_args_array,
320
+ iter,
321
+ dynamic_casting,
322
+ &scalar_val
323
+ );
324
+ }
325
+
326
+ } // at::native
327
+
328
+ #endif // AT_USE_JITERATOR()
329
+
330
+ #else
331
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
332
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)