koichi12 commited on
Commit
edce735
·
verified ·
1 Parent(s): f52c26c

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/AtomicAddFloat.h +37 -0
  2. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/ChannelShuffleKernel.h +14 -0
  3. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/CopyKernel.h +14 -0
  4. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/DepthwiseConvKernel.h +21 -0
  5. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/DistributionTemplates.h +425 -0
  6. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/GridSamplerKernel.h +34 -0
  7. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/IndexKernelUtils.h +87 -0
  8. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Intrinsics.h +33 -0
  9. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/IsContiguous.h +62 -0
  10. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/LogAddExp.h +61 -0
  11. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/MaxUnpoolKernel.h +14 -0
  12. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/PixelShuffleKernel.h +14 -0
  13. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Reduce.h +314 -0
  14. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SampledAddmmKernel.h +12 -0
  15. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SerialStackImpl.h +146 -0
  16. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SoftmaxKernel.h +28 -0
  17. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SpmmReduceKernel.h +22 -0
  18. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/StackKernel.h +12 -0
  19. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/UpSampleKernelAVXAntialias.h +1376 -0
  20. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/WeightNormKernel.h +20 -0
  21. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/avx_mathfun.h +522 -0
  22. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/int_mm_kernel.h +16 -0
  23. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/mixed_data_type.h +41 -0
  24. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/moments_utils.h +202 -0
  25. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/utils.h +212 -0
  26. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/zmath.h +250 -0
  27. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Activation.h +20 -0
  28. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/BinaryInternal.h +48 -0
  29. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CUDAJitLoops.cuh +296 -0
  30. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CUDALoops.cuh +348 -0
  31. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CompositeRandomAccessor.h +35 -0
  32. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Copy.h +10 -0
  33. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CuFFTUtils.h +73 -0
  34. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/DeviceSqrt.cuh +25 -0
  35. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/DistributionTemplates.h +671 -0
  36. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Distributions.h +25 -0
  37. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/ForeachFunctors.cuh +681 -0
  38. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/GridSampler.cuh +321 -0
  39. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/IndexKernel.h +16 -0
  40. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/KernelUtils.cuh +149 -0
  41. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/LaunchUtils.h +18 -0
  42. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/MemoryAccess.cuh +389 -0
  43. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/MiscUtils.h +32 -0
  44. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/MultiTensorApply.cuh +379 -0
  45. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Normalization.cuh +1742 -0
  46. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Pow.cuh +58 -0
  47. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Randperm.cuh +58 -0
  48. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Resize.h +53 -0
  49. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/RowwiseScaledMM.h +15 -0
  50. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/ScanUtils.cuh +459 -0
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/AtomicAddFloat.h ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef ATOMIC_ADD_FLOAT
2
+ #define ATOMIC_ADD_FLOAT
3
+
4
+ #if (defined(__x86_64__) || defined(__i386__) || defined(__aarch64__))
5
+ #include <ATen/native/cpu/Intrinsics.h>
6
+ #else
7
+ #define _mm_pause()
8
+ #endif
9
+
10
+ #include <atomic>
11
+
12
+ static inline void cpu_atomic_add_float(float* dst, float fvalue)
13
+ {
14
+ typedef union {
15
+ unsigned intV;
16
+ float floatV;
17
+ } uf32_t;
18
+
19
+ uf32_t new_value, old_value;
20
+ std::atomic<unsigned>* dst_intV = (std::atomic<unsigned>*)(dst);
21
+
22
+ old_value.floatV = *dst;
23
+ new_value.floatV = old_value.floatV + fvalue;
24
+
25
+ unsigned* old_intV = (unsigned*)(&old_value.intV);
26
+ while (!std::atomic_compare_exchange_strong(dst_intV, old_intV, new_value.intV)) {
27
+ #ifdef __aarch64__
28
+ __asm__ __volatile__("yield;" : : : "memory");
29
+ #else
30
+ _mm_pause();
31
+ #endif
32
+ old_value.floatV = *dst;
33
+ new_value.floatV = old_value.floatV + fvalue;
34
+ }
35
+ }
36
+
37
+ #endif
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/ChannelShuffleKernel.h ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/native/DispatchStub.h>
3
+ #include <cstdint>
4
+
5
+ namespace at {
6
+ class TensorBase;
7
+ }
8
+
9
+ namespace at::native {
10
+
11
+ using channel_shuffle_fn = void(*)(TensorBase&, const TensorBase&, int64_t);
12
+ DECLARE_DISPATCH(channel_shuffle_fn, channel_shuffle_kernel);
13
+
14
+ } // at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/CopyKernel.h ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/TensorIterator.h>
4
+
5
+ namespace at {
6
+ struct TensorIteratorBase;
7
+
8
+ namespace native {
9
+ inline namespace CPU_CAPABILITY {
10
+
11
+ void direct_copy_kernel(TensorIteratorBase &iter);
12
+ void copy_kernel(TensorIterator& iter, bool /*non_blocking*/);
13
+
14
+ }}} // namespace at::native::CPU_CAPABILITY
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/DepthwiseConvKernel.h ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/DispatchStub.h>
4
+ #include <c10/util/ArrayRef.h>
5
+
6
+ /*
7
+ Depthwise 3x3 Winograd convolution operator
8
+ */
9
+
10
+ namespace at {
11
+ class Tensor;
12
+
13
+ namespace native {
14
+
15
+ using convolution_depthwise3x3_winograd_fn =
16
+ Tensor (*)(const Tensor &, const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, int64_t);
17
+
18
+ DECLARE_DISPATCH(convolution_depthwise3x3_winograd_fn, convolution_depthwise3x3_winograd_stub);
19
+
20
+ } // namespace native
21
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/DistributionTemplates.h ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/CPUApplyUtils.h>
4
+ #include <ATen/Dispatch.h>
5
+ #include <ATen/Dispatch_v2.h>
6
+ #include <ATen/ExpandBase.h>
7
+ #include <ATen/core/DistributionsHelper.h>
8
+ #include <ATen/native/TensorIterator.h>
9
+ #include <ATen/native/cpu/Loops.h>
10
+ #include <mutex>
11
+
12
+ #ifdef CPU_CAPABILITY_AVX2
13
+ #include <ATen/native/cpu/avx_mathfun.h>
14
+ #include <c10/util/irange.h>
15
+ #endif
16
+
17
+
18
+
19
+
20
+ namespace at::native::templates::cpu {
21
+ namespace {
22
+
23
+ // ==================================================== Random ========================================================
24
+
25
+ template<typename RNG>
26
+ void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG generator) {
27
+ AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cpu", AT_WRAP([&] {
28
+ std::lock_guard<std::mutex> lock(generator->mutex_);
29
+ cpu_serial_kernel(iter, [range, base, generator]() -> scalar_t {
30
+ uniform_int_from_to_distribution<scalar_t> random(range, base);
31
+ return random(generator);
32
+ });
33
+ }), kBool, kHalf, kBFloat16, AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
34
+ }
35
+
36
+ // This is the special kernel to handle single specific case:
37
+ // from(inclusive) = std::numeric_limits<int64_t>::lowest()
38
+ // to(exclusive) = None (= std::numeric_limits<int64_t>::max() + 1)
39
+ template<typename RNG>
40
+ void random_full_64_bits_range_kernel(TensorIteratorBase& iter, RNG generator) {
41
+ AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, iter.dtype(), "random_full_64_bits_range_kernel_cpu", [&] {
42
+ if constexpr (std::is_same_v<scalar_t, int64_t> ||
43
+ std::is_same_v<scalar_t, double> ||
44
+ std::is_same_v<scalar_t, float> ||
45
+ std::is_same_v<scalar_t, at::BFloat16>) {
46
+ std::lock_guard<std::mutex> lock(generator->mutex_);
47
+ cpu_serial_kernel(iter, [generator]() -> scalar_t {
48
+ uniform_int_full_range_distribution<scalar_t> random;
49
+ return random(generator);
50
+ });
51
+ } else {
52
+ TORCH_CHECK(false, "random_full_64_bits_range_kernel_cpu handles only int64, double, float and bfloat16");
53
+ }
54
+ });
55
+ }
56
+
57
+ template<typename RNG>
58
+ struct RandomFromToKernel {
59
+ void operator()(TensorIteratorBase& iter, uint64_t range, int64_t base, std::optional<Generator> gen) {
60
+ random_from_to_kernel(iter, range, base, check_generator<RNG>(gen));
61
+ }
62
+ void operator()(TensorIteratorBase& iter, std::optional<Generator> gen) {
63
+ random_full_64_bits_range_kernel(iter, check_generator<RNG>(gen));
64
+ }
65
+ };
66
+
67
+ template<typename RNG>
68
+ void random_kernel(TensorIteratorBase& iter, RNG generator) {
69
+ std::lock_guard<std::mutex> lock(generator->mutex_);
70
+ AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cpu", [&] {
71
+ cpu_serial_kernel(iter, [generator]() -> scalar_t {
72
+ uniform_int_distribution<scalar_t> random;
73
+ return random(generator);
74
+ });
75
+ });
76
+ }
77
+
78
+ template<typename RNG>
79
+ struct RandomKernel {
80
+ void operator()(TensorIteratorBase& iter, std::optional<Generator> gen) {
81
+ random_kernel(iter, check_generator<RNG>(gen));
82
+ }
83
+ };
84
+
85
+ // ==================================================== Normal ========================================================
86
+
87
+ #ifdef CPU_CAPABILITY_AVX2
88
+ static void normal_fill_16_AVX2(float *data,
89
+ const __m256* two_pi,
90
+ const __m256* one,
91
+ const __m256* minus_two,
92
+ const __m256* mean,
93
+ const __m256* std_v) {
94
+ const __m256 u1 = _mm256_sub_ps(*one, _mm256_loadu_ps(data));
95
+ const __m256 u2 = _mm256_loadu_ps(data + 8);
96
+ // sincos256_ps and log256_ps are from avx_mathfun.h
97
+ const __m256 radius = _mm256_sqrt_ps(_mm256_mul_ps(*minus_two, log256_ps(u1)));
98
+ const __m256 theta = _mm256_mul_ps(*two_pi, u2);
99
+ __m256 sintheta, costheta;
100
+ sincos256_ps(theta, &sintheta, &costheta);
101
+ const __m256 n1 = _mm256_mul_ps(radius, costheta);
102
+ const __m256 n2 = _mm256_mul_ps(radius, sintheta);
103
+ _mm256_storeu_ps(data, _mm256_fmadd_ps(n1, *std_v, *mean));
104
+ _mm256_storeu_ps(data + 8, _mm256_fmadd_ps(n2, *std_v, *mean));
105
+ }
106
+
107
+ template<typename RNG>
108
+ void normal_fill_AVX2(const TensorBase &self, const float mean, const float std, RNG generator) {
109
+ float *data = self.data_ptr<float>();
110
+ auto size = self.numel();
111
+ std::lock_guard<std::mutex> lock(generator->mutex_);
112
+ for (const auto i : c10::irange(size)) {
113
+ at::uniform_real_distribution<float> uniform(0, 1);
114
+ data[i] = uniform(generator);
115
+ }
116
+ const __m256 two_pi = _mm256_set1_ps(2.0f * c10::pi<double>);
117
+ const __m256 one = _mm256_set1_ps(1.0f);
118
+ const __m256 minus_two = _mm256_set1_ps(-2.0f);
119
+ const __m256 mean_v = _mm256_set1_ps(mean);
120
+ const __m256 std_v = _mm256_set1_ps(std);
121
+
122
+ for (int64_t i = 0; i < size - 15; i += 16) {
123
+ normal_fill_16_AVX2(data + i, &two_pi, &one, &minus_two, &mean_v, &std_v);
124
+ }
125
+
126
+ if (size % 16 != 0) {
127
+ // Recompute the last 16 values.
128
+ data = data + size - 16;
129
+ for (const auto i : c10::irange(16)) {
130
+ at::uniform_real_distribution<float> uniform(0, 1);
131
+ data[i] = uniform(generator);
132
+ }
133
+ normal_fill_16_AVX2(data, &two_pi, &one, &minus_two, &mean_v, &std_v);
134
+ }
135
+ }
136
+ #endif
137
+
138
+ template <typename scalar_t>
139
+ static void normal_fill_16(scalar_t *data, const scalar_t mean, const scalar_t std) {
140
+ for (const auto j : c10::irange(8)) {
141
+ const scalar_t u1 = 1 - data[j]; // [0, 1) -> (0, 1] for log.
142
+ const scalar_t u2 = data[j + 8];
143
+ const scalar_t radius = std::sqrt(-2 * std::log(u1));
144
+ const scalar_t theta = 2.0f * c10::pi<double> * u2;
145
+ data[j] = radius * std::cos(theta) * std + mean;
146
+ data[j + 8] = radius * std::sin(theta) * std + mean;
147
+ }
148
+ }
149
+
150
+ #if defined(__VSX__) || defined(CPU_CAPABILITY_VSX)
151
+ static void normal_fill_16_VSX(float *data,const Vectorized<float> &two_pi,const Vectorized<float> &one,const Vectorized<float> &minus_two,const Vectorized<float> &mean,const Vectorized<float> &std) {
152
+ using Vec = Vectorized<float>;
153
+ Vec u1=one-Vec::loadu(data);
154
+ Vec u2=Vec::loadu(data+8);
155
+ Vec radius=(minus_two * u1.log());
156
+ radius=radius.sqrt();
157
+ Vec theta=two_pi * u2;
158
+ Vec output_vec=radius * theta.cos() * std + mean;
159
+ Vec output_vec2=radius * theta.sin() * std + mean;
160
+ output_vec.store(data);
161
+ output_vec2.store(data+8);
162
+ }
163
+
164
+ template <typename scalar_t, typename RNG>
165
+ void normal_fill_VSX(const TensorBase &self, const scalar_t mean, const scalar_t std, RNG generator) {
166
+ float *data = self.data_ptr<float>();
167
+ auto size = self.numel();
168
+ std::lock_guard<std::mutex> lock(generator->mutex_);
169
+ for (const auto i : c10::irange(size)) {
170
+ at::uniform_real_distribution<scalar_t> uniform(0, 1);
171
+ data[i] = uniform(generator);
172
+ }
173
+
174
+ using Vec = Vectorized<float>;
175
+ const Vec two_pi = Vec(2.0f * c10::pi<double>);
176
+ const Vec one = Vec(1.0f);
177
+ const Vec minus_two = Vec(-2.0f);
178
+ const Vec var_vec = Vec(std);
179
+ const Vec mean_vec = Vec(mean);
180
+
181
+ for (int64_t i = 0; i < size - 15; i += 16) {
182
+ if(Vec::size()==8) {
183
+ normal_fill_16_VSX(data + i, two_pi, one, minus_two, mean_vec, var_vec);
184
+ }
185
+ else{
186
+ normal_fill_16<scalar_t>(data + i, mean, std);
187
+ }
188
+ }
189
+ if (size % 16 != 0) {
190
+ // Recompute the last 16 values.
191
+ data = data + size - 16;
192
+ for (const auto i : c10::irange(16)) {
193
+ at::uniform_real_distribution<scalar_t> uniform(0, 1);
194
+ data[i] = uniform(generator);
195
+ }
196
+ if(Vec::size()==8){
197
+ normal_fill_16_VSX(data, two_pi, one, minus_two, mean_vec, var_vec);
198
+ }
199
+ else{
200
+ normal_fill_16<scalar_t>(data, mean, std);
201
+ }
202
+ }
203
+ }
204
+ #endif //VSX
205
+
206
+ template <typename scalar_t, typename RNG>
207
+ void normal_fill(const TensorBase &self, const scalar_t mean, const scalar_t std, RNG generator) {
208
+ scalar_t *data = self.data_ptr<scalar_t>();
209
+ auto size = self.numel();
210
+ std::lock_guard<std::mutex> lock(generator->mutex_);
211
+ for (const auto i : c10::irange(size)) {
212
+ at::uniform_real_distribution<scalar_t> uniform(0, 1);
213
+ data[i] = uniform(generator);
214
+ }
215
+
216
+ for (int64_t i = 0; i < size - 15; i += 16) {
217
+ normal_fill_16<scalar_t>(data + i, mean, std);
218
+ }
219
+ if (size % 16 != 0) {
220
+ // Recompute the last 16 values.
221
+ data = data + size - 16;
222
+ for (const auto i : c10::irange(16)) {
223
+ at::uniform_real_distribution<scalar_t> uniform(0, 1);
224
+ data[i] = uniform(generator);
225
+ }
226
+ normal_fill_16<scalar_t>(data, mean, std);
227
+ }
228
+ }
229
+
230
+ template<typename RNG>
231
+ void normal_kernel(const TensorBase &self, double mean, double std, RNG generator) {
232
+ auto size = self.numel();
233
+ if (self.scalar_type() == ScalarType::Float && size >= 16 && self.is_contiguous()) {
234
+ #ifdef CPU_CAPABILITY_AVX2
235
+ normal_fill_AVX2(self, static_cast<float>(mean), static_cast<float>(std), generator);
236
+ #elif defined(__VSX__) || defined(CPU_CAPABILITY_VSX)
237
+ normal_fill_VSX(self, static_cast<float>(mean), static_cast<float>(std), generator);
238
+ #else
239
+ normal_fill(self, static_cast<float>(mean), static_cast<float>(std), generator);
240
+ #endif
241
+ } else {
242
+ AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, self.scalar_type(), "normal_kernel_cpu", [&] {
243
+ if (size >= 16 && self.is_contiguous()) {
244
+ normal_fill<scalar_t>(self, static_cast<scalar_t>(mean), static_cast<scalar_t>(std), generator);
245
+ } else {
246
+ auto iter = TensorIterator::borrowing_nullary_op(self);
247
+ std::lock_guard<std::mutex> lock(generator->mutex_);
248
+ cpu_serial_kernel(iter, [mean, std, generator]() -> scalar_t {
249
+ at::normal_distribution<double> normal(mean, std);
250
+ return static_cast<scalar_t>(normal(generator));
251
+ });
252
+ }
253
+ });
254
+ }
255
+ }
256
+
257
+ template<typename RNG>
258
+ struct NormalKernel {
259
+ void operator()(Tensor& self, double mean, double std, std::optional<Generator> gen) {
260
+ normal_kernel(self, mean, std, check_generator<RNG>(gen));
261
+ }
262
+ };
263
+
264
+ // ==================================================== Uniform =======================================================
265
+
266
+ template<typename RNG>
267
+ void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG generator) {
268
+ AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "uniform_kernel_cpu", [&]() {
269
+ std::lock_guard<std::mutex> lock(generator->mutex_);
270
+ auto from = static_cast<scalar_t>(from_);
271
+ auto to = static_cast<scalar_t>(to_);
272
+ at::uniform_real_distribution<scalar_t> uniform(from, to);
273
+ cpu_serial_kernel(iter, [&uniform, generator]() -> scalar_t {
274
+ return static_cast<scalar_t>(uniform(generator));
275
+ });
276
+ });
277
+ }
278
+
279
+ template<typename RNG>
280
+ struct UniformKernel {
281
+ void operator()(TensorIteratorBase& iter, double from, double to, std::optional<Generator> gen) {
282
+ uniform_kernel(iter, from, to, check_generator<RNG>(gen));
283
+ }
284
+ };
285
+
286
+ // ==================================================== Cauchy ========================================================
287
+
288
+ template<typename RNG>
289
+ void cauchy_kernel(TensorIteratorBase& iter, double median, double sigma, RNG generator) {
290
+ AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "cauchy_cpu", [&]() {
291
+ std::lock_guard<std::mutex> lock(generator->mutex_);
292
+ at::cauchy_distribution<double> cauchy(median, sigma);
293
+ cpu_serial_kernel(iter, [&cauchy, generator]() -> scalar_t {
294
+ return static_cast<scalar_t>(cauchy(generator));
295
+ });
296
+ });
297
+ }
298
+
299
+ template<typename RNG>
300
+ struct CauchyKernel {
301
+ void operator()(TensorIteratorBase& iter, double median, double sigma, std::optional<Generator> gen) {
302
+ cauchy_kernel(iter, median, sigma, check_generator<RNG>(gen));
303
+ }
304
+ };
305
+
306
+ // ================================================== LogNormal =======================================================
307
+
308
+ template<typename RNG>
309
+ void log_normal_kernel(TensorIteratorBase& iter, double mean, double std, RNG generator) {
310
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "log_normal_cpu", [&]() {
311
+ std::lock_guard<std::mutex> lock(generator->mutex_);
312
+ at::lognormal_distribution<double> logNormal(mean, std);
313
+ cpu_serial_kernel(iter, [&logNormal, generator]() -> scalar_t {
314
+ return static_cast<scalar_t>(logNormal(generator));
315
+ });
316
+ });
317
+ }
318
+
319
+ template<typename RNG>
320
+ struct LogNormalKernel {
321
+ void operator()(TensorIteratorBase& iter, double mean, double std, std::optional<Generator> gen) {
322
+ log_normal_kernel(iter, mean, std, check_generator<RNG>(gen));
323
+ }
324
+ };
325
+
326
+ // =================================================== Geometric ======================================================
327
+
328
+ template<typename RNG>
329
+ void geometric_kernel(TensorIteratorBase& iter, double p, RNG generator) {
330
+ AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_cpu", [&]() {
331
+ std::lock_guard<std::mutex> lock(generator->mutex_);
332
+ at::geometric_distribution<double> geometric(p);
333
+ cpu_serial_kernel(iter, [&geometric, generator]() -> scalar_t {
334
+ return static_cast<scalar_t>(geometric(generator));
335
+ });
336
+ });
337
+ }
338
+
339
+ template<typename RNG>
340
+ struct GeometricKernel {
341
+ void operator()(TensorIteratorBase& iter, double p, std::optional<Generator> gen) {
342
+ geometric_kernel(iter, p, check_generator<RNG>(gen));
343
+ }
344
+ };
345
+
346
+ // ================================================== Exponential =====================================================
347
+
348
+ template<typename RNG>
349
+ void exponential_kernel(TensorIteratorBase& iter, double lambda, RNG generator) {
350
+ TORCH_CHECK(isFloatingType(iter.dtype()), "Exponential distribution is a continuous probability distribution. dtype must be a floating point but you specified ", iter.dtype());
351
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_cpu", [&]() {
352
+ std::lock_guard<std::mutex> lock(generator->mutex_);
353
+ at::exponential_distribution<double> exponential(lambda);
354
+ cpu_serial_kernel(iter, [&exponential, generator]() -> scalar_t {
355
+ return static_cast<scalar_t>(exponential(generator));
356
+ });
357
+ });
358
+ }
359
+
360
+ template<typename RNG>
361
+ struct ExponentialKernel {
362
+ void operator()(TensorIteratorBase& iter, double lambda, std::optional<Generator> gen) {
363
+ exponential_kernel(iter, lambda, check_generator<RNG>(gen));
364
+ }
365
+ };
366
+
367
+ // ================================================== Bernoulli =======================================================
368
+
369
+ template<typename RNG>
370
+ void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG generator) {
371
+ AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half,
372
+ self.scalar_type(), "bernoulli_tensor_cpu_self_", [&] {
373
+ // See Note [Acquire lock when using random generators]
374
+ std::lock_guard<std::mutex> lock(generator->mutex_);
375
+ using self_t = scalar_t;
376
+ auto p_cpu = p_.to(kCPU);
377
+ auto p = expand_inplace(self, p_cpu);
378
+ auto iter = TensorIteratorConfig()
379
+ .add_output(self)
380
+ .add_const_input(*p)
381
+ .check_all_same_dtype(false)
382
+ .build();
383
+ if (p->scalar_type() == kDouble) {
384
+ cpu_serial_kernel(iter, [&](const double p_val) -> self_t {
385
+ at::bernoulli_distribution<double> bernoulli(p_val);
386
+ return static_cast<self_t>(bernoulli(generator));
387
+ });
388
+ } else {
389
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half,
390
+ p->scalar_type(), "bernoulli_tensor_cpu_p_", [&] {
391
+ using p_t = scalar_t;
392
+ cpu_serial_kernel(iter, [&](const p_t p_val) -> self_t {
393
+ at::bernoulli_distribution<float> bernoulli(p_val);
394
+ return static_cast<self_t>(bernoulli(generator));
395
+ });
396
+ });
397
+ }
398
+ });
399
+ }
400
+
401
+ template<typename RNG>
402
+ void bernoulli_kernel(const TensorBase &self, double p, RNG generator) {
403
+ AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half,
404
+ self.scalar_type(), "bernoulli_scalar_cpu_", [&] {
405
+ // See Note [Acquire lock when using random generators]
406
+ std::lock_guard<std::mutex> lock(generator->mutex_);
407
+ auto iter = TensorIterator::borrowing_nullary_op(self);
408
+ cpu_serial_kernel(iter, [p, generator]() -> scalar_t {
409
+ at::bernoulli_distribution<double> bernoulli(p);
410
+ return static_cast<scalar_t>(bernoulli(generator));
411
+ });
412
+ });
413
+ }
414
+
415
+ template<typename RNG>
416
+ struct BernoulliKernel {
417
+ void operator()(const TensorBase &self, double p, std::optional<Generator> gen) {
418
+ bernoulli_kernel(self, p, check_generator<RNG>(gen));
419
+ }
420
+ void operator()(const TensorBase &self, const TensorBase &p_, std::optional<Generator> gen) {
421
+ bernoulli_kernel(self, p_, check_generator<RNG>(gen));
422
+ }
423
+ };
424
+
425
+ }}
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/GridSamplerKernel.h ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/DispatchStub.h>
4
+
5
+ #include <array>
6
+ #include <cstdint>
7
+
8
+ namespace at {
9
+ class TensorBase;
10
+ }
11
+
12
+ namespace at::native {
13
+
14
+ using forward_2d_fn = void (*) (
15
+ const TensorBase &output,
16
+ const TensorBase &input,
17
+ const TensorBase &grid,
18
+ int64_t interpolation_mode,
19
+ int64_t padding_mode,
20
+ bool align_corners);
21
+ using backward_2d_fn = void (*) (
22
+ const TensorBase &grad_input,
23
+ const TensorBase &grad_grid,
24
+ const TensorBase &grad_output,
25
+ const TensorBase &input,
26
+ const TensorBase &grid,
27
+ int64_t interpolation_mode,
28
+ int64_t padding_mode,
29
+ bool align_corners,
30
+ std::array<bool, 2> output_mask);
31
+ DECLARE_DISPATCH(forward_2d_fn, grid_sampler_2d_cpu_kernel);
32
+ DECLARE_DISPATCH(backward_2d_fn, grid_sampler_2d_backward_cpu_kernel);
33
+
34
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/IndexKernelUtils.h ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/native/TensorIterator.h>
3
+ #include <c10/util/irange.h>
4
+
5
+ namespace at::native {
6
+
7
+ namespace {
8
+ static bool is_constant_index(int ntensor, const int64_t* strides) {
9
+ AT_ASSERT(ntensor >= 3);
10
+ for (const auto arg : c10::irange(2, ntensor)) {
11
+ if (strides[arg] != 0) {
12
+ return false;
13
+ }
14
+ }
15
+ return true;
16
+ }
17
+
18
+
19
+ struct Indexer {
20
+ Indexer(int64_t num_indexers, char** indexers, const int64_t* indexer_strides,
21
+ IntArrayRef original_sizes, IntArrayRef original_strides)
22
+ : num_indexers(num_indexers)
23
+ , indexers(indexers)
24
+ , indexer_strides(indexer_strides)
25
+ , original_strides(original_strides.data())
26
+ , original_sizes(original_sizes.data()) {
27
+ AT_ASSERT(static_cast<int64_t>(original_strides.size()) == num_indexers);
28
+ AT_ASSERT(static_cast<int64_t>(original_sizes.size()) == num_indexers);
29
+ }
30
+
31
+ int64_t num_indexers;
32
+ char** indexers;
33
+ const int64_t* indexer_strides;
34
+ const int64_t* original_strides;
35
+ const int64_t* original_sizes;
36
+
37
+ int64_t get(int64_t idx) {
38
+ int64_t offset = 0;
39
+ for (const auto j : c10::irange(num_indexers)) {
40
+ int64_t value = *(int64_t*)&indexers[j][idx * indexer_strides[j]];
41
+ int64_t size = original_sizes[j];
42
+ TORCH_CHECK_INDEX(value >= -size && value < size,
43
+ "index ", value, " is out of bounds for dimension ", j, " with size ", size);
44
+ if (value < 0) {
45
+ value += size;
46
+ }
47
+ offset += value * original_strides[j];
48
+ }
49
+ return offset;
50
+ }
51
+ };
52
+ } // anonymous namespace
53
+
54
+ template <typename scalar_t, typename func_t>
55
+ void cpu_index_kernel(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride,
56
+ const func_t& f, bool serial_execution=false)
57
+ {
58
+ int ntensor = iter.ntensors();
59
+ // When launch the index parallel version, set a relative small grain size less than the INTERNAL::GRAIN_SIZE
60
+ // to make the whole available thread numbers get more balanced work load and a better cache location.
61
+ // The grain size here is chosen by the op benchmark to overcome the thread launch overhead
62
+ const int index_parallel_grain_size = 3000;
63
+ auto loop = [&](char** data, const int64_t* strides, int64_t n) {
64
+ auto indexer = Indexer(ntensor - 2, &data[2], &strides[2], index_size, index_stride);
65
+ char* dst = data[0];
66
+ char* src = data[1];
67
+ if (is_constant_index(ntensor, strides)) {
68
+ // specialization for when every element uses the same index
69
+ int64_t offset = indexer.get(0);
70
+ for (const auto i : c10::irange(n)) {
71
+ f(dst + strides[0] * i, src + strides[1] * i, offset);
72
+ }
73
+ } else {
74
+ for (const auto i : c10::irange(n)) {
75
+ int64_t offset = indexer.get(i);
76
+ f(dst + strides[0] * i, src + strides[1] * i, offset);
77
+ }
78
+ }
79
+ };
80
+ if (serial_execution) {
81
+ iter.serial_for_each(loop, {0, iter.numel()});
82
+ } else {
83
+ iter.for_each(loop, index_parallel_grain_size);
84
+ }
85
+ }
86
+ } // at
87
+ // native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Intrinsics.h ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #if defined(__clang__) && (defined(__x86_64__) || defined(__i386__))
4
+ /* Clang-compatible compiler, targeting x86/x86-64 */
5
+ #include <x86intrin.h>
6
+ #elif defined(_MSC_VER)
7
+ /* Microsoft C/C++-compatible compiler */
8
+ #include <intrin.h>
9
+ #if _MSC_VER <= 1900
10
+ #define _mm256_extract_epi64(X, Y) (((uint64_t*)&X)[Y])
11
+ #endif
12
+ #elif defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__))
13
+ /* GCC-compatible compiler, targeting x86/x86-64 */
14
+ #include <x86intrin.h>
15
+ #elif defined(__GNUC__) && defined(__ARM_NEON__)
16
+ /* GCC-compatible compiler, targeting ARM with NEON */
17
+ #include <arm_neon.h>
18
+ #elif defined(__GNUC__) && defined(__IWMMXT__)
19
+ /* GCC-compatible compiler, targeting ARM with WMMX */
20
+ #include <mmintrin.h>
21
+ #elif (defined(__GNUC__) || defined(__xlC__)) && \
22
+ (defined(__VEC__) || defined(__ALTIVEC__))
23
+ /* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */
24
+ #include <altivec.h>
25
+ /* We need to undef those tokens defined by <altivec.h> to avoid conflicts
26
+ with the C++ types. => Can still use __bool/__vector */
27
+ #undef bool
28
+ #undef vector
29
+ #undef pixel
30
+ #elif defined(__GNUC__) && defined(__SPE__)
31
+ /* GCC-compatible compiler, targeting PowerPC with SPE */
32
+ #include <spe.h>
33
+ #endif
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/IsContiguous.h ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ namespace at::native { inline namespace CPU_CAPABILITY {
4
+
5
+ // n: number of function arguments (arity)
6
+ // traits: function_traits (see FunctionTraits.h)
7
+ // s: index of scalar argument or -1
8
+ template <int n, int stride_index, typename traits, int s=-1>
9
+ struct IsContiguous {
10
+ static bool eval(const int64_t* strides) {
11
+ using type = typename traits::template arg<n - 1>::type;
12
+ return strides[stride_index] == (s == n ? 0 : sizeof(type)) &&
13
+ IsContiguous<n - 1, stride_index - 1, traits, s>::eval(strides);
14
+ }
15
+ };
16
+
17
+ // will be called when there is an output exists
18
+ template <typename traits, int s>
19
+ struct IsContiguous<0, 0, traits, s> {
20
+ static bool eval(const int64_t* strides) {
21
+ return strides[0] == sizeof(typename traits::result_type);
22
+ }
23
+ };
24
+
25
+ // will be called when there is no output
26
+ template <typename traits, int s>
27
+ struct IsContiguous<0, -1, traits, s> {
28
+ static bool eval(const int64_t* /*strides*/) {
29
+ return true;
30
+ }
31
+ };
32
+
33
+ // output and all inputs are contiguous
34
+ template <typename traits,
35
+ typename std::enable_if<std::is_void<typename traits::result_type>::value>::type* = nullptr>
36
+ static inline bool is_contiguous(const int64_t* strides) {
37
+ return IsContiguous<traits::arity, traits::arity - 1, traits>::eval(strides);
38
+ }
39
+
40
+ template <typename traits,
41
+ typename std::enable_if<!std::is_void<typename traits::result_type>::value>::type* = nullptr>
42
+ static inline bool is_contiguous(const int64_t* strides) {
43
+ return IsContiguous<traits::arity, traits::arity, traits>::eval(strides);
44
+ }
45
+
46
+ // input at `s` is scalar (stride 0); output and other inputs are contiguous
47
+ // NB: output is typically at strides[0] so first input corresponds to s=1
48
+ template <typename traits, int s,
49
+ typename std::enable_if<std::is_void<typename traits::result_type>::value>::type* = nullptr>
50
+ static inline bool is_contiguous_scalar(const int64_t* strides) {
51
+ static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds");
52
+ return IsContiguous<traits::arity, traits::arity - 1, traits, s>::eval(strides);
53
+ }
54
+
55
+ template <typename traits, int s,
56
+ typename std::enable_if<!std::is_void<typename traits::result_type>::value>::type* = nullptr>
57
+ static inline bool is_contiguous_scalar(const int64_t* strides) {
58
+ static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds");
59
+ return IsContiguous<traits::arity, traits::arity, traits, s>::eval(strides);
60
+ }
61
+
62
+ }}
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/LogAddExp.h ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/util/complex.h>
4
+ #include <ATen/NumericUtils.h>
5
+
6
+ namespace at::native {
7
+ inline namespace CPU_CAPABILITY {
8
+
9
+ // custom min and max to be used in logcumsumexp for complex arguments
10
+ template <typename scalar_t>
11
+ std::pair<c10::complex<scalar_t>, c10::complex<scalar_t>> _logcumsumexp_minmax(c10::complex<scalar_t> x, c10::complex<scalar_t> y) {
12
+ if (at::_isnan(y)) { // either real is nan or imag is nan
13
+ return std::make_pair(y, y);
14
+ } else if (at::_isnan(x)) { // either real is nan or imag is nan
15
+ return std::make_pair(x, x);
16
+ } else {
17
+ return (x.real() < y.real()) ? std::make_pair(x, y) : std::make_pair(y, x);
18
+ }
19
+ }
20
+
21
+ template <typename scalar_t>
22
+ scalar_t _log_add_exp_helper(scalar_t x, scalar_t y) {
23
+ // Reference : https://www.tensorflow.org/api_docs/python/tf/math/cumulative_logsumexp
24
+ scalar_t min = at::_isnan(y) ? y : std::min(x, y); // std::min returns first arg if one of the args is nan
25
+ scalar_t max = at::_isnan(y) ? y : std::max(x, y); // std::max returns first arg if one of the args is nan
26
+ if (min != max || std::isfinite(min)) {
27
+ // nan will be propagated here
28
+ return std::log1p(std::exp(min - max)) + max;
29
+ } else {
30
+ // special case to correctly handle infinite cases
31
+ return x;
32
+ }
33
+ }
34
+
35
+ template <typename scalar_t>
36
+ c10::complex<scalar_t> _log_add_exp_helper(const c10::complex<scalar_t>& x, const c10::complex<scalar_t>& y) {
37
+ auto [min, max] = _logcumsumexp_minmax<scalar_t>(x, y);
38
+ auto min_real = std::real(min);
39
+ auto max_real = std::real(max);
40
+
41
+ if (at::_isnan(min)) { // either real is nan or imag is nan
42
+ // handling the "infectious" NaNs
43
+ return {std::numeric_limits<scalar_t>::quiet_NaN(), std::numeric_limits<scalar_t>::quiet_NaN()};
44
+ } else if (!std::isfinite(min_real) && (min_real == max_real)) {
45
+ if (min_real < 0) {
46
+ // handle the -inf case, the imaginary part here does not really matter as the exp(value)
47
+ // will be around 0.0 and the angle (i.e. the imaginary part) cannot be determined.
48
+ // It does not matter if we're taking the exp of this value
49
+ return min;
50
+ } else {
51
+ // handle the +inf case, we don't need the special precision for log1p for small values
52
+ // and to avoid producing nan in case of real(max) == real(min) == +inf
53
+ return std::log(std::exp(min) + std::exp(max));
54
+ }
55
+ } else {
56
+ return std::log1p(std::exp(min - max)) + max;
57
+ }
58
+ }
59
+
60
+ } // end namespace
61
+ } //end at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/MaxUnpoolKernel.h ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/native/DispatchStub.h>
3
+
4
+ namespace at {
5
+ class Tensor;
6
+
7
+ namespace native {
8
+
9
+ using max_unpooling_fn = void(*)(Tensor&, const Tensor&, const Tensor&);
10
+
11
+ DECLARE_DISPATCH(max_unpooling_fn, max_unpool2d_kernel);
12
+ DECLARE_DISPATCH(max_unpooling_fn, max_unpool3d_kernel);
13
+
14
+ }} // at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/PixelShuffleKernel.h ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/native/DispatchStub.h>
3
+
4
+ namespace at {
5
+ class TensorBase;
6
+ }
7
+
8
+ namespace at::native {
9
+
10
+ using pixel_shuffle_fn = void(*)(TensorBase&, const TensorBase&, int64_t);
11
+ DECLARE_DISPATCH(pixel_shuffle_fn, pixel_shuffle_kernel);
12
+ DECLARE_DISPATCH(pixel_shuffle_fn, pixel_unshuffle_kernel);
13
+
14
+ } // at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Reduce.h ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/cpu/Loops.h>
4
+ #include <ATen/Parallel.h>
5
+ #include <c10/util/TypeList.h>
6
+ #include <c10/core/Scalar.h>
7
+ #include <c10/util/irange.h>
8
+
9
+ #include <sstream>
10
+ #include <type_traits>
11
+
12
+ namespace at { namespace native { inline namespace CPU_CAPABILITY {
13
+
14
+ using namespace vec;
15
+
16
+ #define VEC_LOOP_HEADER(func_t, data) \
17
+ using scalar_t = typename function_traits<func_t>::result_type; \
18
+ using Vec = Vectorized<scalar_t>; \
19
+ char* out_ptr = data[0]; \
20
+ (void) out_ptr;
21
+
22
+ // reduction that is contiguous over the input in dim 0
23
+ template <typename traits>
24
+ inline bool is_contiguous_reduction(const int64_t* strides) {
25
+ return strides[0] == 0 &&
26
+ strides[1] == sizeof(typename traits::arg2_t);
27
+ }
28
+
29
+ // reduction that is contiguous over the input in dim 1
30
+ template <typename traits>
31
+ inline bool is_outer_reduction(const int64_t* strides) {
32
+ return strides[0] == 0 &&
33
+ strides[2] == sizeof(typename traits::result_type) &&
34
+ strides[3] == sizeof(typename traits::arg2_t);
35
+ }
36
+
37
+ template <typename func_t, typename vec_func_t>
38
+ inline void vectorized_reduction(char** data, int64_t n, int64_t stride,
39
+ func_t op, vec_func_t vop, bool reduce) {
40
+ VEC_LOOP_HEADER(func_t, data)
41
+ const char* in1_ptr = data[1];
42
+ Vec acc[4];
43
+ for (const auto j : c10::irange(4)) {
44
+ acc[j] = Vec::loadu(in1_ptr + j * Vec::size() * sizeof(scalar_t));
45
+ }
46
+ for (const auto i : c10::irange(1, n)) {
47
+ const char* ptr = in1_ptr + stride * i;
48
+ acc[0] = vop(acc[0], Vec::loadu(ptr + (0 * Vec::size() * sizeof(scalar_t))));
49
+ acc[1] = vop(acc[1], Vec::loadu(ptr + (1 * Vec::size() * sizeof(scalar_t))));
50
+ acc[2] = vop(acc[2], Vec::loadu(ptr + (2 * Vec::size() * sizeof(scalar_t))));
51
+ acc[3] = vop(acc[3], Vec::loadu(ptr + (3 * Vec::size() * sizeof(scalar_t))));
52
+ }
53
+ if (reduce) {
54
+ scalar_t buffer[Vec::size()];
55
+ acc[0] = vop(vop(acc[0], acc[1]), vop(acc[2], acc[3]));
56
+ acc[0].store(buffer);
57
+ for (const auto j : c10::irange(1, Vec::size())) {
58
+ buffer[0] = op(buffer[0], buffer[j]);
59
+ }
60
+ auto dst = (scalar_t*)out_ptr;
61
+ *dst = op(*dst, buffer[0]);
62
+ } else {
63
+ for (const auto j : c10::irange(4)) {
64
+ auto dst = out_ptr + j * Vec::size() * sizeof(scalar_t);
65
+ acc[j] = vop(acc[j], Vec::loadu(dst));
66
+ acc[j].store(dst);
67
+ }
68
+ }
69
+ }
70
+
71
+ template <typename F>
72
+ inline void UNARY_OUTER_LOOP(char* data[2], const int64_t strides[2], int64_t n, F f) {
73
+ for (const auto j C10_UNUSED : c10::irange(n)) {
74
+ f();
75
+ data[0] += strides[0];
76
+ data[1] += strides[1];
77
+ }
78
+ }
79
+
80
+ // computes the reduction out = op(out, in)
81
+ template <typename func_t, typename vec_func_t>
82
+ inline void vectorized_inner_reduction(char** data, int64_t n, func_t op, vec_func_t vop) {
83
+ VEC_LOOP_HEADER(func_t, data)
84
+ int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t);
85
+ int64_t count = n / (4 * Vec::size());
86
+ if (count > 0) {
87
+ vectorized_reduction(data, count, vector_stride, op, vop, /*reduce=*/true);
88
+ }
89
+ char* ptrs[3] = { data[0], data[0], data[1] };
90
+ int64_t strides[] = { 0, 0, sizeof(scalar_t) };
91
+ basic_loop(ptrs, strides, count * 4 * Vec::size(), n, op);
92
+ }
93
+
94
+ // computes the reduction out = op(out, in)
95
+ template <typename func_t, typename vec_func_t>
96
+ inline void vectorized_outer_reduction(char** data, int64_t inner_stride, int64_t size0, int64_t size1, func_t op, vec_func_t vop) {
97
+ VEC_LOOP_HEADER(func_t, data)
98
+
99
+ // reduce down each column of 4 * Vec::size() elements (128 or 256 bytes)
100
+ #if defined(CPU_CAPABILITY_AVX512)
101
+ int64_t outer_stride[2] = { 256, 256 };
102
+ #else
103
+ int64_t outer_stride[2] = { 128, 128 };
104
+ #endif
105
+ UNARY_OUTER_LOOP(data, outer_stride, size1 / (4 * Vec::size()), [&] {
106
+ vectorized_reduction(data, size0, inner_stride, op, vop, /*reduce=*/false);
107
+ });
108
+
109
+ // reduce down the remaining columns
110
+ int64_t step[] = { sizeof(scalar_t), sizeof(scalar_t) };
111
+ int64_t remaining = size1 % (4 * Vec::size());
112
+ UNARY_OUTER_LOOP(data, step, remaining, [&] {
113
+ char* ptrs[3] = { data[0], data[0], data[1] };
114
+ int64_t strides[] = { 0, 0, inner_stride };
115
+ basic_loop(ptrs, strides, 0, size0, op);
116
+ });
117
+ }
118
+
119
+ template<typename traits, typename res_t>
120
+ static void set_result(const int index, const res_t result, const TensorIteratorBase &iter, const int num_outputs) {
121
+ // static_assert(std::is_same<res_t, typename traits::arg2_t>::value, "data types must match");
122
+ if (index < num_outputs) {
123
+ char *out = (char *) iter.data_ptr(index);
124
+ *(res_t *) out = result;
125
+ }
126
+ }
127
+
128
+ template<typename traits, typename res_t>
129
+ static void set_results(const res_t result, const TensorIteratorBase &iter, const int num_outputs) {
130
+ AT_ASSERT(num_outputs == 1);
131
+ set_result<traits>(0, result, iter, num_outputs);
132
+ }
133
+
134
+ template<typename traits, std::size_t i = 0, typename... tuple_t>
135
+ inline typename std::enable_if<i == sizeof...(tuple_t), std::size_t>::type
136
+ for_each_in_tuple(const std::tuple<tuple_t...>& /*t*/, const TensorIteratorBase& /*iter*/, const int /*num_outputs*/) {
137
+ return i;
138
+ }
139
+
140
+ template<typename traits, std::size_t i = 0, typename... tuple_t>
141
+ inline typename std::enable_if<i < sizeof...(tuple_t), std::size_t>::type
142
+ for_each_in_tuple(const std::tuple<tuple_t...>& t, const TensorIteratorBase &iter, const int num_outputs) {
143
+ if (i < (size_t)num_outputs) {
144
+ set_result<traits>(i, std::get<i>(t), iter, num_outputs);
145
+ return for_each_in_tuple<traits, i + 1, tuple_t...>(t, iter, num_outputs);
146
+ }
147
+ return i;
148
+ }
149
+
150
+ template<typename traits, typename... res_t>
151
+ static void set_results(const std::tuple<res_t...>& result, const TensorIteratorBase &iter, const int num_outputs) {
152
+ AT_ASSERT(num_outputs >= 1);
153
+ std::size_t result_size = for_each_in_tuple<traits>(result, iter, num_outputs);
154
+ AT_ASSERT((size_t)num_outputs == result_size);
155
+ }
156
+
157
+ template <typename T, typename... Args>
158
+ struct all_same : std::conjunction<
159
+ std::is_same<T, Args>...
160
+ > {};
161
+
162
+ // data_t is the input/output data type.
163
+ // acc_t is a type that contains all the necessary data
164
+ // to continue reducing.
165
+ // index_t is a one-dimensional index
166
+ //
167
+ // ops_t is such that &ops_t::reduce, &ops_t::combine, and &ops_t::project exist and satisfy
168
+ // the following.
169
+ // reduce: (acc_t, data_t, index_t) -> acc_t adds one data point to the accumulated value.
170
+ // combine: (acc_t, acc_t) -> acc_t combines two accumulated values into one.
171
+ // project: acc_t -> out_t finishes the reduction, getting the required output.
172
+ //
173
+ // Additionally, acc_t must be default-constructible:
174
+ // acc_t {} is an identity for combine,
175
+ // and project(acc_t {}) is the value of the operation on zero elements.
176
+ //
177
+ // The point of `combine` is to support parallelization -
178
+ // the idea is to one sequence of `reduce` calls per thread of execution,
179
+ // and then to combine them at the end with `combine`.
180
+ //
181
+ // If there is more than one output element,
182
+ // our parallelization strategy is to use one thread for each of them,
183
+ // which means that `combine` will never be called.
184
+ //
185
+ // If, on the other hand, there is only one, then we split the input into
186
+ // into several pieces, reduce each separately, and then combine them.
187
+
188
+ template <typename ops_t, typename init_t>
189
+ void binary_kernel_reduce(TensorIteratorBase& iter, ops_t ops, init_t init) {
190
+ using rf_t = decltype(&ops_t::reduce);
191
+ using cf_t = decltype(&ops_t::combine);
192
+ using pf_t = decltype(&ops_t::project);
193
+ using r_traits = binary_function_traits<rf_t>;
194
+ using c_traits = binary_function_traits<cf_t>;
195
+ using p_traits = unary_function_traits<pf_t>;
196
+ using acc_t = typename p_traits::arg1_t;
197
+ using data_t = typename r_traits::arg2_t;
198
+ static_assert(
199
+ all_same<
200
+ acc_t,
201
+ init_t,
202
+ typename r_traits::arg1_t,
203
+ typename r_traits::result_type,
204
+ typename c_traits::arg1_t,
205
+ typename c_traits::arg2_t,
206
+ typename c_traits::result_type>::value,
207
+ "all accumulate types must match");
208
+ static_assert(
209
+ std::is_default_constructible<acc_t>::value,
210
+ "the accumulate type must be default-constructible"
211
+ );
212
+ const int num_outputs = iter.noutputs();
213
+ iter.foreach_reduced_elt([&ops, &init, num_outputs](TensorIteratorBase &sub_iter) {
214
+ auto reduction_body = [&ops, &sub_iter, num_outputs](acc_t acc, int64_t begin, int64_t end) -> acc_t {
215
+ int ntensors = sub_iter.ntensors();
216
+ sub_iter.serial_for_each([&acc, &ops, num_outputs, ntensors, begin](char** data, const int64_t* strides, int64_t size) {
217
+ AT_ASSERT(ntensors - num_outputs == 1);
218
+ char *in = data[ntensors - 1];
219
+ int64_t stride = strides[ntensors - 1];
220
+ for (const auto i : c10::irange(size)) {
221
+ acc = ops.reduce(acc, c10::load<data_t>(in), begin + i);
222
+ in += stride;
223
+ }
224
+ }, {begin, end});
225
+ return ops.translate_idx(acc, sub_iter.view_offsets()[0]);
226
+ };
227
+ acc_t total_acc = init;
228
+ auto numel = sub_iter.numel();
229
+ if (numel < at::internal::GRAIN_SIZE || at::get_num_threads() == 1 ||
230
+ at::in_parallel_region()) {
231
+ total_acc = reduction_body(total_acc, 0, numel);
232
+ } else {
233
+ int max_threads = at::get_num_threads();
234
+ AT_ASSERT(max_threads > 0);
235
+ static_assert(
236
+ !std::is_same<acc_t, bool>::value,
237
+ "Concurrently modifying different references into std::vector<bool> is UB."
238
+ );
239
+ std::vector<acc_t> buffer((unsigned)max_threads, init);
240
+ at::parallel_for(0, numel, internal::GRAIN_SIZE,
241
+ [&](int64_t begin, int64_t end) {
242
+ auto& acc = buffer[at::get_thread_num()];
243
+ acc = reduction_body(acc, begin, end);
244
+ }
245
+ );
246
+ for (const auto i : c10::irange(max_threads)) {
247
+ total_acc = ops.combine(total_acc, buffer[i]);
248
+ }
249
+ }
250
+ set_results<r_traits>(ops.project(total_acc), sub_iter, num_outputs);
251
+ });
252
+ }
253
+
254
+ template <typename func_t, typename vec_func_t>
255
+ void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, double ident = 0) {
256
+ using traits = binary_function_traits<func_t>;
257
+ static_assert(
258
+ all_same<
259
+ typename traits::result_type,
260
+ typename traits::arg1_t,
261
+ typename traits::arg2_t>::value,
262
+ "all types must match");
263
+
264
+ iter.output_base().fill_(ident);
265
+ iter.parallel_reduce([&](char** data, const int64_t* strides, int64_t size0, int64_t size1) {
266
+ int64_t outer_strides[] = { strides[2], strides[3] };
267
+ if (is_contiguous_reduction<traits>(strides)) {
268
+ // input is contiguous in dim 0, output is reduced in dim 0
269
+ UNARY_OUTER_LOOP(data, outer_strides, size1, [&] {
270
+ vectorized_inner_reduction(data, size0, op, vop);
271
+ });
272
+ } else if (is_outer_reduction<traits>(strides)) {
273
+ // input and output are contiguous in dim 1
274
+ int64_t inner_stride = strides[1]; // stride of input in dim 0
275
+ vectorized_outer_reduction(data, inner_stride, size0, size1, op, vop);
276
+ } else {
277
+ UNARY_OUTER_LOOP(data, outer_strides, size1, [&] {
278
+ char* ptrs[3] = { data[0], data[0], data[1] };
279
+ int64_t inner_strides[3] = { strides[0], strides[0], strides[1] };
280
+ basic_loop(ptrs, inner_strides, 0, size0, op);
281
+ });
282
+ }
283
+ });
284
+ }
285
+
286
+ // when reduction is on most inner dimension (dim 0 in TensorIterator)
287
+ // and input has contiguous most inner dimension, `binary_kernel_reduce_lastdim`
288
+ // can be used.
289
+ inline bool is_reduce_lastdim(TensorIteratorBase& iter) {
290
+ return iter.num_reduce_dims() == 1 && iter.is_dim_reduced(0)
291
+ && iter.ninputs() == 1 && iter.strides(1)[0] == iter.element_size(1);
292
+ }
293
+
294
+ template <typename reduce_func_t>
295
+ void binary_kernel_reduce_lastdim(TensorIteratorBase& iter, reduce_func_t reduce_op) {
296
+ auto shape = iter.shape();
297
+ int64_t dim_size = shape[0];
298
+ int64_t grain_size = std::max((int64_t) 1, at::internal::GRAIN_SIZE / dim_size);
299
+ TensorIterator sub_iter(iter);
300
+ // create sub iterator to parallel on all non-reduce-dims
301
+ sub_iter.narrow(0, 0, 1);
302
+ auto loop = [&](char** data, const int64_t* strides, int64_t size) {
303
+ char* out = data[0];
304
+ char* in = data[1];
305
+ for (int64_t i = 0; i < size; ++i) {
306
+ reduce_op(out, in, dim_size);
307
+ out += strides[0];
308
+ in += strides[1];
309
+ }
310
+ };
311
+ sub_iter.for_each(loop, grain_size);
312
+ }
313
+
314
+ }}} // namespace at::native::<anonymous>
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SampledAddmmKernel.h ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/native/DispatchStub.h>
5
+
6
+ namespace at::native {
7
+
8
+ using sampled_addmm_sparse_csr_fn = void(*)(const Tensor&, const Tensor&, const Scalar&, const Scalar&, const Tensor&);
9
+
10
+ DECLARE_DISPATCH(sampled_addmm_sparse_csr_fn, sampled_addmm_sparse_csr_stub);
11
+
12
+ } // at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SerialStackImpl.h ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2004-present Facebook. All Rights Reserved.
2
+ #pragma once
3
+
4
+ #include <ATen/core/Tensor.h>
5
+
6
+ #include <ATen/MemoryOverlap.h>
7
+ #include <ATen/Parallel.h>
8
+ #include <ATen/TensorIterator.h>
9
+ #include <ATen/cpu/vec/functional.h>
10
+ #include <ATen/cpu/vec/vec.h>
11
+ #include <c10/util/irange.h>
12
+
13
+ namespace at::native::detail {
14
+
15
+ struct InputMeta {
16
+ void* data_ptr;
17
+ int64_t inner_size;
18
+
19
+ InputMeta(const Tensor& t, int64_t dim, int64_t inner)
20
+ : data_ptr(t.data_ptr()), inner_size(t.sizes()[dim] * inner) {}
21
+ };
22
+
23
+ // This kernel is used by two TensorList types:
24
+ // 1. stack_serial_kernel uses at::ArrayRef<Tensor>
25
+ // 2. Static runtime calls this kernel directly (csrc/jit/runtime/static/ops.cpp) with
26
+ // ProcessedNodeInputWrapper.
27
+ // When making changes, make sure that they are compatible with both types!
28
+ template <typename scalar_t, typename TensorListType>
29
+ void stack_serial_kernel_impl(Tensor& result, TensorListType tensors, int64_t dim) {
30
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
31
+ dim >= 0 && dim <= result.dim(),
32
+ "dim out of range in stack_serial_kernel_impl");
33
+ int64_t outer =
34
+ result.numel() / (result.sizes()[dim] * result.strides()[dim]);
35
+ scalar_t* result_data = result.data_ptr<scalar_t>();
36
+ int64_t ninputs = tensors.size();
37
+ std::vector<InputMeta> inputs;
38
+ inputs.reserve(ninputs);
39
+ for (const auto& tensor : tensors) {
40
+ inputs.emplace_back(tensor, dim, tensor.strides()[dim]);
41
+ }
42
+
43
+ using Vec = vec::Vectorized<scalar_t>;
44
+ scalar_t* result_ptr = result_data;
45
+ for (const auto i : c10::irange(outer)) {
46
+ for (const auto j : c10::irange(ninputs)) {
47
+ int64_t local_inner = inputs[j].inner_size;
48
+ scalar_t* input_ptr = (scalar_t*)(inputs[j].data_ptr) + i * local_inner;
49
+
50
+ if (local_inner < Vec::size()) {
51
+ for (const auto k : c10::irange(local_inner)) {
52
+ result_ptr[k] = input_ptr[k];
53
+ }
54
+ } else {
55
+ vec::map(
56
+ [](Vec x) { return x; }, result_ptr, input_ptr, local_inner);
57
+ }
58
+ result_ptr += local_inner;
59
+ }
60
+ }
61
+ }
62
+
63
+ // Checks to see whether native stack can be invoked under these conditions:
64
+ // - result and input tensors are contiguous
65
+ // - only one thread is used
66
+ // - no type promotion has to occur
67
+ // - tensors dtype is Double or Float
68
+ template <typename TensorListType>
69
+ bool can_use_native_serial_stack_impl(Tensor& result, TensorListType tensors, int64_t dim) {
70
+ TORCH_CHECK(tensors.size() > 0, "expected a non-empty list of Tensors");
71
+ const Tensor& first_tensor = tensors[0];
72
+ // stack dimension should be in range [0,firstTensor.dim())
73
+ // dim == firstTensor.dim() is a valid input, but it is handled by default code path
74
+ // that uses unsqueeze
75
+ if (dim >= first_tensor.dim()) return false;
76
+ // Native stack doesn't apply any tensor is skipped.
77
+ if (first_tensor.numel() == 0 && first_tensor.dim() == 1) return false;
78
+ // there should be no type promotion
79
+ if (result.dtype() != first_tensor.dtype()) return false;
80
+
81
+ auto first_tensor_mem_format = first_tensor.suggest_memory_format();
82
+ ScalarType dtype = first_tensor.scalar_type();
83
+
84
+ if (!result.is_contiguous(first_tensor_mem_format)) {
85
+ return false;
86
+ }
87
+
88
+ // fast path only works for Double and Float
89
+ if (dtype != ScalarType::Double && dtype != ScalarType::Float) {
90
+ return false;
91
+ }
92
+
93
+ // check remainder of inputs
94
+ #ifndef STRIP_ERROR_MESSAGES
95
+ auto const &first_tensor_shape = first_tensor.sizes();
96
+ #endif
97
+ for (const auto i : c10::irange(1, tensors.size())) {
98
+ auto const &tensor = tensors[i];
99
+ TORCH_CHECK(tensors[i].sizes() == first_tensor.sizes(),
100
+ "stack expects each tensor to be equal size, but got ", first_tensor_shape,
101
+ " at entry 0 and ", tensor.sizes(), " at entry ", i);
102
+
103
+ // every tensor must be contiguous
104
+ // tensor sizes and strides must be the same
105
+ // there should be no type promotion
106
+ if (!tensor.is_contiguous(first_tensor_mem_format) ||
107
+ tensor.strides() != first_tensor.strides() ||
108
+ tensor.dtype() != dtype) {
109
+ return false;
110
+ }
111
+ }
112
+
113
+ // fast native stack should only be used when it is not worth using multiple threads
114
+ // or there is only one thread. Note that we aren't checking result.numel() here because
115
+ // it may not have been resized and we want to defer that cost till later.
116
+ int64_t numel_in_stack = first_tensor.numel() * tensors.size();
117
+ return numel_in_stack < at::internal::GRAIN_SIZE || at::get_num_threads() == 1;
118
+ }
119
+
120
+ template <typename TensorListType, bool should_skip_overlap_check>
121
+ struct CanUseNativeSerialStack;
122
+
123
+ template <typename TensorListType>
124
+ struct CanUseNativeSerialStack<TensorListType, false> {
125
+ static bool call(Tensor& result, TensorListType tensors, int64_t dim) {
126
+ // Inputs cannot alias the output tensor
127
+ for (const auto i : c10::irange(tensors.size())) {
128
+ auto lap = at::get_overlap_status(result, tensors[i]);
129
+ TORCH_CHECK(lap != at::MemOverlapStatus::Partial &&
130
+ lap != at::MemOverlapStatus::Full, 0,
131
+ "unsupported operation: the input tensors cannot refer to any of the "
132
+ "output memory locations. Found overlap in input tensor ", i);
133
+ }
134
+
135
+ return can_use_native_serial_stack_impl(result, tensors, dim);
136
+ }
137
+ };
138
+
139
+ template <typename TensorListType>
140
+ struct CanUseNativeSerialStack<TensorListType, true> {
141
+ static bool call(Tensor& result, TensorListType tensors, int64_t dim) {
142
+ return can_use_native_serial_stack_impl(result, tensors, dim);
143
+ }
144
+ };
145
+
146
+ } // namespace at::native::detail
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SoftmaxKernel.h ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/DispatchStub.h>
4
+ #include <cstdint>
5
+
6
+ namespace at {
7
+ class Tensor;
8
+
9
+ namespace native {
10
+
11
+ using forward_fn = void (*)(const Tensor&, const Tensor&);
12
+ using backward_fn = void(*)(const Tensor &, const Tensor &, const Tensor&);
13
+
14
+ DECLARE_DISPATCH(forward_fn, softmax_lastdim_kernel);
15
+ DECLARE_DISPATCH(forward_fn, log_softmax_lastdim_kernel);
16
+ DECLARE_DISPATCH(backward_fn, softmax_backward_lastdim_kernel);
17
+ DECLARE_DISPATCH(backward_fn, log_softmax_backward_lastdim_kernel);
18
+
19
+ using forward_fn_with_dim = void(*)(const Tensor &, const Tensor &, const int64_t);
20
+ using backward_fn_with_dim =
21
+ void (*)(const Tensor&, const Tensor&, const Tensor&, const int64_t);
22
+
23
+ DECLARE_DISPATCH(forward_fn_with_dim, softmax_kernel);
24
+ DECLARE_DISPATCH(forward_fn_with_dim, log_softmax_kernel);
25
+ DECLARE_DISPATCH(backward_fn_with_dim, softmax_backward_kernel);
26
+ DECLARE_DISPATCH(backward_fn_with_dim, log_softmax_backward_kernel);
27
+ }
28
+ }
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SpmmReduceKernel.h ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/native/DispatchStub.h>
5
+ #include <ATen/native/ReductionType.h>
6
+
7
+ namespace at::native {
8
+
9
+ using spmm_reduce_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
10
+ using spmm_reduce_arg_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
11
+ using spmm_reduce_backward_input_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
12
+ using spmm_reduce_backward_input_arg_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
13
+ using spmm_reduce_backward_other_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
14
+
15
+ DECLARE_DISPATCH(spmm_reduce_fn, spmm_reduce_stub);
16
+ DECLARE_DISPATCH(spmm_reduce_arg_fn, spmm_reduce_arg_stub);
17
+ DECLARE_DISPATCH(spmm_reduce_backward_input_fn, spmm_reduce_backward_input_stub);
18
+ DECLARE_DISPATCH(spmm_reduce_backward_input_arg_fn, spmm_reduce_backward_input_arg_stub);
19
+ DECLARE_DISPATCH(spmm_reduce_backward_other_fn, spmm_reduce_backward_other_stub);
20
+ DECLARE_DISPATCH(spmm_reduce_backward_input_arg_fn, spmm_reduce_backward_other_arg_stub);
21
+
22
+ } // at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/StackKernel.h ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2004-present Facebook. All Rights Reserved.
2
+ #pragma once
3
+
4
+ #include <ATen/core/Tensor.h>
5
+ #include <ATen/native/DispatchStub.h>
6
+
7
+ namespace at::native {
8
+
9
+ using stack_serial_fn = void(*)(Tensor &, TensorList, int64_t);
10
+ DECLARE_DISPATCH(stack_serial_fn, stack_serial_stub);
11
+
12
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/UpSampleKernelAVXAntialias.h ADDED
@@ -0,0 +1,1376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ The Python Imaging Library (PIL) is
3
+
4
+ Copyright © 1997-2011 by Secret Labs AB
5
+ Copyright © 1995-2011 by Fredrik Lundh
6
+
7
+ Pillow is the friendly PIL fork. It is
8
+
9
+ Copyright © 2010-2022 by Alex Clark and contributors
10
+
11
+ Like PIL, Pillow is licensed under the open source HPND License
12
+ */
13
+
14
+ // This code is heavily inspired from PILLOW-SIMD's implementation:
15
+ // https://github.com/uploadcare/pillow-simd/blob/simd/master/src/libImaging/Resample.c
16
+
17
+ #pragma once
18
+ #ifdef CPU_CAPABILITY_AVX2
19
+ // TODO: This file only supports AVX2. We could split the AVX kernels into
20
+ // smaller logical blocks in order to port them into the Vec.h logic. This would
21
+ // allow to support other vectorization architectures and perhaps also support
22
+ // the non-vectorized fallback (we'd need to make sure it's not slower than the
23
+ // current fallback).
24
+
25
+ #include <ATen/core/Tensor.h>
26
+ #include <ATen/cpu/vec/intrinsics.h>
27
+ #include <c10/util/irange.h>
28
+
29
+ #ifndef AT_PER_OPERATOR_HEADERS
30
+ #include <ATen/Functions.h>
31
+ #else
32
+ #include <ATen/ops/empty.h>
33
+ #endif
34
+
35
+
36
+ namespace {
37
+
38
+ static inline __m128i mm_cvtsi32_si128(const uint8_t* C10_RESTRICT ptr, bool i32_aligned) {
39
+ int32_t v;
40
+ if (i32_aligned) {
41
+ v = *(const int32_t*)ptr;
42
+ } else {
43
+ std::memcpy(&v, ptr, 4);
44
+ }
45
+ return _mm_cvtsi32_si128(v);
46
+ }
47
+
48
+ static inline __m128i mm_cvtepu8_epi32(const uint8_t* C10_RESTRICT ptr, bool i32_aligned) {
49
+ return _mm_cvtepu8_epi32(mm_cvtsi32_si128(ptr, i32_aligned));
50
+ }
51
+
52
+ static inline void _write_endline_rgb_as_uint32(
53
+ uint8_t* C10_RESTRICT output,
54
+ uint32_t data
55
+ ) {
56
+ // data is (R G B X), output is (X1 X2 X3 | R1 B1 G1 R2 ...)
57
+ // Here we explicitly set X as R1
58
+ uint8_t* data_ptr = reinterpret_cast<uint8_t*>(&data);
59
+ data_ptr[3] = output[3];
60
+ std::memcpy(output, data_ptr, 4);
61
+ }
62
+
63
+ at::Tensor unpack_rgb(const at::Tensor& packed_tensor) {
64
+ // Convert a "packed" tensor (typically RGBRGBRGB if channels_last) into
65
+ // RGBARGBARGBA format where A is hard-coded to 0. Each pixel is encoded
66
+ // into as 32 bits. This generalizes to num_channels <= 4 and also works for
67
+ // non-channels_last tensors.
68
+
69
+ const uint8_t* packed = (const uint8_t*)packed_tensor.const_data_ptr<uint8_t>();
70
+ auto num_pixels = packed_tensor.size(1) * packed_tensor.size(2);
71
+ auto num_channels = packed_tensor.size(0);
72
+
73
+ constexpr int rgba_size = 4;
74
+ auto unpacked_tensor = at::empty({rgba_size, packed_tensor.size(1), packed_tensor.size(2)}, at::CPU(at::kByte));
75
+ uint8_t* unpacked = (uint8_t*) unpacked_tensor.data_ptr<uint8_t>();
76
+
77
+ auto stride_i = packed_tensor.stride(2);
78
+ auto stride_j = packed_tensor.stride(0);
79
+
80
+ for (const auto i : c10::irange(num_pixels)) {
81
+ for (const auto j : c10::irange(rgba_size)) {
82
+ unpacked[rgba_size * i + j] = (j < num_channels) ? packed[stride_i * i + stride_j * j] : 0;
83
+ }
84
+ }
85
+ return unpacked_tensor;
86
+ }
87
+
88
+ void pack_rgb(
89
+ const at::Tensor& unpacked_tensor, // IN
90
+ const at::Tensor& packed_tensor // OUT
91
+ ) {
92
+ // Convert from unpacked channels last 3-channels or 4-channels tensor into original data layout.
93
+
94
+ uint8_t* unpacked = (uint8_t*)unpacked_tensor.data_ptr<uint8_t>();
95
+ uint8_t* packed = (uint8_t*)packed_tensor.data_ptr<uint8_t>();
96
+ auto num_pixels = packed_tensor.size(1) * packed_tensor.size(2);
97
+ auto num_channels = packed_tensor.size(0);
98
+
99
+ auto unpacked_increment = unpacked_tensor.size(0);
100
+ auto packed_increment = packed_tensor.stride(2);
101
+ auto packed_stride = packed_tensor.stride(0);
102
+
103
+ TORCH_INTERNAL_ASSERT(unpacked_increment == 3 || unpacked_increment == 4);
104
+
105
+ for (const auto i C10_UNUSED : c10::irange(num_pixels)) {
106
+ for (const auto j : c10::irange(num_channels)) {
107
+ packed[j * packed_stride] = unpacked[j];
108
+ }
109
+ unpacked += unpacked_increment;
110
+ packed += packed_increment;
111
+ }
112
+ }
113
+
114
+ void ImagingResampleHorizontalConvolution8u4x(
115
+ uint8_t* C10_RESTRICT lineOut0,
116
+ uint8_t* C10_RESTRICT lineOut1,
117
+ uint8_t* C10_RESTRICT lineOut2,
118
+ uint8_t* C10_RESTRICT lineOut3,
119
+ int64_t out_xsize,
120
+ const uint8_t* C10_RESTRICT lineIn0,
121
+ const uint8_t* C10_RESTRICT lineIn1,
122
+ const uint8_t* C10_RESTRICT lineIn2,
123
+ const uint8_t* C10_RESTRICT lineIn3,
124
+ int64_t in_xsize,
125
+ const int64_t* idx_ptr_xmin,
126
+ const int64_t* idx_ptr_size,
127
+ const int16_t* kk,
128
+ int kmax,
129
+ unsigned int coefs_precision,
130
+ int64_t num_channels,
131
+ bool is_last_line);
132
+
133
+ void ImagingResampleHorizontalConvolution8u(
134
+ uint8_t* C10_RESTRICT lineOut,
135
+ int64_t out_xsize,
136
+ const uint8_t* C10_RESTRICT lineIn,
137
+ int64_t in_xsize,
138
+ const int64_t* idx_ptr_xmin,
139
+ const int64_t* idx_ptr_size,
140
+ const int16_t* kk,
141
+ int kmax,
142
+ unsigned int coefs_precision,
143
+ int64_t num_channels,
144
+ bool is_last_line);
145
+
146
+ void ImagingResampleVerticalConvolution8u(
147
+ uint8_t* C10_RESTRICT lineOut,
148
+ const uint8_t* C10_RESTRICT lineIn,
149
+ int64_t xsize,
150
+ int64_t ids_min,
151
+ int64_t ids_size,
152
+ const int16_t* k,
153
+ unsigned int coefs_precision,
154
+ int64_t num_channels);
155
+
156
+ template<int num_channels>
157
+ void ImagingResampleHorizontal(
158
+ const at::Tensor & unpacked_output,
159
+ const at::Tensor & unpacked_input,
160
+ int ksize,
161
+ const std::vector<at::Tensor>& horiz_indices_weights,
162
+ unsigned int horiz_weights_precision) {
163
+
164
+ // Interpolation horizontal pass: we compute x-axis (image width) interpolation outputs.
165
+
166
+ // Input data is stored as
167
+ // input = [r[0], g[0], b[0], a[0], r[1], g[1], b[1], a[1], r[2], g[2], b[2], a[2], ...]
168
+ // Weights are float values computed for each output pixel and rescaled to uint16:
169
+ // weights[i] = [w[i, 0], w[i, 1], ..., w[i, K-1]]
170
+ // We want to compute the output as following:
171
+ // output = [oR[0], oG[0], oB[0], oA[0], oR[1], oG[1], oB[1], oA[1], ...]
172
+ // where
173
+ // oR[yoffset + i] = r[yoffset + xmin[i]] * w[i, 0] + ... + r[yoffset + xmin[i] + K-1] * w[i, K-1]
174
+ // oG[yoffset + i] = g[yoffset + xmin[i]] * w[i, 0] + ... + g[yoffset + xmin[i] + K-1] * w[i, K-1]
175
+ // oB[yoffset + i] = b[yoffset + xmin[i]] * w[i, 0] + ... + b[yoffset + xmin[i] + K-1] * w[i, K-1]
176
+ //
177
+
178
+ // TODO: we may want to merge that into the fallback code (currently called
179
+ // basic_loop_aa_horizontal<uint8_t>)
180
+ // Although this may not be needed if / when we port all this code to use
181
+ // Vec.h since this would potentially give us another fall-back implem
182
+
183
+ const int16_t* kk = (int16_t*)(horiz_indices_weights[3].const_data_ptr<double>());
184
+
185
+ auto xout = unpacked_output.size(2);
186
+ auto yout = unpacked_output.size(1);
187
+ auto xin = unpacked_input.size(2);
188
+ TORCH_INTERNAL_ASSERT(num_channels == unpacked_input.size(0));
189
+
190
+ const int64_t* idx_ptr_xmin = horiz_indices_weights[0].const_data_ptr<int64_t>();
191
+ const int64_t* idx_ptr_size = horiz_indices_weights[1].const_data_ptr<int64_t>();
192
+
193
+ uint8_t* unpacked_output_p = unpacked_output.data_ptr<uint8_t>();
194
+ const uint8_t* unpacked_input_p = unpacked_input.const_data_ptr<uint8_t>();
195
+
196
+ int64_t yy = 0;
197
+ auto xout_stride = xout * num_channels;
198
+ auto xin_stride = xin * num_channels;
199
+ for (; yy < yout - 3; yy += 4) {
200
+ ImagingResampleHorizontalConvolution8u4x(
201
+ unpacked_output_p + yy * xout_stride,
202
+ unpacked_output_p + (yy + 1) * xout_stride,
203
+ unpacked_output_p + (yy + 2) * xout_stride,
204
+ unpacked_output_p + (yy + 3) * xout_stride,
205
+ xout,
206
+ unpacked_input_p + yy * xin_stride,
207
+ unpacked_input_p + (yy + 1) * xin_stride,
208
+ unpacked_input_p + (yy + 2) * xin_stride,
209
+ unpacked_input_p + (yy + 3) * xin_stride,
210
+ xin,
211
+ idx_ptr_xmin,
212
+ idx_ptr_size,
213
+ kk,
214
+ ksize,
215
+ horiz_weights_precision,
216
+ num_channels,
217
+ yy + 3 == yout - 1);
218
+ }
219
+ for (; yy < yout; yy++) {
220
+ ImagingResampleHorizontalConvolution8u(
221
+ unpacked_output_p + yy * xout_stride,
222
+ xout,
223
+ unpacked_input_p + yy * xin_stride,
224
+ xin,
225
+ idx_ptr_xmin,
226
+ idx_ptr_size,
227
+ kk,
228
+ ksize,
229
+ horiz_weights_precision,
230
+ num_channels,
231
+ yy == yout - 1);
232
+ }
233
+ }
234
+
235
+ void ImagingResampleVertical(
236
+ const at::Tensor & unpacked_output,
237
+ const at::Tensor & unpacked_input,
238
+ int ksize,
239
+ const std::vector<at::Tensor>& vert_indices_weights,
240
+ unsigned int vert_weights_precision) {
241
+
242
+ // Interpolation vertical pass: we compute y-axis interpolation outputs.
243
+ // Input data is stored as
244
+ // input = [r[0], g[0], b[0], a[0], r[1], g[1], b[1], a[1], r[2], g[2], b[2], a[2], ...]
245
+ // Weights are float values computed for each output pixel and rescaled to uint16:
246
+ // weights[i] = [w[i, 0], w[i, 1], ..., w[i, K-1]]
247
+ // We want to compute the output as following:
248
+ // output = [oR[0], oG[0], oB[0], oA[0], oR[1], oG[1], oB[1], oA[1], ...]
249
+ // where
250
+ // oR[xoffset + i] = r[xoffset + ymin[i]] * w[i, 0] + ... + r[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1]
251
+ // oG[xoffset + i] = g[xoffset + ymin[i]] * w[i, 0] + ... + g[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1]
252
+ // oB[xoffset + i] = b[xoffset + ymin[i]] * w[i, 0] + ... + b[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1]
253
+
254
+ // TODO: we may want to merge that into the fallback code (currently called
255
+ // basic_loop_aa_vertical<uint8_t>)
256
+ // Although this may not be needed if / when we port all this code to use
257
+ // Vec.h since this would potentially give us another fall-back implem
258
+ const int16_t* kk = (int16_t*)(vert_indices_weights[3].const_data_ptr<double>());
259
+
260
+ const int64_t* idx_ptr_xmin = vert_indices_weights[0].const_data_ptr<int64_t>();
261
+ const int64_t* idx_ptr_size = vert_indices_weights[1].const_data_ptr<int64_t>();
262
+
263
+ uint8_t* unpacked_output_p = unpacked_output.data_ptr<uint8_t>();
264
+ const uint8_t* unpacked_input_p = unpacked_input.const_data_ptr<uint8_t>();
265
+
266
+ auto xout = unpacked_output.size(2);
267
+ auto yout = unpacked_output.size(1);
268
+ const auto num_channels = unpacked_input.size(0);
269
+ TORCH_INTERNAL_ASSERT(num_channels == unpacked_output.size(0));
270
+
271
+ auto xout_stride = xout * num_channels;
272
+ for (const auto yy : c10::irange(yout)) {
273
+ const auto* k = &kk[yy * ksize];
274
+ auto ids_min = idx_ptr_xmin[yy];
275
+ auto ids_size = idx_ptr_size[yy];
276
+ ImagingResampleVerticalConvolution8u(
277
+ unpacked_output_p + yy * xout_stride,
278
+ unpacked_input_p,
279
+ xout,
280
+ ids_min,
281
+ ids_size,
282
+ k,
283
+ vert_weights_precision,
284
+ num_channels);
285
+ }
286
+ }
287
+
288
+ // This is the only public entry point in this file. It supports bilinear or bicubic
289
+ // mode for uint8 dtype when C <= 4, with or without antialias. The
290
+ // implem is based on PIL-SIMD.
291
+ // Its equivalent implementation (fallback) for when AVX isn't supported or when
292
+ // C > 4 is separable_upsample_generic_Nd_kernel_impl() There are a bunch of
293
+ // future improvement that can be done: look for the TODOs in this file.
294
+ // For details on how the weights are computed and how the multiplications are
295
+ // run on int (instead of float weights), see
296
+ // [ Weights computation for uint8_t and multiplication trick ]
297
+ // For details on how the AVX kernels are implemented, see
298
+ // https://gist.github.com/NicolasHug/47c97d731f05eaad5694c173849b86f5
299
+ // See also [ Support for antialias=False as a subcase of antialias=True ] to
300
+ // learn more about how the antialias=False case is computed. The same holds
301
+ // here: all these kernels are general enough to handle an arbitrary number of
302
+ // weights, but when aa=False they could be optimized further.
303
+ template <typename scale_type, class F>
304
+ void upsample_avx_bilinear_bicubic_uint8(
305
+ const at::Tensor& input_,
306
+ const at::Tensor& output,
307
+ bool align_corners,
308
+ const scale_type& scales,
309
+ bool antialias) {
310
+ auto batch_size = input_.size(0);
311
+ auto num_channels = input_.size(1);
312
+ auto xin = input_.size(3);
313
+ auto yin = input_.size(2);
314
+ auto xout = output.size(3);
315
+ auto yout = output.size(2);
316
+
317
+ if (xin == xout && yin == yout) {
318
+ output.copy_(input_);
319
+ return;
320
+ }
321
+
322
+ at::Tensor input = input_;
323
+ if (!(input.is_contiguous() || input.is_contiguous(at::MemoryFormat::ChannelsLast))) {
324
+ // If input is not contiguous with memory format channels first or channels last,
325
+ // we explicitly convert the input to contiguous channels last memory format.
326
+ // This simplifies the rest of the code and let us assume that the format is only contiguous channels first or channels last,
327
+ // Most tensors going through this `if` block won't need to go through unpacking, but those having C < 3 may
328
+ // have to (this means 2 copies are made). We could avoid the extra copy by handling non-contiguous input
329
+ // directly within unpack_rgb() and pack_rgb(), but initial attempts showed that this is fairly complex.
330
+ input = input.contiguous(at::MemoryFormat::ChannelsLast);
331
+ }
332
+
333
+ auto need_horizontal = xout != xin;
334
+ auto need_vertical = yout != yin;
335
+
336
+ int ksize_horiz, ksize_vert;
337
+ std::vector<at::Tensor> horiz_indices_weights, vert_indices_weights;
338
+ unsigned int horiz_weights_precision, vert_weights_precision;
339
+
340
+ bool skip_unpacking = (num_channels == 3 || num_channels == 4) && input.is_contiguous(at::MemoryFormat::ChannelsLast);
341
+ bool skip_packing = (num_channels == 3 || num_channels == 4) && output.is_contiguous(at::MemoryFormat::ChannelsLast);
342
+
343
+ if (need_horizontal) {
344
+ int interp_dim = 3;
345
+ auto stride = (skip_unpacking) ? num_channels : 4;
346
+ std::tie(horiz_indices_weights, ksize_horiz, horiz_weights_precision) =
347
+ F::compute_index_ranges_int16_weights(
348
+ /*input_size=*/xin,
349
+ /*output_size=*/xout,
350
+ /*stride=*/stride,
351
+ /*ndims=*/4,
352
+ /*reshape_dim=*/interp_dim,
353
+ /*align_corners=*/align_corners,
354
+ /*opt_scale=*/scales[interp_dim - 2],
355
+ /*antialias=*/antialias,
356
+ /*align_i32=*/true);
357
+ }
358
+
359
+ if (need_vertical) {
360
+ int interp_dim = 2;
361
+ auto stride = (skip_unpacking) ? num_channels * xout : 4 * xout;
362
+ std::tie(vert_indices_weights, ksize_vert, vert_weights_precision) =
363
+ F::compute_index_ranges_int16_weights(
364
+ /*input_size=*/yin,
365
+ /*output_size=*/yout,
366
+ /*stride=*/stride,
367
+ /*ndims=*/4,
368
+ /*reshape_dim=*/interp_dim,
369
+ /*align_corners=*/align_corners,
370
+ /*opt_scale=*/scales[interp_dim - 2],
371
+ /*antialias=*/antialias,
372
+ /*align_i32=*/true);
373
+ }
374
+
375
+ at::Tensor buffer_horiz, buffer_vert;
376
+ // Minor optimization: we can avoid allocating an extra buffer if we're performing
377
+ // horizontal-only or vertical-only interpolation, and if the tensor doesn't
378
+ // need repacking
379
+ if (need_horizontal && (need_vertical || !skip_packing)) {
380
+ auto c = (skip_unpacking) ? num_channels : 4;
381
+ buffer_horiz = at::empty({c, yin, xout}, input.options());
382
+ }
383
+ if (need_vertical && !skip_packing) {
384
+ auto c = (skip_unpacking) ? num_channels : 4;
385
+ buffer_vert = at::empty({c, yout, xout}, input.options());
386
+ }
387
+
388
+ for (const auto i : c10::irange(batch_size)) {
389
+
390
+ at::Tensor unpacked_input = (skip_unpacking) ? input[i] : unpack_rgb(input[i]);
391
+ at::Tensor unpacked_output;
392
+
393
+ if (need_horizontal) {
394
+ at::Tensor unpacked_output_temp = (need_vertical || !skip_packing) ? buffer_horiz : output[i];
395
+
396
+ if (skip_unpacking && num_channels == 3) {
397
+ ImagingResampleHorizontal<3>(
398
+ unpacked_output_temp,
399
+ unpacked_input,
400
+ ksize_horiz,
401
+ horiz_indices_weights,
402
+ horiz_weights_precision);
403
+ } else {
404
+ ImagingResampleHorizontal<4>(
405
+ unpacked_output_temp,
406
+ unpacked_input,
407
+ ksize_horiz,
408
+ horiz_indices_weights,
409
+ horiz_weights_precision);
410
+ }
411
+ unpacked_output = unpacked_input = unpacked_output_temp;
412
+ }
413
+ if (need_vertical) {
414
+ unpacked_output = (skip_packing) ? output[i] : buffer_vert;
415
+
416
+ ImagingResampleVertical(
417
+ unpacked_output,
418
+ unpacked_input,
419
+ ksize_vert,
420
+ vert_indices_weights,
421
+ vert_weights_precision
422
+ );
423
+ }
424
+
425
+ TORCH_INTERNAL_ASSERT(unpacked_output.defined());
426
+
427
+ if (!skip_packing) {
428
+ pack_rgb(unpacked_output, output[i]);
429
+ }
430
+ }
431
+ }
432
+
433
+ void ImagingResampleHorizontalConvolution8u4x(
434
+ uint8_t* C10_RESTRICT lineOut0,
435
+ uint8_t* C10_RESTRICT lineOut1,
436
+ uint8_t* C10_RESTRICT lineOut2,
437
+ uint8_t* C10_RESTRICT lineOut3,
438
+ int64_t out_xsize,
439
+ const uint8_t* C10_RESTRICT lineIn0,
440
+ const uint8_t* C10_RESTRICT lineIn1,
441
+ const uint8_t* C10_RESTRICT lineIn2,
442
+ const uint8_t* C10_RESTRICT lineIn3,
443
+ int64_t in_xsize,
444
+ const int64_t* idx_ptr_xmin,
445
+ const int64_t* idx_ptr_size,
446
+ const int16_t* kk,
447
+ int kmax,
448
+ unsigned int coefs_precision,
449
+ int64_t num_channels,
450
+ bool is_last_line) {
451
+
452
+ // Interpolation horizontal pass processing together 4 vertical lines.
453
+ // - Input data format is RGBA or RGB with R,G,B,A being uint8. In case of RGBA
454
+ // we can encode 4 values as a single uint32 value.
455
+ // - We split the size of weight vector for a given output index as a sum:
456
+ // ids_size = num_blocks_4 * 4 + num_blocks_2 * 2 + num_blocks_1.
457
+ // - We load and process 4 weights values in a loop ("block 4") then we process 2 weights values
458
+ // in another loop ("block 2") and finally we process 1 weights value in the final loop ("block 1").
459
+
460
+ // Define shuffling masks (low/high) for num_channels 4 and 3
461
+ // Mask low casts lower half of each lane to epi16 and reorder RGBARGBA -> RRGGBBAA:
462
+ // [r1 g1 b1 a1 r2 g2 b2 a2 ... | R1 G1 B1 A1 R2 G2 B2 A2 ... ] ->
463
+ // [r1 0 r2 0 g1 0 g2 0 b1 0 b2 0 a1 0 a2 0 | R1 0 R2 0 G1 0 G2 0 B1 0 B2 0 A1 0 A2 0]
464
+ // Mask high casts upper half of each lane to epi16 and reorder RGBARGBA -> RRGGBBAA::
465
+ // [ ... r3 g3 b3 a3 r4 g4 b4 a4 | ... R3 G3 B3 A3 R4 G4 B4 A4 ] ->
466
+ // [r3 0 r4 0 g3 0 g4 0 b3 0 b4 0 a3 0 a4 0 | R3 0 R4 0 G3 0 G4 0 B3 0 B4 0 A3 0 A4 0]
467
+
468
+ const auto mask_low_c4 = _mm256_set_epi8(
469
+ -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0,
470
+ -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);
471
+ const auto mask_high_c4 = _mm256_set_epi8(
472
+ -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8,
473
+ -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8);
474
+ const auto mask_low_c3 = _mm256_set_epi8(
475
+ -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0,
476
+ -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0);
477
+ const auto mask_high_c3 = _mm256_set_epi8(
478
+ -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6,
479
+ -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6);
480
+
481
+ const auto mask_low = (num_channels == 3) ? mask_low_c3 : mask_low_c4;
482
+ const auto mask_high = (num_channels == 3) ? mask_high_c3 : mask_high_c4;
483
+
484
+ const auto stride = num_channels * sizeof(uint8_t);
485
+
486
+ TORCH_INTERNAL_ASSERT(stride == 3 || stride == 4);
487
+
488
+ // out_xsize = output width, out_x = output x index
489
+ // ids_min is the input offset index corresponding to out_x
490
+ // ids_size is the interpolation size for out_x
491
+
492
+ // Let's precompute ids_size limits for block 4 and block 2.
493
+ //
494
+ // In block 4 (4 means we process 4 weight values together), we read input data
495
+ // with _mm_loadu_si128, i.e. 16 bytes, per one line:
496
+ // lineIn0 + stride * (i + ids_min) + 16 <= lineIn0 + stride * (ids_size + ids_min)
497
+ // --> i <= ids_size - 16.0 / stride
498
+ // Strict boundary:
499
+ // --> i < ids_size + 1 - int(ceil(16.0 / stride)) = ids_size - b4_delta
500
+ // Soft boundary for reading inside the buffer except its boundaries:
501
+ // --> i < ids_size + 1 - int(16.0 / stride) = ids_size - b4_delta_soft
502
+ // RGBA: b4_delta = b4_delta_soft = 3
503
+ // RGB : b4_delta = 5
504
+ // RGB : b4_delta_soft = 4
505
+ const auto b4_delta = (stride == 4) ? 3 : ((is_last_line) ? 5 : 4);
506
+
507
+ // In block 2 (2 means we process 2 weights values together), we read input data
508
+ // with _mm_loadl_epi64, i.e. 8 bytes, per one line:
509
+ // lineIn0 + stride * (i + ids_min) + 8 <= lineIn0 + stride * (ids_size + ids_min)
510
+ // --> i <= ids_size - 8.0 / stride
511
+ // Strict boundary:
512
+ // --> i < ids_size + 1 - int(ceil(8.0 / stride)) = ids_size - b2_delta
513
+ // Soft boundary for reading inside the buffer except its boundaries:
514
+ // --> i < ids_size + 1 - int(8.0 / stride) = ids_size - b2_delta_soft
515
+ // RGBA: b2_delta = b2_delta_soft = 1
516
+ // RGB : b2_delta = 2
517
+ // RGB : b2_delta_soft = 1
518
+ const auto b2_delta = (stride == 4) ? 1 : ((is_last_line) ? 2 : 1);
519
+
520
+ const auto max_out_x_strided = out_xsize * stride;
521
+ const auto max_in_x_strided = in_xsize * stride;
522
+
523
+ const auto zero = _mm256_setzero_si256();
524
+ const auto initial = _mm256_set1_epi32(1 << (coefs_precision - 1));
525
+
526
+ for (const auto out_x : c10::irange(out_xsize)) {
527
+ const auto ids_min = idx_ptr_xmin[out_x];
528
+ const auto ids_size = idx_ptr_size[out_x];
529
+ const auto * k = &kk[out_x * kmax];
530
+ int64_t i = 0;
531
+
532
+ auto sss0 = initial;
533
+ auto sss1 = initial;
534
+
535
+ const auto * lineIn0_min = lineIn0 + ids_min;
536
+ const auto * lineIn1_min = lineIn1 + ids_min;
537
+ const auto * lineIn2_min = lineIn2 + ids_min;
538
+ const auto * lineIn3_min = lineIn3 + ids_min;
539
+
540
+ // block 4
541
+ for (; i < ids_size - b4_delta; i += 4) {
542
+ // Load 4 values from weight vector
543
+ // mmk0 = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ...]
544
+ // mmk1 = [wl_2 wh_2 wl_3 wh_3 wl_2 wh_2 wl_3 wh_3 ...]
545
+ const auto mmk0 = _mm256_set1_epi32(*(int32_t*)&k[i]);
546
+ const auto mmk1 = _mm256_set1_epi32(*(int32_t*)&k[i + 2]);
547
+
548
+ // RGBA: Load 8 pixels (4 per line) from input lines 0 and 1:
549
+ // source = [
550
+ // r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3
551
+ // R0 G0 B0 A0 R1 G1 B1 A1 R2 G2 B2 A2 R3 G3 B3 A3
552
+ // ]
553
+ // RGB: Load 10 pixels (5 per line)
554
+ // source = [
555
+ // r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5
556
+ // R0 G0 B0 R1 G1 B1 R2 G2 B2 R3 G3 B3 R4 G4 B4 R5
557
+ // ]
558
+ auto source = _mm256_inserti128_si256(_mm256_castsi128_si256(
559
+ _mm_loadu_si128((__m128i *) (lineIn0_min + stride * i))),
560
+ _mm_loadu_si128((__m128i *) (lineIn1_min + stride * i)), 1);
561
+
562
+ // Apply mask_low:
563
+ // RGBA:
564
+ // [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 A0 0 A1 0]
565
+ // RGB:
566
+ // [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 0 0 0 0]
567
+ auto pix1 = _mm256_shuffle_epi8(source, mask_low);
568
+ // Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision
569
+ sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk0));
570
+
571
+ // Apply mask_high:
572
+ // RGBA:
573
+ // [r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 a2 0 a3 0 | R2 0 R3 0 G2 0 G3 0 B2 0 B3 0 A2 0 A3 0]
574
+ // RGB:
575
+ // [r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 0 0 0 0 | R2 0 R3 0 G2 0 G3 0 B2 0 B3 0 0 0 0 0]
576
+ auto pix2 = _mm256_shuffle_epi8(source, mask_high);
577
+ // Compute output value as C += w2 * C2 + w3 * C3 for each channel in 32-bit precision
578
+ sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix2, mmk1));
579
+
580
+ // Same as above to next lines 2 and 3:
581
+ auto source2 = _mm256_inserti128_si256(_mm256_castsi128_si256(
582
+ _mm_loadu_si128((__m128i *) (lineIn2_min + stride * i))),
583
+ _mm_loadu_si128((__m128i *) (lineIn3_min + stride * i)), 1);
584
+ auto pix3 = _mm256_shuffle_epi8(source2, mask_low);
585
+ sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix3, mmk0));
586
+ auto pix4 = _mm256_shuffle_epi8(source2, mask_high);
587
+ sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix4, mmk1));
588
+ }
589
+
590
+ // block 2
591
+ for (; i < ids_size - b2_delta; i += 2) {
592
+ // Load 2 values from weight vector
593
+ // mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ...]
594
+ const auto mmk = _mm256_set1_epi32(*(int32_t*)&k[i]);
595
+
596
+ // Load 4 pixels (2 per line) from input lines 0 and 1:
597
+ // RGBA: source1 = [
598
+ // r0 g0 b0 a0 r1 g1 b1 a1 0 0 0 0 0 0 0 0
599
+ // R0 G0 B0 A0 R1 G1 B1 A1 0 0 0 0 0 0 0 0
600
+ // ]
601
+ // RGB: source1 = [
602
+ // r0 g0 b0 r1 g1 b1 r2 0 0 0 0 0 0 0 0
603
+ // R0 G0 B0 R1 G1 B1 R2 0 0 0 0 0 0 0 0
604
+ // ]
605
+ auto source1 = _mm256_inserti128_si256(_mm256_castsi128_si256(
606
+ _mm_loadl_epi64((__m128i *) (lineIn0_min + stride * i))),
607
+ _mm_loadl_epi64((__m128i *) (lineIn1_min + stride * i)), 1);
608
+ // Apply mask_low:
609
+ // RGBA:
610
+ // [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 A0 0 A1 0]
611
+ // RGB:
612
+ // [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 0 0 0 0]
613
+ auto pix1 = _mm256_shuffle_epi8(source1, mask_low);
614
+ // Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision
615
+ sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk));
616
+
617
+ // Same as above for lines 2 and 3:
618
+ auto source2 = _mm256_inserti128_si256(_mm256_castsi128_si256(
619
+ _mm_loadl_epi64((__m128i *) (lineIn2_min + stride * i))),
620
+ _mm_loadl_epi64((__m128i *) (lineIn3_min + stride * i)), 1);
621
+ auto pix2 = _mm256_shuffle_epi8(source2, mask_low);
622
+ sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
623
+ }
624
+
625
+ // block 1
626
+ const auto i32_aligned = num_channels == 4;
627
+ for (; i < ids_size - 1; i++) {
628
+ // Load 1 value from weight vector
629
+ // mmk = [wl_0 wh_0 0 0 wl_0 wh_0 0 0 ...]
630
+ const auto mmk = _mm256_set1_epi32(k[i]);
631
+
632
+ // Load 2 pixels (one per line) from input lines 0 and 1:
633
+ // RGBA: pix1 = [
634
+ // r0 0 0 0 g0 0 0 0 b0 0 0 0 a0 0 0 0
635
+ // R0 0 0 0 G0 0 0 0 B0 0 0 0 A0 0 0 0
636
+ // ]
637
+ // RGB: pix1 = [
638
+ // r0 0 0 0 g0 0 0 0 b0 0 0 0 r1 0 0 0
639
+ // R0 0 0 0 G0 0 0 0 B0 0 0 0 R1 0 0 0
640
+ // ]
641
+ auto pix1 = _mm256_inserti128_si256(_mm256_castsi128_si256(
642
+ mm_cvtepu8_epi32(lineIn0_min + stride * i, i32_aligned)),
643
+ mm_cvtepu8_epi32(lineIn1_min + stride * i, i32_aligned), 1);
644
+ // Compute output value as C += w0 * C0 for each channel in 32-bit precision
645
+ sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk));
646
+
647
+ // Same as above for lines 2 and 3
648
+ auto pix2 = _mm256_inserti128_si256(_mm256_castsi128_si256(
649
+ mm_cvtepu8_epi32(lineIn2_min + stride * i, i32_aligned)),
650
+ mm_cvtepu8_epi32(lineIn3_min + stride * i, i32_aligned), 1);
651
+ sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
652
+ }
653
+
654
+ if (i == ids_size - 1) {
655
+ // last element
656
+ auto mmk = _mm256_set1_epi32(k[i]);
657
+ // For num_channels == 3 (3 bytes = one pixel) we tolerate to read 4 bytes
658
+ // lines 0, 1 and 2 wont go out of allocated memory bounds
659
+ auto pix = _mm256_inserti128_si256(_mm256_castsi128_si256(
660
+ mm_cvtepu8_epi32(lineIn0_min + stride * i, i32_aligned)),
661
+ mm_cvtepu8_epi32(lineIn1_min + stride * i, i32_aligned), 1);
662
+ sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix, mmk));
663
+
664
+ auto p0 = mm_cvtepu8_epi32(lineIn2_min + stride * i, i32_aligned);
665
+ __m128i p1;
666
+ if (num_channels == 3 && C10_UNLIKELY(is_last_line && ids_min + stride * i + 4 >= max_in_x_strided)) {
667
+ uint8_t input[4];
668
+ std::memcpy(input, lineIn3_min + stride * i, 3);
669
+ p1 = mm_cvtepu8_epi32(input, true);
670
+ } else {
671
+ p1 = mm_cvtepu8_epi32(lineIn3_min + stride * i, i32_aligned);
672
+ }
673
+ auto pix2 = _mm256_inserti128_si256(_mm256_castsi128_si256(p0), p1, 1);
674
+ sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
675
+ }
676
+
677
+ // Convert fixed point values back to integers (truncating)
678
+ sss0 = _mm256_srai_epi32(sss0, coefs_precision);
679
+ sss1 = _mm256_srai_epi32(sss1, coefs_precision);
680
+ // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
681
+ // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d 0 0 0 0 0 0 0 0)
682
+ sss0 = _mm256_packs_epi32(sss0, zero);
683
+ sss1 = _mm256_packs_epi32(sss1, zero);
684
+ // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
685
+ // (a a b b c c d d) -> (a b c d 0 0 0 0)
686
+ sss0 = _mm256_packus_epi16(sss0, zero);
687
+ sss1 = _mm256_packus_epi16(sss1, zero);
688
+
689
+ // Write the output into single uint32
690
+ // (a b c d) -> x_uint32
691
+ auto o0 = _mm_cvtsi128_si32(_mm256_castsi256_si128(sss0));
692
+ auto o1 = _mm_cvtsi128_si32(_mm256_extracti128_si256(sss0, 1));
693
+ auto o2 = _mm_cvtsi128_si32(_mm256_castsi256_si128(sss1));
694
+ auto o3 = _mm_cvtsi128_si32(_mm256_extracti128_si256(sss1, 1));
695
+
696
+ const auto out_x_strided = stride * out_x;
697
+
698
+ if (num_channels == 3 && C10_UNLIKELY(out_x_strided + 4 >= max_out_x_strided)) {
699
+ // Memcpy 4-bytes is faster than 3-bytes and this is a boundary case when we want to write
700
+ // 4 bytes (R G B | X) to the output buffer (X1 X2 X3 | R1).
701
+ // The 4th byte in the register (X) has a garbage value and 4th byte in the output buffer (R1) has a correct
702
+ // value which was previously computed by another line. In other words, it means that we can not overwrite
703
+ // it by simply writing 4 bytes from the register to the output. We'll do the following:
704
+ // v----------|
705
+ // Output = [... X1 X2 X3 | R1 G1 B1 R2 ...]
706
+ // First, we write R1 value to the 4th byte of (R G B | X) -> (R G B | R1)
707
+ // Second, we write 4 bytes from the register to the output: (X1 X2 X3 | R1) -> (R G B | R1)
708
+ // Output = [... R G B | R1 G1 B1 R2 ...]
709
+
710
+ _write_endline_rgb_as_uint32(lineOut0 + out_x_strided, o0);
711
+ _write_endline_rgb_as_uint32(lineOut1 + out_x_strided, o1);
712
+ _write_endline_rgb_as_uint32(lineOut2 + out_x_strided, o2);
713
+
714
+ if (C10_UNLIKELY(is_last_line)) {
715
+ // When we handle the last line, we can not access the next 4 bytes
716
+ // as they are out of memory bounds.
717
+ std::memcpy(lineOut3 + out_x_strided, (uint8_t *) &o3, num_channels);
718
+ } else {
719
+ _write_endline_rgb_as_uint32(lineOut3 + out_x_strided, o3);
720
+ }
721
+ } else if (num_channels == 3) {
722
+ // Memcpy 4-bytes is faster than 3-bytes and here
723
+ // we simply write 4 bytes (... R G B X 0 0 0 0 0 ...) where X is a garbage value
724
+ // that we will overwrite on the next iteration: (... R G B R G B X 0 0 ...)
725
+ std::memcpy(lineOut0 + out_x_strided, (uint8_t *) &o0, 4);
726
+ std::memcpy(lineOut1 + out_x_strided, (uint8_t *) &o1, 4);
727
+ std::memcpy(lineOut2 + out_x_strided, (uint8_t *) &o2, 4);
728
+ std::memcpy(lineOut3 + out_x_strided, (uint8_t *) &o3, 4);
729
+ } else {
730
+ // num_channels = 4 -> lineOutX + out_x_strided should be uint32 aligned
731
+ *(uint32_t *)(lineOut0 + out_x_strided) = o0;
732
+ *(uint32_t *)(lineOut1 + out_x_strided) = o1;
733
+ *(uint32_t *)(lineOut2 + out_x_strided) = o2;
734
+ *(uint32_t *)(lineOut3 + out_x_strided) = o3;
735
+ }
736
+ }
737
+ }
738
+
739
+ void ImagingResampleHorizontalConvolution8u(
740
+ uint8_t* C10_RESTRICT lineOut,
741
+ int64_t out_xsize,
742
+ const uint8_t* C10_RESTRICT lineIn,
743
+ int64_t in_xsize,
744
+ const int64_t* idx_ptr_xmin,
745
+ const int64_t* idx_ptr_size,
746
+ const int16_t* kk,
747
+ int kmax,
748
+ unsigned int coefs_precision,
749
+ int64_t num_channels,
750
+ bool is_last_line) {
751
+
752
+ // Interpolation horizontal pass processing only one vertical line.
753
+ // - Input data format is RGBA or RGB with R,G,B,A being uint8. In case of RGBA
754
+ // we can encode 4 values as a single uint32 value.
755
+ // - We split the size of weight vector for a given output index as a sum:
756
+ // ids_size = num_blocks_8 * 8 + num_blocks_4 * 4 + num_blocks_2 * 2 + num_blocks_1
757
+ // - We load and process 8 weights values in a loop ("block 8") then 4 weights and 2 weights values in
758
+ // in another loops ("block 4" and "block 2") and finally we process 1 weight value in the final loop ("block 1").
759
+
760
+ // Define various shuffling masks
761
+ const auto kmask_low = _mm256_set_epi8(
762
+ 11, 10, 9, 8, 11, 10, 9, 8, 11, 10, 9, 8, 11, 10, 9, 8,
763
+ 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0);
764
+ const auto kmask_high = _mm256_set_epi8(
765
+ 15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12,
766
+ 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4);
767
+ const auto kmask_hl = _mm256_set_epi8(
768
+ 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4,
769
+ 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0);
770
+
771
+ const auto mask_low_c4 = _mm256_set_epi8(
772
+ -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0,
773
+ -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);
774
+ const auto mask_high_c4 = _mm256_set_epi8(
775
+ -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8,
776
+ -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8);
777
+ const auto mask_low_c3 = _mm256_set_epi8(
778
+ -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0,
779
+ -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0);
780
+ const auto mask_high_c3 = _mm256_set_epi8(
781
+ -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6,
782
+ -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6);
783
+ const auto mask_hl_c3 = _mm256_set_epi8(
784
+ -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6,
785
+ -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0);
786
+ const auto mask_hl_c4 = _mm256_set_epi8(
787
+ -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8,
788
+ -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);
789
+
790
+ const auto mask_low128_c3 = _mm_set_epi8(
791
+ -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0);
792
+ const auto mask_low128_c4 = _mm_set_epi8(
793
+ -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);
794
+
795
+ const auto mask_low = (num_channels == 3) ? mask_low_c3 : mask_low_c4;
796
+ const auto mask_high = (num_channels == 3) ? mask_high_c3 : mask_high_c4;
797
+ const auto mask_hl = (num_channels == 3) ? mask_hl_c3 : mask_hl_c4;
798
+ const auto mask_low128 = (num_channels == 3) ? mask_low128_c3 : mask_low128_c4;
799
+
800
+ // out_xsize = output width, out_x = output x index
801
+ // ids_min is the input offset index corresponding to out_x
802
+ // ids_size is the interpolation size for out_x
803
+
804
+ const auto stride = num_channels * sizeof(uint8_t);
805
+ const auto zero = _mm_setzero_si128();
806
+
807
+ TORCH_INTERNAL_ASSERT(stride == 3 || stride == 4);
808
+
809
+ // Let's precompute ids_size limits for block 8, block 4 and block 2
810
+ //
811
+ // In block 8 (8 means we process 8 weight values together), we read at
812
+ // most 32 bytes input data (16 + 16 bytes for RGBA and 12 + 16 bytes for RGB)
813
+ // lineIn + stride * (i + ids_min) + 32 <= lineIn + stride * (ids_size + ids_min)
814
+ // --> i <= ids_size - 32.0 / stride
815
+ // Strict boundary:
816
+ // --> i < ids_size + 1 - int(ceil(32.0 / stride)) = ids_size - b8_delta
817
+ // Soft boundary for reading inside the buffer except its boundaries:
818
+ // --> i < ids_size + 1 - int(32.0 / stride) = ids_size - b8_delta_soft
819
+ // RGBA: b8_delta = b8_delta_soft = 7
820
+ // RGB : b8_delta = 10
821
+ // RGB : b8_delta_soft = 9
822
+ const auto b8_delta = (stride == 4) ? 7 : ((is_last_line) ? 10 : 9);
823
+
824
+ // In block 4 (4 means we process 4 weight values together), we read
825
+ // 16 bytes of input data.
826
+ // lineIn + stride * (i + ids_min) + 16 <= lineIn0 + stride * (ids_size + ids_min)
827
+ // --> i <= ids_size - 16.0 / stride
828
+ // Strict boundary:
829
+ // --> i < ids_size + 1 - int(ceil(16.0 / stride)) = ids_size - b4_delta
830
+ // Soft boundary for reading inside the buffer except its boundaries:
831
+ // --> i < ids_size + 1 - int(16.0 / stride) = ids_size - b4_delta_soft
832
+ // RGBA: b4_delta = b4_delta_soft = 3
833
+ // RGB : b4_delta = 5
834
+ // RGB : b4_delta_soft = 4
835
+ const auto b4_delta = (stride == 4) ? 3 : ((is_last_line) ? 5 : 4);
836
+
837
+ // In block 2 (2 means we process 2 weight values together), we read
838
+ // 8 bytes of input data.
839
+ // lineIn0 + stride * (i + ids_min) + 8 <= lineIn0 + stride * (ids_size + ids_min)
840
+ // --> i <= ids_size - 8.0 / stride
841
+ // Strict boundary:
842
+ // --> i < ids_size + 1 - int(ceil(8.0 / stride)) = ids_size - b2_delta
843
+ // Soft boundary for reading inside the buffer except its boundaries:
844
+ // --> i < ids_size + 1 - int(8.0 / stride) = ids_size - b2_delta_soft
845
+ // RGBA: b2_delta = b2_delta_soft = 1
846
+ // RGB : b2_delta = 2
847
+ // RGB : b2_delta_soft = 1
848
+ const auto b2_delta = (stride == 4) ? 1 : ((is_last_line) ? 2 : 1);
849
+
850
+ const auto max_out_x_strided = out_xsize * stride;
851
+ const auto max_in_x_strided = in_xsize * stride;
852
+
853
+ for (const auto out_x : c10::irange(out_xsize)) {
854
+ __m128i sss;
855
+ const auto ids_min = idx_ptr_xmin[out_x];
856
+ const auto ids_size = idx_ptr_size[out_x];
857
+ const auto * k = &kk[out_x * kmax];
858
+ int64_t i = 0;
859
+
860
+ const auto * lineIn_min = lineIn + ids_min;
861
+
862
+ if (ids_size < 8) {
863
+ sss = _mm_set1_epi32(1 << (coefs_precision - 1));
864
+ } else {
865
+ // Lower part will be added to higher, use only half of the error
866
+ auto sss256 = _mm256_set1_epi32(1 << (coefs_precision - 2));
867
+
868
+ // block 8
869
+ for (; i < ids_size - b8_delta; i += 8) {
870
+ // Load 8 values from weight vector
871
+ auto tmp = _mm_loadu_si128((__m128i*)&k[i]);
872
+ // ksource = [
873
+ // wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 wl_4 wh_4 wl_5 wh_5 wl_6 wh_6 wl_7 wh_7
874
+ // wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 wl_4 wh_4 wl_5 wh_5 wl_6 wh_6 wl_7 wh_7
875
+ // ]
876
+ auto ksource = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1);
877
+
878
+ // RGBA: Load 8 pixels from input:
879
+ // source = [
880
+ // r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3
881
+ // r4 g4 b4 a4 r5 g5 b5 a5 r6 g6 b6 a6 r7 g7 b7 a7
882
+ // ]
883
+ // RGB: Load 10 pixels from input (however we can process only 8 pixels):
884
+ // source = [
885
+ // r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5
886
+ // r4 g4 b4 r5 g5 b5 r6 g6 b6 r7 g7 b7 r8 g8 b8 r9
887
+ // ]
888
+ auto source = _mm256_inserti128_si256(_mm256_castsi128_si256(
889
+ _mm_loadu_si128((__m128i *) (lineIn_min + stride * i))),
890
+ _mm_loadu_si128((__m128i *) (lineIn_min + stride * (i + 4))), 1);
891
+
892
+ // Extract lower part of each lane, cast to epi16 and reoder RGBARGBA -> RRGGBBAA
893
+ // RGBA: pix1 = [
894
+ // r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0
895
+ // r4 0 r5 0 g4 0 g5 0 b4 0 b5 0 a4 0 a5 0
896
+ // ]
897
+ // RGB: pix1 = [
898
+ // r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0
899
+ // r4 0 r5 0 g4 0 g5 0 b4 0 b5 0 0 0 0 0
900
+ // ]
901
+ auto pix1 = _mm256_shuffle_epi8(source, mask_low);
902
+ // mmk1 = [
903
+ // wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ...
904
+ // wl_4 wh_4 wl_5 wh_5 wl_4 wh_4 wl_5 wh_5 ... ...
905
+ // ]
906
+ auto mmk1 = _mm256_shuffle_epi8(ksource, kmask_low);
907
+ // Compute output value as
908
+ // C += w0 * C0 + w1 * C1
909
+ // C += w4 * C4 + w5 * C5 for each channel in 32-bit precision
910
+ sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix1, mmk1));
911
+
912
+ // Same as above for higher part of each lane
913
+ auto pix2 = _mm256_shuffle_epi8(source, mask_high);
914
+ auto mmk2 = _mm256_shuffle_epi8(ksource, kmask_high);
915
+ // Compute output value as
916
+ // C += w2 * C2 + w3 * C3
917
+ // C += w6 * C6 + w7 * C7 for each channel in 32-bit precision
918
+ sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix2, mmk2));
919
+ }
920
+
921
+ // block 4
922
+ for (; i < ids_size - b4_delta; i += 4) {
923
+ // Load 4 values from weight vector
924
+ auto tmp = _mm_loadl_epi64((__m128i *) &k[i]);
925
+ // ksource = [
926
+ // wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 0 0 0 0 0 0 0 0
927
+ // wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 0 0 0 0 0 0 0 0
928
+ // ]
929
+ auto ksource = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1);
930
+
931
+ // Load pixels from input line
932
+ tmp = _mm_loadu_si128((__m128i *) (lineIn_min + stride * i));
933
+ // RGBA: source = [
934
+ // r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3
935
+ // r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3
936
+ // ]
937
+ // RGB: source = [
938
+ // r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5
939
+ // r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5
940
+ // ]
941
+ auto source = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1);
942
+
943
+ // Cast source to epi16 and reorder RGBARGBA -> RRGGBBAA
944
+ // RGBA: pix = [
945
+ // r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0
946
+ // r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 a2 0 a3 0
947
+ // ]
948
+ // RGB: pix = [
949
+ // r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0
950
+ // r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 0 0 0 0
951
+ // ]
952
+ auto pix = _mm256_shuffle_epi8(source, mask_hl);
953
+ // mmk = [
954
+ // wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ...
955
+ // wl_2 wh_2 wl_3 wh_3 wl_2 wh_2 wl_3 wh_3 ... ...
956
+ // ]
957
+ auto mmk = _mm256_shuffle_epi8(ksource, kmask_hl);
958
+ // Compute output value as
959
+ // C += w0 * C0 + w1 * C1
960
+ // C += w2 * C2 + w3 * C3 for each channel in 32-bit precision
961
+ sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix, mmk));
962
+ }
963
+
964
+ // Sum results between the lanes
965
+ sss = _mm_add_epi32(
966
+ _mm256_extracti128_si256(sss256, 0),
967
+ _mm256_extracti128_si256(sss256, 1));
968
+ }
969
+
970
+ // block 2
971
+ for (; i < ids_size - b2_delta; i += 2) {
972
+ // Load 2 values from weight vector
973
+ // mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ...]
974
+ auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]);
975
+ // Load pixels from input line
976
+ // RGBA: source = [
977
+ // r0 g0 b0 a0 r1 g1 b1 a1 0 0 0 0 0 0 0 0
978
+ // ]
979
+ // RGB: source = [
980
+ // r0 g0 b0 r1 g1 b1 r2 g2 0 0 0 0 0 0 0 0
981
+ // ]
982
+ auto source = _mm_loadl_epi64((__m128i *) (lineIn_min + stride * i));
983
+ // Cast source to epi16 and reorder RGBARGBA -> RRGGBBAA
984
+ auto pix = _mm_shuffle_epi8(source, mask_low128);
985
+ // Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision
986
+ sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
987
+ }
988
+
989
+ // block 1
990
+ const auto i32_aligned = num_channels == 4;
991
+ for (; i < ids_size - 1; i++) {
992
+ // Load 1 value from weight vector
993
+ // mmk = [wl_0 wh_0 0 0 wl_0 wh_0 0 0 ...]
994
+ auto mmk = _mm_set1_epi32(k[i]);
995
+ // Load one pixel from input line
996
+ // RGBA: pix = [
997
+ // r0 0 0 0 g0 0 0 0 b0 0 0 0 a0 0 0 0
998
+ // ]
999
+ // RGB: pix = [
1000
+ // r0 0 0 0 g0 0 0 0 b0 0 0 0 r1 0 0 0
1001
+ // ]
1002
+ auto pix = mm_cvtepu8_epi32(lineIn_min + stride * i, i32_aligned);
1003
+ // Compute output value as C += w0 * C0 for each channel in 32-bit precision
1004
+ sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
1005
+ }
1006
+
1007
+ if (i == ids_size - 1) {
1008
+ // last element
1009
+ auto mmk = _mm_set1_epi32(k[i]);
1010
+ __m128i pix;
1011
+ auto p = lineIn_min + stride * i;
1012
+ if (num_channels == 3 && C10_UNLIKELY(is_last_line && ids_min + stride * i + 4 >= max_in_x_strided)) {
1013
+ uint8_t input[4];
1014
+ std::memcpy(input, p, 3);
1015
+ pix = mm_cvtepu8_epi32(input, true);
1016
+ } else {
1017
+ pix = mm_cvtepu8_epi32(p, i32_aligned);
1018
+ }
1019
+ sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
1020
+ }
1021
+
1022
+ // Convert fixed point values back to integers (truncating)
1023
+ sss = _mm_srai_epi32(sss, coefs_precision);
1024
+ // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
1025
+ // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d 0 0 0 0 0 0 0 0)
1026
+ sss = _mm_packs_epi32(sss, zero);
1027
+ // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
1028
+ // (a a b b c c d d) -> (a b c d 0 0 0 0)
1029
+ sss = _mm_packus_epi16(sss, zero);
1030
+ // Write the output into single uint32
1031
+ // (a b c d) -> x_uint32
1032
+ auto o = _mm_cvtsi128_si32(sss);
1033
+ const auto out_x_strided = stride * out_x;
1034
+ if (num_channels == 3 && C10_UNLIKELY(out_x_strided + 4 >= max_out_x_strided)) {
1035
+ if (C10_UNLIKELY(is_last_line)) {
1036
+ // When we handle the last line, we can not access the next 4 bytes
1037
+ // as they are out of memory bounds.
1038
+ std::memcpy(lineOut + out_x_strided, (uint8_t *) &o, 3);
1039
+ } else {
1040
+ // Memcpy 4-bytes is faster than 3-bytes and this is a boundary case when we want to write
1041
+ // 4 bytes (R G B | X) to the output buffer (X1 X2 X3 | R1).
1042
+ // The 4th byte in the register (X) has a garbage value and 4th byte in the output buffer (R1) has a correct
1043
+ // value which was previously computed by another line. In other words, it means that we can not overwrite
1044
+ // it by simply writing 4 bytes from the register to the output. We'll do the following:
1045
+ // v----------|
1046
+ // Output = [... X1 X2 X3 | R1 G1 B1 R2 ...]
1047
+ // First, we write R1 value to the 4th byte of (R G B | X) -> (R G B | R1)
1048
+ // Second, we write 4 bytes from the register to the output: (X1 X2 X3 | R1) -> (R G B | R1)
1049
+ // Output = [... R G B | R1 G1 B1 R2 ...]
1050
+ _write_endline_rgb_as_uint32(lineOut + out_x_strided, o);
1051
+ }
1052
+ } else if (num_channels == 3) {
1053
+ // Memcpy 4-bytes is faster than 3-bytes and here
1054
+ // we simply write 4 bytes (... R G B X 0 0 0 0 0 ...) where X is a garbage value
1055
+ // that we will overwrite on the next iteration: (... R G B R G B X 0 0 ...)
1056
+ std::memcpy(lineOut + out_x_strided, (uint8_t *) &o, 4);
1057
+ } else {
1058
+ // num_channels = 4 -> lineOut + out_x_strided should be uint32 aligned
1059
+ *(uint32_t *)(lineOut + out_x_strided) = o;
1060
+ }
1061
+ }
1062
+ }
1063
+
1064
+ void ImagingResampleVerticalConvolution8u(
1065
+ uint8_t* C10_RESTRICT lineOut,
1066
+ const uint8_t* C10_RESTRICT lineIn,
1067
+ int64_t xsize,
1068
+ int64_t ids_min,
1069
+ int64_t ids_size,
1070
+ const int16_t* k,
1071
+ unsigned int coefs_precision,
1072
+ int64_t num_channels) {
1073
+
1074
+ // Interpolation vertical pass processing one line.
1075
+ // - We process x-axis data with blocks of 8, 2 and 1
1076
+ // - We split the size of weight vector for a given output index as a sum: K = n * 2 + m.
1077
+
1078
+ // xsize = output width, also equals to input width
1079
+ // ids_size = interpolation size
1080
+ // ids_min = input y start index
1081
+ const auto stride = num_channels * sizeof(uint8_t);
1082
+
1083
+ TORCH_INTERNAL_ASSERT(stride == 3 || stride == 4);
1084
+
1085
+ const int64_t data_size = xsize * stride;
1086
+ const int64_t data_stride = stride;
1087
+ constexpr auto vec_size = 256 / 8;
1088
+
1089
+ const auto initial = _mm_set1_epi32(1 << (coefs_precision - 1));
1090
+ const auto initial_256 = _mm256_set1_epi32(1 << (coefs_precision - 1));
1091
+ const auto zero = _mm_setzero_si128();
1092
+ const auto zero_256 = _mm256_setzero_si256();
1093
+
1094
+ int64_t j = 0;
1095
+ // block 8
1096
+ const auto b8_usable_vec_stride = (vec_size / data_stride) * data_stride;
1097
+ for (; j < data_size - vec_size; j += b8_usable_vec_stride) {
1098
+ auto sss0 = initial_256;
1099
+ auto sss1 = initial_256;
1100
+ auto sss2 = initial_256;
1101
+ auto sss3 = initial_256;
1102
+ int64_t i = 0;
1103
+ const auto * lineIn_min = lineIn + j + ids_min;
1104
+
1105
+ for (; i < ids_size - 1; i += 2) {
1106
+ // Load 2 values from weight vector
1107
+ auto mmk = _mm256_set1_epi32(*(int32_t*)&k[i]);
1108
+
1109
+ // RGBA: Load 8 pixels per line
1110
+ // source1 = [
1111
+ // r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3
1112
+ // r4 g4 b4 a4 r5 g5 b5 a5 r6 g6 b6 a6 r7 g7 b7 a7
1113
+ // ]
1114
+ // RGB: Load 10 pixels per line (however we can process only 8 pixels):
1115
+ // source1 = [
1116
+ // r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5
1117
+ // r4 g4 b4 r5 g5 b5 r6 g6 b6 r7 g7 b7 r8 g8 b8 r9
1118
+ // ]
1119
+ auto source1 =
1120
+ _mm256_loadu_si256((__m256i*)(lineIn_min + data_size * i));
1121
+ auto source2 =
1122
+ _mm256_loadu_si256((__m256i*)(lineIn_min + data_size * (i + 1)));
1123
+
1124
+ // Interleave source1 and source2 from the low half of each 128-bit lane
1125
+ // and cast the result to epi16
1126
+ // RGBA: pix1 = [
1127
+ // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 a0 0 A0 0
1128
+ // r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 a1 0 A1 0
1129
+ // ]
1130
+ // RGB: pix1 = [
1131
+ // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 0 0 0 0
1132
+ // r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 0 0 0 0
1133
+ // ]
1134
+ auto source_lo = _mm256_unpacklo_epi8(source1, source2);
1135
+ auto pix1 = _mm256_unpacklo_epi8(source_lo, zero_256);
1136
+ // Compute output value as
1137
+ // C += w0 * c0 + w1 * C0
1138
+ // C += w0 * c1 + w1 * C1 for each channel in 32-bit precision
1139
+ sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk));
1140
+
1141
+ // RGBA: pix2 = [
1142
+ // r2 0 R2 0 g2 0 G2 0 b2 0 B2 0 a2 0 A2 0
1143
+ // r3 0 R3 0 g3 0 G3 0 b3 0 B3 0 a3 0 A3 0
1144
+ // ]
1145
+ // RGB: pix2 = [
1146
+ // r2 0 R2 0 g2 0 G2 0 b2 0 B2 0 0 0 0 0
1147
+ // r3 0 R3 0 g3 0 G3 0 b3 0 B3 0 0 0 0 0
1148
+ // ]
1149
+ auto pix2 = _mm256_unpackhi_epi8(source_lo, zero_256);
1150
+ // Compute output value as
1151
+ // C += w0 * c2 + w1 * C2
1152
+ // C += w0 * c3 + w1 * C3 for each channel in 32-bit precision
1153
+ sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
1154
+
1155
+ // Same as above for the high half of each 128-bit lane
1156
+ auto source_hi = _mm256_unpackhi_epi8(source1, source2);
1157
+ auto pix3 = _mm256_unpacklo_epi8(source_hi, zero_256);
1158
+ sss2 = _mm256_add_epi32(sss2, _mm256_madd_epi16(pix3, mmk));
1159
+ auto pix4 = _mm256_unpackhi_epi8(source_hi, zero_256);
1160
+ sss3 = _mm256_add_epi32(sss3, _mm256_madd_epi16(pix4, mmk));
1161
+ }
1162
+ // Same processing as above but with a single weight value
1163
+ for (; i < ids_size; i += 1) {
1164
+ auto mmk = _mm256_set1_epi32(k[i]);
1165
+
1166
+ auto source1 = _mm256_loadu_si256((__m256i*)(lineIn_min + i * data_size));
1167
+
1168
+ auto source_lo = _mm256_unpacklo_epi8(source1, zero_256);
1169
+ auto pix1 = _mm256_unpacklo_epi8(source_lo, zero_256);
1170
+ sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk));
1171
+ auto pix2 = _mm256_unpackhi_epi8(source_lo, zero_256);
1172
+ sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
1173
+
1174
+ auto source_hi = _mm256_unpackhi_epi8(source1, zero_256);
1175
+ auto pix3 = _mm256_unpacklo_epi8(source_hi, _mm256_setzero_si256());
1176
+ sss2 = _mm256_add_epi32(sss2, _mm256_madd_epi16(pix3, mmk));
1177
+ auto pix4 = _mm256_unpackhi_epi8(source_hi, _mm256_setzero_si256());
1178
+ sss3 = _mm256_add_epi32(sss3, _mm256_madd_epi16(pix4, mmk));
1179
+ }
1180
+ // Convert fixed point values back to integers (truncating)
1181
+ sss0 = _mm256_srai_epi32(sss0, coefs_precision);
1182
+ sss1 = _mm256_srai_epi32(sss1, coefs_precision);
1183
+ sss2 = _mm256_srai_epi32(sss2, coefs_precision);
1184
+ sss3 = _mm256_srai_epi32(sss3, coefs_precision);
1185
+ // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
1186
+ // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d)
1187
+ sss0 = _mm256_packs_epi32(sss0, sss1);
1188
+ sss2 = _mm256_packs_epi32(sss2, sss3);
1189
+ // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
1190
+ // (a a b b c c d d) -> (a b c d)
1191
+ sss0 = _mm256_packus_epi16(sss0, sss2);
1192
+
1193
+ // Stores 32 bytes
1194
+ _mm256_storeu_si256((__m256i*)(lineOut + j), sss0);
1195
+ }
1196
+
1197
+ // TODO: Do we also need block 4 ???
1198
+ // block 2
1199
+ const auto b2_usable_vec_stride = (8 / data_stride) * data_stride;
1200
+ for (; j < data_size - vec_size / 4; j += b2_usable_vec_stride) {
1201
+ auto sss0 = initial;
1202
+ auto sss1 = initial;
1203
+ int64_t i = 0;
1204
+ const auto * lineIn_min = lineIn + j + ids_min;
1205
+
1206
+ for (; i < ids_size - 1; i += 2) {
1207
+ // Load 2 values from weight vector
1208
+ // mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ]
1209
+ auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]);
1210
+
1211
+ // Load 2 pixels per line
1212
+ // RGBA: source1 = [
1213
+ // r0 g0 b0 a0 r1 g1 b1 a1 0 0 0 0 0 0 0 0
1214
+ // ]
1215
+ // RGB: source1 = [
1216
+ // r0 g0 b0 r1 g1 b1 r2 g2 0 0 0 0 0 0 0 0
1217
+ // ]
1218
+ auto source1 = _mm_loadl_epi64((__m128i *) (lineIn_min + i * data_size));
1219
+ auto source2 = _mm_loadl_epi64((__m128i *) (lineIn_min + (i + 1) * data_size));
1220
+ // Interleave source1 and source2 and cast the result to epi16
1221
+ // RGBA: pix = [
1222
+ // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 a0 0 A0 0
1223
+ // ]
1224
+ // RGB: pix = [
1225
+ // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 0 0 0 0
1226
+ // ]
1227
+ auto source = _mm_unpacklo_epi8(source1, source2);
1228
+ auto pix = _mm_unpacklo_epi8(source, zero);
1229
+ // Compute output value as C += w0 * c0 + w1 * C0 for each channel in 32-bit precision
1230
+ sss0 = _mm_add_epi32(sss0, _mm_madd_epi16(pix, mmk));
1231
+ // RGBA: pix = [
1232
+ // r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 a1 0 A1 0
1233
+ // ]
1234
+ // RGB: pix = [
1235
+ // r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 0 0 0 0
1236
+ // ]
1237
+ pix = _mm_unpackhi_epi8(source, zero);
1238
+ // Compute output value as C += w0 * c1 + w1 * C1 for each channel in 32-bit precision
1239
+ sss1 = _mm_add_epi32(sss1, _mm_madd_epi16(pix, mmk));
1240
+ }
1241
+ // Same processing as above but with a single weight value
1242
+ for (; i < ids_size; i += 1) {
1243
+ auto mmk = _mm_set1_epi32(k[i]);
1244
+
1245
+ auto source1 = _mm_loadl_epi64((__m128i*) (lineIn_min + i * data_size));
1246
+
1247
+ auto source = _mm_unpacklo_epi8(source1, zero);
1248
+ auto pix1 = _mm_unpacklo_epi8(source, zero);
1249
+ sss0 = _mm_add_epi32(sss0, _mm_madd_epi16(pix1, mmk));
1250
+ auto pix2 = _mm_unpackhi_epi8(source, zero);
1251
+ sss1 = _mm_add_epi32(sss1, _mm_madd_epi16(pix2, mmk));
1252
+ }
1253
+ // Convert fixed point values back to integers (truncating)
1254
+ sss0 = _mm_srai_epi32(sss0, coefs_precision);
1255
+ sss1 = _mm_srai_epi32(sss1, coefs_precision);
1256
+ // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
1257
+ // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d)
1258
+ sss0 = _mm_packs_epi32(sss0, sss1);
1259
+ // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
1260
+ // (a a b b c c d d) -> (a b c d)
1261
+ sss0 = _mm_packus_epi16(sss0, sss0);
1262
+ // Store 2 pixels to the output
1263
+ _mm_storel_epi64((__m128i*)(lineOut + j), sss0);
1264
+ }
1265
+
1266
+ // block 1
1267
+ const auto b1_usable_vec_stride = (4 / data_stride) * data_stride;
1268
+ const auto i32_aligned = num_channels == 4;
1269
+ for (; j < data_size - 4; j += b1_usable_vec_stride) {
1270
+ auto sss = initial;
1271
+ int64_t i = 0;
1272
+ const auto * lineIn_min = lineIn + j + ids_min;
1273
+
1274
+ for (; i < ids_size - 1; i += 2) {
1275
+ // Load 2 values from weight vector
1276
+ // mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ]
1277
+ auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]);
1278
+
1279
+ // Load one pixel per line
1280
+ // RGBA: source1 = [
1281
+ // r0 g0 b0 a0 0 0 0 0 0 0 0 0 0 0 0 0
1282
+ // ]
1283
+ // RGB: source1 = [
1284
+ // r0 g0 b0 r1 0 0 0 0 0 0 0 0 0 0 0 0
1285
+ // ]
1286
+ auto source1 = mm_cvtsi32_si128(lineIn_min + i * data_size, i32_aligned);
1287
+ auto source2 = mm_cvtsi32_si128(lineIn_min + (i + 1) * data_size, i32_aligned);
1288
+
1289
+ // Interleave source1 and source2 and cast the result to epi16
1290
+ // RGBA: pix = [
1291
+ // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 a0 0 A0 0
1292
+ // ]
1293
+ // RGB: pix = [
1294
+ // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 0 0 0 0
1295
+ // ]
1296
+ auto source = _mm_unpacklo_epi8(source1, source2);
1297
+ auto pix = _mm_unpacklo_epi8(source, zero);
1298
+ // Compute output value as C += w0 * c0 + w1 * C0 for each channel in 32-bit precision
1299
+ sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
1300
+ }
1301
+
1302
+ for (; i < ids_size; i++) {
1303
+ auto mmk = _mm_set1_epi32(k[i]);
1304
+ auto pix = mm_cvtepu8_epi32(lineIn_min + i * data_size, i32_aligned);
1305
+ sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
1306
+ }
1307
+ sss = _mm_srai_epi32(sss, coefs_precision);
1308
+ sss = _mm_packs_epi32(sss, zero);
1309
+ sss = _mm_packus_epi16(sss, zero);
1310
+
1311
+ auto o = _mm_cvtsi128_si32(sss);
1312
+
1313
+ // Here we write 4 bytes to the output even if num_channels < 4, e.g o = {r,g,b,X} for num_channels=3
1314
+ // It is OK to write 4th byte (e.g. X) as on the next step we will overwrite it with new data.
1315
+ // We also wont go out of bounds of lineOut memory allocation
1316
+ std::memcpy(lineOut + j, (uint8_t *) &o, 4);
1317
+ }
1318
+
1319
+ for (; j < data_size; j += data_stride) {
1320
+ auto sss = initial;
1321
+ int64_t i = 0;
1322
+ const auto * lineIn_min = lineIn + j + ids_min;
1323
+ // For RGBA we can use (ids_size - 1) as tighter limit but for RGB we can read outside memory boundary
1324
+ // for the last remaining line
1325
+ for (; i < ids_size - 2; i += 2) {
1326
+ // Load two coefficients at once
1327
+ auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]);
1328
+
1329
+ // Load 2 lines
1330
+ auto source1 = mm_cvtsi32_si128(lineIn_min + i * data_size, i32_aligned);
1331
+ auto source2 = mm_cvtsi32_si128(lineIn_min + (i + 1) * data_size, i32_aligned);
1332
+
1333
+ auto source = _mm_unpacklo_epi8(source1, source2);
1334
+ auto pix = _mm_unpacklo_epi8(source, zero);
1335
+ sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
1336
+ }
1337
+
1338
+ // Same processing as above but with a single weight value
1339
+ for (; i < ids_size; i++) {
1340
+ auto mmk = _mm_set1_epi32(k[i]);
1341
+
1342
+ const uint8_t * p = lineIn_min + i * data_size;
1343
+ __m128i pix;
1344
+ // There is no much perf gain using more detailed condition like
1345
+ // num_channels == 3 && ids_min + j + data_size * i + 4 >= in_max_size
1346
+ // const int64_t in_max_size = data_size * in_ysize;
1347
+ if (num_channels == 3) {
1348
+ uint8_t input[4];
1349
+ std::memcpy(input, p, 3);
1350
+ pix = mm_cvtepu8_epi32(input, true);
1351
+ } else {
1352
+ pix = mm_cvtepu8_epi32(p, true);
1353
+ }
1354
+ sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
1355
+ }
1356
+
1357
+ // Convert fixed point values back to integers (truncating)
1358
+ sss = _mm_srai_epi32(sss, coefs_precision);
1359
+ // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
1360
+ // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d)
1361
+ sss = _mm_packs_epi32(sss, zero);
1362
+ // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
1363
+ // (a a b b c c d d) -> (a b c d)
1364
+ sss = _mm_packus_epi16(sss, zero);
1365
+ // Store one pixel to the output
1366
+ auto o = _mm_cvtsi128_si32(sss);
1367
+ if (num_channels == 3 && C10_UNLIKELY(j + 4 >= data_size)) {
1368
+ std::memcpy(lineOut + j, (uint8_t *) &o, 3);
1369
+ } else {
1370
+ std::memcpy(lineOut + j, (uint8_t *) &o, 4);
1371
+ }
1372
+ }
1373
+ }
1374
+
1375
+ } // anonymous namespace
1376
+ #endif // CPU_CAPABILITY_AVX2
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/WeightNormKernel.h ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/native/DispatchStub.h>
3
+ #include <cstdint>
4
+
5
+ namespace at {
6
+ class TensorBase;
7
+ }
8
+
9
+ namespace at::native {
10
+
11
+ using weight_norm_fn = void(*)(
12
+ TensorBase&, TensorBase&, const TensorBase&, const TensorBase&, int64_t);
13
+ using weight_norm_backward_fn = void(*)(
14
+ TensorBase&, TensorBase&, const TensorBase&, const TensorBase&,
15
+ const TensorBase&, const TensorBase&, int64_t);
16
+
17
+ DECLARE_DISPATCH(weight_norm_fn, weight_norm_stub);
18
+ DECLARE_DISPATCH(weight_norm_backward_fn, weight_norm_backward_stub);
19
+
20
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/avx_mathfun.h ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ /*
3
+ AVX implementation of sin, cos, sincos, exp and log
4
+
5
+ Based on "sse_mathfun.h", by Julien Pommier
6
+ http://gruntthepeon.free.fr/ssemath/
7
+
8
+ Copyright (C) 2012 Giovanni Garberoglio
9
+ Interdisciplinary Laboratory for Computational Science (LISC)
10
+ Fondazione Bruno Kessler and University of Trento
11
+ via Sommarive, 18
12
+ I-38123 Trento (Italy)
13
+
14
+ This software is provided 'as-is', without any express or implied
15
+ warranty. In no event will the authors be held liable for any damages
16
+ arising from the use of this software.
17
+
18
+ Permission is granted to anyone to use this software for any purpose,
19
+ including commercial applications, and to alter it and redistribute it
20
+ freely, subject to the following restrictions:
21
+
22
+ 1. The origin of this software must not be misrepresented; you must not
23
+ claim that you wrote the original software. If you use this software
24
+ in a product, an acknowledgment in the product documentation would be
25
+ appreciated but is not required.
26
+ 2. Altered source versions must be plainly marked as such, and must not be
27
+ misrepresented as being the original software.
28
+ 3. This notice may not be removed or altered from any source distribution.
29
+
30
+ (this is the zlib license)
31
+ */
32
+
33
+ #include <ATen/native/cpu/Intrinsics.h>
34
+
35
+ /* The original source of this file has been modified. */
36
+ #if defined(CPU_CAPABILITY_AVX2)
37
+
38
+ #if defined(__GNUC__)
39
+ # define ALIGN32_BEG __attribute__((aligned(32)))
40
+ #elif defined(_WIN32)
41
+ # define ALIGN32_BEG __declspec(align(32))
42
+ #endif
43
+
44
+ typedef __m256 v8sf; // vector of 8 float (avx2)
45
+ typedef __m256i v8si; // vector of 8 int (avx2)
46
+
47
+ /* declare some AVX constants -- why can't I figure a better way to do that? */
48
+ #define _PS256_CONST(Name, Val) \
49
+ static const ALIGN32_BEG float _ps256_##Name[8] = { Val, Val, Val, Val, Val, Val, Val, Val }
50
+ #define _PI32_CONST256(Name, Val) \
51
+ static const ALIGN32_BEG int _pi32_256_##Name[8] = { Val, Val, Val, Val, Val, Val, Val, Val }
52
+ #define _PS256_CONST_TYPE(Name, Type, Val) \
53
+ static const ALIGN32_BEG Type _ps256_##Name[8] = { Val, Val, Val, Val, Val, Val, Val, Val }
54
+
55
+ _PS256_CONST(1 , 1.0f);
56
+ _PS256_CONST(0p5, 0.5f);
57
+ /* the smallest non denormalized float number */
58
+ _PS256_CONST_TYPE(min_norm_pos, int, 0x00800000);
59
+ _PS256_CONST_TYPE(mant_mask, int, 0x7f800000);
60
+ _PS256_CONST_TYPE(inv_mant_mask, int, ~0x7f800000);
61
+
62
+ _PS256_CONST_TYPE(sign_mask, int, (int)0x80000000);
63
+ _PS256_CONST_TYPE(inv_sign_mask, int, ~0x80000000);
64
+
65
+ _PI32_CONST256(0, 0);
66
+ _PI32_CONST256(1, 1);
67
+ _PI32_CONST256(inv1, ~1);
68
+ _PI32_CONST256(2, 2);
69
+ _PI32_CONST256(4, 4);
70
+ _PI32_CONST256(0x7f, 0x7f);
71
+
72
+ _PS256_CONST(cephes_SQRTHF, 0.707106781186547524);
73
+ _PS256_CONST(cephes_log_p0, 7.0376836292E-2);
74
+ _PS256_CONST(cephes_log_p1, - 1.1514610310E-1);
75
+ _PS256_CONST(cephes_log_p2, 1.1676998740E-1);
76
+ _PS256_CONST(cephes_log_p3, - 1.2420140846E-1);
77
+ _PS256_CONST(cephes_log_p4, + 1.4249322787E-1);
78
+ _PS256_CONST(cephes_log_p5, - 1.6668057665E-1);
79
+ _PS256_CONST(cephes_log_p6, + 2.0000714765E-1);
80
+ _PS256_CONST(cephes_log_p7, - 2.4999993993E-1);
81
+ _PS256_CONST(cephes_log_p8, + 3.3333331174E-1);
82
+ _PS256_CONST(cephes_log_q1, -2.12194440e-4);
83
+ _PS256_CONST(cephes_log_q2, 0.693359375);
84
+
85
+
86
+ /* natural logarithm computed for 8 simultaneous float
87
+ return NaN for x <= 0
88
+ */
89
+ inline v8sf log256_ps(v8sf x) {
90
+ v8si imm0;
91
+ v8sf one = *(v8sf*)_ps256_1;
92
+
93
+ //v8sf invalid_mask = _mm256_cmple_ps(x, _mm256_setzero_ps());
94
+ v8sf invalid_mask = _mm256_cmp_ps(x, _mm256_setzero_ps(), _CMP_LE_OS);
95
+
96
+ x = _mm256_max_ps(x, *(v8sf*)_ps256_min_norm_pos); /* cut off denormalized stuff */
97
+
98
+ // can be done with AVX2
99
+ imm0 = _mm256_srli_epi32(_mm256_castps_si256(x), 23);
100
+
101
+ /* keep only the fractional part */
102
+ x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_mant_mask);
103
+ x = _mm256_or_ps(x, *(v8sf*)_ps256_0p5);
104
+
105
+ // this is again another AVX2 instruction
106
+ imm0 = _mm256_sub_epi32(imm0, *(v8si*)_pi32_256_0x7f);
107
+ v8sf e = _mm256_cvtepi32_ps(imm0);
108
+
109
+ e = _mm256_add_ps(e, one);
110
+
111
+ /* part2:
112
+ if( x < SQRTHF ) {
113
+ e -= 1;
114
+ x = x + x - 1.0;
115
+ } else { x = x - 1.0; }
116
+ */
117
+ //v8sf mask = _mm256_cmplt_ps(x, *(v8sf*)_ps256_cephes_SQRTHF);
118
+ v8sf mask = _mm256_cmp_ps(x, *(v8sf*)_ps256_cephes_SQRTHF, _CMP_LT_OS);
119
+ v8sf tmp = _mm256_and_ps(x, mask);
120
+ x = _mm256_sub_ps(x, one);
121
+ e = _mm256_sub_ps(e, _mm256_and_ps(one, mask));
122
+ x = _mm256_add_ps(x, tmp);
123
+
124
+ v8sf z = _mm256_mul_ps(x,x);
125
+
126
+ v8sf y = *(v8sf*)_ps256_cephes_log_p0;
127
+ y = _mm256_mul_ps(y, x);
128
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p1);
129
+ y = _mm256_mul_ps(y, x);
130
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p2);
131
+ y = _mm256_mul_ps(y, x);
132
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p3);
133
+ y = _mm256_mul_ps(y, x);
134
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p4);
135
+ y = _mm256_mul_ps(y, x);
136
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p5);
137
+ y = _mm256_mul_ps(y, x);
138
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p6);
139
+ y = _mm256_mul_ps(y, x);
140
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p7);
141
+ y = _mm256_mul_ps(y, x);
142
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p8);
143
+ y = _mm256_mul_ps(y, x);
144
+
145
+ y = _mm256_mul_ps(y, z);
146
+
147
+ tmp = _mm256_mul_ps(e, *(v8sf*)_ps256_cephes_log_q1);
148
+ y = _mm256_add_ps(y, tmp);
149
+
150
+
151
+ tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5);
152
+ y = _mm256_sub_ps(y, tmp);
153
+
154
+ tmp = _mm256_mul_ps(e, *(v8sf*)_ps256_cephes_log_q2);
155
+ x = _mm256_add_ps(x, y);
156
+ x = _mm256_add_ps(x, tmp);
157
+ x = _mm256_or_ps(x, invalid_mask); // negative arg will be NAN
158
+ return x;
159
+ }
160
+
161
+ _PS256_CONST(exp_hi, 88.3762626647949f);
162
+ _PS256_CONST(exp_lo, -88.3762626647949f);
163
+
164
+ _PS256_CONST(cephes_LOG2EF, 1.44269504088896341);
165
+ _PS256_CONST(cephes_exp_C1, 0.693359375);
166
+ _PS256_CONST(cephes_exp_C2, -2.12194440e-4);
167
+
168
+ _PS256_CONST(cephes_exp_p0, 1.9875691500E-4);
169
+ _PS256_CONST(cephes_exp_p1, 1.3981999507E-3);
170
+ _PS256_CONST(cephes_exp_p2, 8.3334519073E-3);
171
+ _PS256_CONST(cephes_exp_p3, 4.1665795894E-2);
172
+ _PS256_CONST(cephes_exp_p4, 1.6666665459E-1);
173
+ _PS256_CONST(cephes_exp_p5, 5.0000001201E-1);
174
+
175
+ inline v8sf exp256_ps(v8sf x) {
176
+ v8sf tmp = _mm256_setzero_ps(), fx;
177
+ v8si imm0;
178
+ v8sf one = *(v8sf*)_ps256_1;
179
+
180
+ x = _mm256_min_ps(x, *(v8sf*)_ps256_exp_hi);
181
+ x = _mm256_max_ps(x, *(v8sf*)_ps256_exp_lo);
182
+
183
+ /* express exp(x) as exp(g + n*log(2)) */
184
+ fx = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_LOG2EF);
185
+ fx = _mm256_add_ps(fx, *(v8sf*)_ps256_0p5);
186
+
187
+ /* how to perform a floorf with SSE: just below */
188
+ //imm0 = _mm256_cvttps_epi32(fx);
189
+ //tmp = _mm256_cvtepi32_ps(imm0);
190
+
191
+ tmp = _mm256_floor_ps(fx);
192
+
193
+ /* if greater, subtract 1 */
194
+ //v8sf mask = _mm256_cmpgt_ps(tmp, fx);
195
+ v8sf mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS);
196
+ mask = _mm256_and_ps(mask, one);
197
+ fx = _mm256_sub_ps(tmp, mask);
198
+
199
+ tmp = _mm256_mul_ps(fx, *(v8sf*)_ps256_cephes_exp_C1);
200
+ v8sf z = _mm256_mul_ps(fx, *(v8sf*)_ps256_cephes_exp_C2);
201
+ x = _mm256_sub_ps(x, tmp);
202
+ x = _mm256_sub_ps(x, z);
203
+
204
+ z = _mm256_mul_ps(x,x);
205
+
206
+ v8sf y = *(v8sf*)_ps256_cephes_exp_p0;
207
+ y = _mm256_mul_ps(y, x);
208
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p1);
209
+ y = _mm256_mul_ps(y, x);
210
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p2);
211
+ y = _mm256_mul_ps(y, x);
212
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p3);
213
+ y = _mm256_mul_ps(y, x);
214
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p4);
215
+ y = _mm256_mul_ps(y, x);
216
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p5);
217
+ y = _mm256_mul_ps(y, z);
218
+ y = _mm256_add_ps(y, x);
219
+ y = _mm256_add_ps(y, one);
220
+
221
+ /* build 2^n */
222
+ imm0 = _mm256_cvttps_epi32(fx);
223
+ // another two AVX2 instructions
224
+ imm0 = _mm256_add_epi32(imm0, *(v8si*)_pi32_256_0x7f);
225
+ imm0 = _mm256_slli_epi32(imm0, 23);
226
+ v8sf pow2n = _mm256_castsi256_ps(imm0);
227
+ y = _mm256_mul_ps(y, pow2n);
228
+ return y;
229
+ }
230
+
231
+ _PS256_CONST(minus_cephes_DP1, -0.78515625);
232
+ _PS256_CONST(minus_cephes_DP2, -2.4187564849853515625e-4);
233
+ _PS256_CONST(minus_cephes_DP3, -3.77489497744594108e-8);
234
+ _PS256_CONST(sincof_p0, -1.9515295891E-4);
235
+ _PS256_CONST(sincof_p1, 8.3321608736E-3);
236
+ _PS256_CONST(sincof_p2, -1.6666654611E-1);
237
+ _PS256_CONST(coscof_p0, 2.443315711809948E-005);
238
+ _PS256_CONST(coscof_p1, -1.388731625493765E-003);
239
+ _PS256_CONST(coscof_p2, 4.166664568298827E-002);
240
+ _PS256_CONST(cephes_FOPI, 1.27323954473516); // 4 / M_PI
241
+
242
+
243
+ /* evaluation of 8 sines at onces using AVX intrinsics
244
+
245
+ The code is the exact rewriting of the cephes sinf function.
246
+ Precision is excellent as long as x < 8192 (I did not bother to
247
+ take into account the special handling they have for greater values
248
+ -- it does not return garbage for arguments over 8192, though, but
249
+ the extra precision is missing).
250
+
251
+ Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the
252
+ surprising but correct result.
253
+
254
+ */
255
+ inline v8sf sin256_ps(v8sf x) { // any x
256
+ v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, sign_bit, y;
257
+ v8si imm0, imm2;
258
+
259
+ sign_bit = x;
260
+ /* take the absolute value */
261
+ x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);
262
+ /* extract the sign bit (upper one) */
263
+ sign_bit = _mm256_and_ps(sign_bit, *(v8sf*)_ps256_sign_mask);
264
+
265
+ /* scale by 4/Pi */
266
+ y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI);
267
+
268
+ /*
269
+ Here we start a series of integer operations, which are in the
270
+ realm of AVX2.
271
+ If we don't have AVX, let's perform them using SSE2 directives
272
+ */
273
+
274
+ /* store the integer part of y in mm0 */
275
+ imm2 = _mm256_cvttps_epi32(y);
276
+ /* j=(j+1) & (~1) (see the cephes sources) */
277
+ // another two AVX2 instruction
278
+ imm2 = _mm256_add_epi32(imm2, *(v8si*)_pi32_256_1);
279
+ imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_inv1);
280
+ y = _mm256_cvtepi32_ps(imm2);
281
+
282
+ /* get the swap sign flag */
283
+ imm0 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_4);
284
+ imm0 = _mm256_slli_epi32(imm0, 29);
285
+ /* get the polynom selection mask
286
+ there is one polynom for 0 <= x <= Pi/4
287
+ and another one for Pi/4<x<=Pi/2
288
+
289
+ Both branches will be computed.
290
+ */
291
+ imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2);
292
+ imm2 = _mm256_cmpeq_epi32(imm2,*(v8si*)_pi32_256_0);
293
+
294
+ v8sf swap_sign_bit = _mm256_castsi256_ps(imm0);
295
+ v8sf poly_mask = _mm256_castsi256_ps(imm2);
296
+ sign_bit = _mm256_xor_ps(sign_bit, swap_sign_bit);
297
+
298
+ /* The magic pass: "Extended precision modular arithmetic"
299
+ x = ((x - y * DP1) - y * DP2) - y * DP3; */
300
+ xmm1 = *(v8sf*)_ps256_minus_cephes_DP1;
301
+ xmm2 = *(v8sf*)_ps256_minus_cephes_DP2;
302
+ xmm3 = *(v8sf*)_ps256_minus_cephes_DP3;
303
+ xmm1 = _mm256_mul_ps(y, xmm1);
304
+ xmm2 = _mm256_mul_ps(y, xmm2);
305
+ xmm3 = _mm256_mul_ps(y, xmm3);
306
+ x = _mm256_add_ps(x, xmm1);
307
+ x = _mm256_add_ps(x, xmm2);
308
+ x = _mm256_add_ps(x, xmm3);
309
+
310
+ /* Evaluate the first polynom (0 <= x <= Pi/4) */
311
+ y = *(v8sf*)_ps256_coscof_p0;
312
+ v8sf z = _mm256_mul_ps(x,x);
313
+
314
+ y = _mm256_mul_ps(y, z);
315
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p1);
316
+ y = _mm256_mul_ps(y, z);
317
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p2);
318
+ y = _mm256_mul_ps(y, z);
319
+ y = _mm256_mul_ps(y, z);
320
+ v8sf tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5);
321
+ y = _mm256_sub_ps(y, tmp);
322
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_1);
323
+
324
+ /* Evaluate the second polynom (Pi/4 <= x <= 0) */
325
+
326
+ v8sf y2 = *(v8sf*)_ps256_sincof_p0;
327
+ y2 = _mm256_mul_ps(y2, z);
328
+ y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p1);
329
+ y2 = _mm256_mul_ps(y2, z);
330
+ y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p2);
331
+ y2 = _mm256_mul_ps(y2, z);
332
+ y2 = _mm256_mul_ps(y2, x);
333
+ y2 = _mm256_add_ps(y2, x);
334
+
335
+ /* select the correct result from the two polynoms */
336
+ xmm3 = poly_mask;
337
+ y2 = _mm256_and_ps(xmm3, y2); //, xmm3);
338
+ y = _mm256_andnot_ps(xmm3, y);
339
+ y = _mm256_add_ps(y,y2);
340
+ /* update the sign */
341
+ y = _mm256_xor_ps(y, sign_bit);
342
+
343
+ return y;
344
+ }
345
+
346
+ /* almost the same as sin_ps */
347
+ inline v8sf cos256_ps(v8sf x) { // any x
348
+ v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, y;
349
+ v8si imm0, imm2;
350
+
351
+ /* take the absolute value */
352
+ x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);
353
+
354
+ /* scale by 4/Pi */
355
+ y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI);
356
+
357
+ /* store the integer part of y in mm0 */
358
+ imm2 = _mm256_cvttps_epi32(y);
359
+ /* j=(j+1) & (~1) (see the cephes sources) */
360
+ imm2 = _mm256_add_epi32(imm2, *(v8si*)_pi32_256_1);
361
+ imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_inv1);
362
+ y = _mm256_cvtepi32_ps(imm2);
363
+ imm2 = _mm256_sub_epi32(imm2, *(v8si*)_pi32_256_2);
364
+
365
+ /* get the swap sign flag */
366
+ imm0 = _mm256_andnot_si256(imm2, *(v8si*)_pi32_256_4);
367
+ imm0 = _mm256_slli_epi32(imm0, 29);
368
+ /* get the polynom selection mask */
369
+ imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2);
370
+ imm2 = _mm256_cmpeq_epi32(imm2, *(v8si*)_pi32_256_0);
371
+
372
+ v8sf sign_bit = _mm256_castsi256_ps(imm0);
373
+ v8sf poly_mask = _mm256_castsi256_ps(imm2);
374
+
375
+ /* The magic pass: "Extended precision modular arithmetic"
376
+ x = ((x - y * DP1) - y * DP2) - y * DP3; */
377
+ xmm1 = *(v8sf*)_ps256_minus_cephes_DP1;
378
+ xmm2 = *(v8sf*)_ps256_minus_cephes_DP2;
379
+ xmm3 = *(v8sf*)_ps256_minus_cephes_DP3;
380
+ xmm1 = _mm256_mul_ps(y, xmm1);
381
+ xmm2 = _mm256_mul_ps(y, xmm2);
382
+ xmm3 = _mm256_mul_ps(y, xmm3);
383
+ x = _mm256_add_ps(x, xmm1);
384
+ x = _mm256_add_ps(x, xmm2);
385
+ x = _mm256_add_ps(x, xmm3);
386
+
387
+ /* Evaluate the first polynom (0 <= x <= Pi/4) */
388
+ y = *(v8sf*)_ps256_coscof_p0;
389
+ v8sf z = _mm256_mul_ps(x,x);
390
+
391
+ y = _mm256_mul_ps(y, z);
392
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p1);
393
+ y = _mm256_mul_ps(y, z);
394
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p2);
395
+ y = _mm256_mul_ps(y, z);
396
+ y = _mm256_mul_ps(y, z);
397
+ v8sf tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5);
398
+ y = _mm256_sub_ps(y, tmp);
399
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_1);
400
+
401
+ /* Evaluate the second polynom (Pi/4 <= x <= 0) */
402
+
403
+ v8sf y2 = *(v8sf*)_ps256_sincof_p0;
404
+ y2 = _mm256_mul_ps(y2, z);
405
+ y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p1);
406
+ y2 = _mm256_mul_ps(y2, z);
407
+ y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p2);
408
+ y2 = _mm256_mul_ps(y2, z);
409
+ y2 = _mm256_mul_ps(y2, x);
410
+ y2 = _mm256_add_ps(y2, x);
411
+
412
+ /* select the correct result from the two polynoms */
413
+ xmm3 = poly_mask;
414
+ y2 = _mm256_and_ps(xmm3, y2); //, xmm3);
415
+ y = _mm256_andnot_ps(xmm3, y);
416
+ y = _mm256_add_ps(y,y2);
417
+ /* update the sign */
418
+ y = _mm256_xor_ps(y, sign_bit);
419
+
420
+ return y;
421
+ }
422
+
423
+ /* since sin256_ps and cos256_ps are almost identical, sincos256_ps could replace both of them..
424
+ it is almost as fast, and gives you a free cosine with your sine */
425
+ inline void sincos256_ps(v8sf x, v8sf *s, v8sf *c) {
426
+
427
+ v8sf xmm1, xmm2, xmm3 = _mm256_setzero_ps(), sign_bit_sin, y;
428
+ v8si imm0, imm2, imm4;
429
+
430
+ sign_bit_sin = x;
431
+ /* take the absolute value */
432
+ x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);
433
+ /* extract the sign bit (upper one) */
434
+ sign_bit_sin = _mm256_and_ps(sign_bit_sin, *(v8sf*)_ps256_sign_mask);
435
+
436
+ /* scale by 4/Pi */
437
+ y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI);
438
+
439
+ /* store the integer part of y in imm2 */
440
+ imm2 = _mm256_cvttps_epi32(y);
441
+
442
+ /* j=(j+1) & (~1) (see the cephes sources) */
443
+ imm2 = _mm256_add_epi32(imm2, *(v8si*)_pi32_256_1);
444
+ imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_inv1);
445
+
446
+ y = _mm256_cvtepi32_ps(imm2);
447
+ imm4 = imm2;
448
+
449
+ /* get the swap sign flag for the sine */
450
+ imm0 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_4);
451
+ imm0 = _mm256_slli_epi32(imm0, 29);
452
+ //v8sf swap_sign_bit_sin = _mm256_castsi256_ps(imm0);
453
+
454
+ /* get the polynom selection mask for the sine*/
455
+ imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2);
456
+ imm2 = _mm256_cmpeq_epi32(imm2, *(v8si*)_pi32_256_0);
457
+ //v8sf poly_mask = _mm256_castsi256_ps(imm2);
458
+
459
+ v8sf swap_sign_bit_sin = _mm256_castsi256_ps(imm0);
460
+ v8sf poly_mask = _mm256_castsi256_ps(imm2);
461
+
462
+ /* The magic pass: "Extended precision modular arithmetic"
463
+ x = ((x - y * DP1) - y * DP2) - y * DP3; */
464
+ xmm1 = *(v8sf*)_ps256_minus_cephes_DP1;
465
+ xmm2 = *(v8sf*)_ps256_minus_cephes_DP2;
466
+ xmm3 = *(v8sf*)_ps256_minus_cephes_DP3;
467
+ xmm1 = _mm256_mul_ps(y, xmm1);
468
+ xmm2 = _mm256_mul_ps(y, xmm2);
469
+ xmm3 = _mm256_mul_ps(y, xmm3);
470
+ x = _mm256_add_ps(x, xmm1);
471
+ x = _mm256_add_ps(x, xmm2);
472
+ x = _mm256_add_ps(x, xmm3);
473
+
474
+ imm4 = _mm256_sub_epi32(imm4, *(v8si*)_pi32_256_2);
475
+ imm4 = _mm256_andnot_si256(imm4, *(v8si*)_pi32_256_4);
476
+ imm4 = _mm256_slli_epi32(imm4, 29);
477
+
478
+ v8sf sign_bit_cos = _mm256_castsi256_ps(imm4);
479
+
480
+ sign_bit_sin = _mm256_xor_ps(sign_bit_sin, swap_sign_bit_sin);
481
+
482
+ /* Evaluate the first polynom (0 <= x <= Pi/4) */
483
+ v8sf z = _mm256_mul_ps(x,x);
484
+ y = *(v8sf*)_ps256_coscof_p0;
485
+
486
+ y = _mm256_mul_ps(y, z);
487
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p1);
488
+ y = _mm256_mul_ps(y, z);
489
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p2);
490
+ y = _mm256_mul_ps(y, z);
491
+ y = _mm256_mul_ps(y, z);
492
+ v8sf tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5);
493
+ y = _mm256_sub_ps(y, tmp);
494
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_1);
495
+
496
+ /* Evaluate the second polynom (Pi/4 <= x <= 0) */
497
+
498
+ v8sf y2 = *(v8sf*)_ps256_sincof_p0;
499
+ y2 = _mm256_mul_ps(y2, z);
500
+ y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p1);
501
+ y2 = _mm256_mul_ps(y2, z);
502
+ y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p2);
503
+ y2 = _mm256_mul_ps(y2, z);
504
+ y2 = _mm256_mul_ps(y2, x);
505
+ y2 = _mm256_add_ps(y2, x);
506
+
507
+ /* select the correct result from the two polynoms */
508
+ xmm3 = poly_mask;
509
+ v8sf ysin2 = _mm256_and_ps(xmm3, y2);
510
+ v8sf ysin1 = _mm256_andnot_ps(xmm3, y);
511
+ y2 = _mm256_sub_ps(y2,ysin2);
512
+ y = _mm256_sub_ps(y, ysin1);
513
+
514
+ xmm1 = _mm256_add_ps(ysin1,ysin2);
515
+ xmm2 = _mm256_add_ps(y,y2);
516
+
517
+ /* update the sign */
518
+ *s = _mm256_xor_ps(xmm1, sign_bit_sin);
519
+ *c = _mm256_xor_ps(xmm2, sign_bit_cos);
520
+ }
521
+
522
+ #endif // CPU_CAPABILITY_AVX2
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/int_mm_kernel.h ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/native/DispatchStub.h>
5
+
6
+ namespace at::native {
7
+
8
+ using weight_to_int4pack_fn = void(*)(const Tensor&, const Tensor&, int, int);
9
+ using int4pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, int, const Tensor&, int, int);
10
+ using int8pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&);
11
+
12
+ DECLARE_DISPATCH(weight_to_int4pack_fn, weight_to_int4pack_stub);
13
+ DECLARE_DISPATCH(int4pack_mm_fn, int4pack_mm_stub);
14
+ DECLARE_DISPATCH(int8pack_mm_fn, int8pack_mm_stub);
15
+
16
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/mixed_data_type.h ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+
5
+ namespace at::native {
6
+
7
+ inline ScalarType first_type() {
8
+ return ScalarType::Undefined;
9
+ }
10
+
11
+ template <typename... Args>
12
+ inline ScalarType first_type(const Tensor& arg, const Args&... parameters) {
13
+ return arg.defined() ? arg.scalar_type() : first_type(parameters...);
14
+ }
15
+
16
+ template <typename... Args>
17
+ inline bool is_mixed_type(const Tensor& input, const Args&... parameters) {
18
+ const auto parameter_type = first_type(parameters...);
19
+ return ((parameter_type != ScalarType::Undefined) &&
20
+ (parameter_type != input.scalar_type()));
21
+ }
22
+
23
+ // currently on CPU, mixed data type is only supported
24
+ // when input is 'BFloat16' or 'Half' and parameters are 'Float'
25
+ inline void check_mixed_data_type(const Tensor& input) {
26
+ TORCH_CHECK(at::isReducedFloatingType(input.scalar_type()),
27
+ "mixed dtype (CPU): all inputs must share same datatype.");
28
+ }
29
+
30
+ template <typename... Args>
31
+ inline void check_mixed_data_type(const Tensor& input, const Tensor& parameter, const Args&... parameters) {
32
+ TORCH_CHECK(!parameter.defined() || parameter.scalar_type() == ScalarType::Float,
33
+ "mixed dtype (CPU): expect parameter to have scalar type of Float");
34
+ check_mixed_data_type(input, parameters...);
35
+ }
36
+
37
+ inline ScalarType param_scalar_type(const Tensor& t, bool is_mixed_type) {
38
+ return is_mixed_type ? ScalarType::Float : t.scalar_type();
39
+ }
40
+
41
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/moments_utils.h ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <array>
4
+ #include <cstring>
5
+ #include <utility>
6
+
7
+ #include <ATen/Parallel.h>
8
+ #include <ATen/OpMathType.h>
9
+ #include <ATen/cpu/vec/vec.h>
10
+ #include <ATen/native/cpu/utils.h>
11
+ #include <c10/util/SmallVector.h>
12
+ #include <c10/util/irange.h>
13
+
14
+ namespace at::native {
15
+ inline namespace CPU_CAPABILITY {
16
+
17
+ template<typename T> using opmath_t = at::opmath_type<T>;
18
+
19
+ constexpr int64_t kChunkSize = 16;
20
+
21
+ template <typename T>
22
+ void AddMoments(
23
+ int64_t m0_add,
24
+ const T& m1_add,
25
+ const T& m2_add,
26
+ int64_t& m0,
27
+ T& m1,
28
+ T& m2) {
29
+ const int64_t n = m0 + m0_add;
30
+ const T c = n == 0 ? static_cast<T>(0) : static_cast<T>(m0_add) / static_cast<T>(n);
31
+ const T delta = m1_add - m1;
32
+ m1 += c * delta;
33
+ m2 += m2_add + delta * delta * c * static_cast<T>(m0);
34
+ m0 = n;
35
+ }
36
+
37
+ template <typename T>
38
+ C10_ALWAYS_INLINE void AddMomentsVec(
39
+ int64_t m0_add,
40
+ const vec::Vectorized<T>& m1_add,
41
+ const vec::Vectorized<T>& m2_add,
42
+ int64_t& m0,
43
+ vec::Vectorized<T>& m1,
44
+ vec::Vectorized<T>& m2) {
45
+ using Vec = vec::Vectorized<T>;
46
+ const int64_t n = m0 + m0_add;
47
+ const T c = n == 0 ? static_cast<T>(0) : static_cast<T>(m0_add) / static_cast<T>(n);
48
+ const Vec c_vec(c);
49
+ const Vec delta = m1_add - m1;
50
+ m1 += c_vec * delta;
51
+ m2 += m2_add + delta * delta * c_vec * Vec(static_cast<T>(m0));
52
+ m0 = n;
53
+ }
54
+
55
+ template <typename T>
56
+ inline std::enable_if_t<std::is_same_v<T, opmath_t<T>>, void>
57
+ UpdateMomentsVec(
58
+ int64_t m0,
59
+ const T* X_ptr,
60
+ const std::array<vec::Vectorized<opmath_t<T>>, kChunkSize>& c_vecs,
61
+ int64_t& m0_stk0,
62
+ vec::Vectorized<opmath_t<T>>& m1_stk0,
63
+ vec::Vectorized<opmath_t<T>>& m2_stk0) {
64
+ using Vec = vec::Vectorized<opmath_t<T>>;
65
+ Vec m1_vec(0);
66
+ Vec m2_vec(0);
67
+ for (const auto j : c10::irange(m0)) {
68
+ const Vec x_vec = Vec::loadu(X_ptr + j * Vec::size());
69
+ const Vec delta_vec = x_vec - m1_vec;
70
+ m1_vec += delta_vec * c_vecs[j];
71
+ m2_vec += delta_vec * (x_vec - m1_vec);
72
+ }
73
+ AddMomentsVec(m0, m1_vec, m2_vec, m0_stk0, m1_stk0, m2_stk0);
74
+ }
75
+
76
+ // each bfloat16/half vector will be converted to two float vectors,
77
+ // and accumulated successively on m1_stk0/m2_stk0.
78
+ template <typename T>
79
+ inline std::enable_if_t<!std::is_same_v<T, at::opmath_type<T>>, void>
80
+ UpdateMomentsVec(
81
+ int64_t m0,
82
+ const T* X_ptr,
83
+ const std::array<vec::Vectorized<at::opmath_type<T>>, kChunkSize>& c_vecs,
84
+ int64_t& m0_stk0,
85
+ vec::Vectorized<at::opmath_type<T>>& m1_stk0,
86
+ vec::Vectorized<at::opmath_type<T>>& m2_stk0) {
87
+ using Vec = vec::Vectorized<T>;
88
+ using fVec = vec::Vectorized<at::opmath_type<T>>;
89
+ fVec m1_fvec0(0), m1_fvec1(0);
90
+ fVec m2_fvec0(0), m2_fvec1(0);
91
+ for (const auto j : c10::irange(m0)) {
92
+ const Vec x_bvec = Vec::loadu(X_ptr + j * Vec::size());
93
+ auto [x_fvec0, x_fvec1] = convert_to_float<T>(x_bvec);
94
+ const fVec delta_fvec0 = x_fvec0 - m1_fvec0;
95
+ const fVec delta_fvec1 = x_fvec1 - m1_fvec1;
96
+ m1_fvec0 += delta_fvec0 * c_vecs[j];
97
+ m1_fvec1 += delta_fvec1 * c_vecs[j];
98
+ m2_fvec0 += delta_fvec0 * (x_fvec0 - m1_fvec0);
99
+ m2_fvec1 += delta_fvec1 * (x_fvec1 - m1_fvec1);
100
+ }
101
+ AddMomentsVec(m0, m1_fvec0, m2_fvec0, m0_stk0, m1_stk0, m2_stk0);
102
+ AddMomentsVec(m0, m1_fvec1, m2_fvec1, m0_stk0, m1_stk0, m2_stk0);
103
+ }
104
+
105
+ // Compute rowwise moments by Welford algorithm and cascade sum to improve
106
+ // numerical stability.
107
+ // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
108
+ // https://en.wikipedia.org/wiki/Pairwise_summation
109
+ template <typename T, int64_t kMaxDepth>
110
+ std::pair<opmath_t<T>, opmath_t<T>> RowwiseMomentsImpl(const T* X, int64_t N, int64_t ddof = 0) {
111
+ using math_t = opmath_t<T>;
112
+
113
+ constexpr int64_t kVecSize = vec::Vectorized<T>::size();
114
+ constexpr int64_t kAccVecSize = vec::Vectorized<math_t>::size();
115
+ const int64_t n = N / kVecSize;
116
+ const int64_t m = divup(n, kChunkSize);
117
+ const int64_t depth = utils::CeilLog2(m);
118
+
119
+ using Vec = vec::Vectorized<math_t>;
120
+ const Vec kZeroVec(math_t(0));
121
+ c10::SmallVector<int64_t, kMaxDepth> m0_stk(depth, 0);
122
+ c10::SmallVector<Vec, kMaxDepth> m1_stk(depth, kZeroVec);
123
+ c10::SmallVector<Vec, kMaxDepth> m2_stk(depth, kZeroVec);
124
+
125
+ for (const auto i : c10::irange(m)) {
126
+ const T* X_ptr = X + i * kChunkSize * kVecSize;
127
+ const int64_t m0 = std::min(kChunkSize, n - i * kChunkSize);
128
+ static std::array<Vec, kChunkSize> c_vecs = ([]() {
129
+ std::array<Vec, kChunkSize> result;
130
+ for (const auto i : c10::irange(kChunkSize)) {
131
+ result[i] = Vec(math_t(1) / static_cast<math_t>(i + 1));
132
+ }
133
+ return result;
134
+ })();
135
+ UpdateMomentsVec(m0, X_ptr, c_vecs, m0_stk[0], m1_stk[0], m2_stk[0]);
136
+
137
+ int64_t mask = i + 1;
138
+ for (int64_t j = 1; j < depth && (mask & 1) == 0; ++j) {
139
+ AddMomentsVec(
140
+ m0_stk[j - 1],
141
+ m1_stk[j - 1],
142
+ m2_stk[j - 1],
143
+ m0_stk[j],
144
+ m1_stk[j],
145
+ m2_stk[j]);
146
+ m0_stk[j - 1] = 0;
147
+ m1_stk[j - 1] = kZeroVec;
148
+ m2_stk[j - 1] = kZeroVec;
149
+ mask >>= 1;
150
+ }
151
+ }
152
+ for (const auto i : c10::irange(1, depth)) {
153
+ AddMomentsVec(
154
+ m0_stk[i], m1_stk[i], m2_stk[i], m0_stk[0], m1_stk[0], m2_stk[0]);
155
+ }
156
+
157
+ std::array<math_t, kAccVecSize> m1_arr{};
158
+ std::array<math_t, kAccVecSize> m2_arr{};
159
+ m1_stk[0].store(m1_arr.data());
160
+ m2_stk[0].store(m2_arr.data());
161
+
162
+ int64_t m0 = 0;
163
+ math_t m1 = 0;
164
+ math_t m2 = 0;
165
+ for (int64_t i = n * kVecSize; i < N; ++i) {
166
+ math_t x = static_cast<math_t>(X[i]);
167
+ const math_t delta = x - m1;
168
+ ++m0;
169
+ m1 += delta / static_cast<math_t>(m0);
170
+ m2 += delta * (x - m1);
171
+ }
172
+ // for BFloat16, each vector in m1_arr/m2_arr holds 2*n accumulated result
173
+ int64_t m0_add = n * kVecSize / kAccVecSize;
174
+ for (const auto i : c10::irange(kAccVecSize)) {
175
+ AddMoments(m0_add, m1_arr[i], m2_arr[i], m0, m1, m2);
176
+ }
177
+
178
+ return std::make_pair(m1, m2 / static_cast<math_t>(N - ddof));
179
+ }
180
+
181
+ template <typename T>
182
+ std::pair<opmath_t<T>, opmath_t<T>> RowwiseMoments(const T* X, int64_t N, int64_t ddof = 0) {
183
+ using Vec = vec::Vectorized<T>;
184
+ constexpr int64_t kVecSize = Vec::size();
185
+ const int64_t n = N / kVecSize;
186
+ const int64_t m = divup(n, kChunkSize);
187
+ const int64_t depth = utils::CeilLog2(m);
188
+ if (depth <= 4) {
189
+ return RowwiseMomentsImpl<T, 4>(X, N, ddof);
190
+ } else if (depth <= 8) {
191
+ return RowwiseMomentsImpl<T, 8>(X, N, ddof);
192
+ } else if (depth <= 16) {
193
+ return RowwiseMomentsImpl<T, 16>(X, N, ddof);
194
+ } else if (depth <= 32) {
195
+ return RowwiseMomentsImpl<T, 32>(X, N, ddof);
196
+ } else {
197
+ return RowwiseMomentsImpl<T, 64>(X, N, ddof);
198
+ }
199
+ }
200
+
201
+ } // namespace CPU_CAPABILITY
202
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/utils.h ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/Parallel.h>
4
+ #include <ATen/core/TensorAccessor.h>
5
+ #include <ATen/cpu/vec/vec.h>
6
+ #include <c10/util/llvmMathExtras.h>
7
+
8
+ #ifdef USE_FBGEMM
9
+ #include <fbgemm/Fbgemm.h>
10
+ #endif
11
+
12
+ namespace at::native {
13
+
14
+ template <typename T>
15
+ inline void _store(T* dst, at::vec::Vectorized<T> src) {
16
+ src.store(dst);
17
+ }
18
+
19
+ inline void _store(at::BFloat16* dst, at::vec::Vectorized<float> src) {
20
+ auto res = at::vec::convert_float_bfloat16(src, src);
21
+ res.store(dst, at::vec::Vectorized<float>::size());
22
+ }
23
+
24
+ inline void _store(at::Half* dst, at::vec::Vectorized<float> src) {
25
+ auto res = at::vec::convert_float_half(src, src);
26
+ res.store(dst, at::vec::Vectorized<float>::size());
27
+ }
28
+
29
+ inline namespace CPU_CAPABILITY {
30
+
31
+ template <typename T>
32
+ inline T data_index_init(T offset) {
33
+ return offset;
34
+ }
35
+
36
+ template <typename T, typename... Args>
37
+ inline T data_index_init(T offset, T& x, const T& X, Args&&... args) {
38
+ offset = data_index_init(offset, std::forward<Args>(args)...);
39
+ x = offset % X;
40
+ return offset / X;
41
+ }
42
+
43
+ inline bool data_index_step() {
44
+ return true;
45
+ }
46
+
47
+ template <typename T, typename... Args>
48
+ inline bool data_index_step(T& x, const T& X, Args&&... args) {
49
+ if (data_index_step(std::forward<Args>(args)...)) {
50
+ x = ((x + 1) == X) ? 0 : (x + 1);
51
+ return x == 0;
52
+ }
53
+ return false;
54
+ }
55
+
56
+ // Helper struct for bfloat16/float16 vectorization
57
+ // Useful when you need float as immediate dtype or accumulate dtype
58
+ using namespace vec;
59
+ struct Vec2 {
60
+ Vectorized<float> val0, val1;
61
+ Vec2(Vectorized<float> v0, Vectorized<float> v1) : val0(v0), val1(v1) {}
62
+ Vec2(float v) : val0(v), val1(v) {}
63
+ static Vec2 loadu(const BFloat16* ptr) {
64
+ auto [v0, v1] = convert_bfloat16_float(Vectorized<BFloat16>::loadu(ptr));
65
+ return {v0, v1};
66
+ }
67
+ static Vec2 loadu(const Half* ptr) {
68
+ auto [v0, v1] = convert_half_float(Vectorized<Half>::loadu(ptr));
69
+ return {v0, v1};
70
+ }
71
+ static Vec2 loadu(const float* ptr) {
72
+ return {Vectorized<float>::loadu(ptr), Vectorized<float>::loadu(ptr + Vectorized<float>::size())};
73
+ }
74
+ void store(BFloat16* ptr) const {
75
+ Vectorized<BFloat16> val = convert_float_bfloat16(val0, val1);
76
+ val.store(ptr);
77
+ }
78
+ void store(Half* ptr) const {
79
+ Vectorized<Half> val = convert_float_half(val0, val1);
80
+ val.store(ptr);
81
+ }
82
+ void store(float* ptr) const {
83
+ val0.store(ptr);
84
+ val1.store(ptr + Vectorized<float>::size());
85
+ }
86
+ };
87
+ inline Vec2 operator+(const Vec2& a, const Vec2& b) { return {a.val0 + b.val0, a.val1 + b.val1}; }
88
+ inline Vec2 operator*(const Vec2& a, const Vec2& b) { return {a.val0 * b.val0, a.val1 * b.val1}; }
89
+ inline Vec2 operator-(const Vec2& a, const Vec2& b) { return {a.val0 - b.val0, a.val1 - b.val1}; }
90
+ inline Vec2 operator/(const Vec2& a, const Vec2& b) { return {a.val0 / b.val0, a.val1 / b.val1}; }
91
+ inline Vec2 maximum(const Vec2& a, const Vec2& b) { return {vec::maximum(a.val0, b.val0), vec::maximum(a.val1, b.val1)}; }
92
+ inline Vec2 minimum(const Vec2& a, const Vec2& b) { return {vec::minimum(a.val0, b.val0), vec::minimum(a.val1, b.val1)}; }
93
+
94
+ template <typename scalar_t> struct VectorizedType { using type = Vectorized<scalar_t>; };
95
+ template <> struct VectorizedType<BFloat16> { using type = Vec2; };
96
+ template <> struct VectorizedType<Half> { using type = Vec2; };
97
+ template <typename scalar_t> using VecType = typename VectorizedType<scalar_t>::type;
98
+
99
+ // Helper for mixed data type parameter Vec::load
100
+ inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const BFloat16* ptr) {
101
+ return convert_bfloat16_float(Vectorized<BFloat16>::loadu(ptr));
102
+ }
103
+
104
+ inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const Half* ptr) {
105
+ return convert_half_float(Vectorized<Half>::loadu(ptr));
106
+ }
107
+
108
+ inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const float* ptr) {
109
+ using Vec = Vectorized<float>;
110
+ return std::make_tuple(Vec::loadu(ptr), Vec::loadu(ptr + Vec::size()));
111
+ }
112
+
113
+ inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const BFloat16* ptr, int64_t count) {
114
+ return convert_bfloat16_float(Vectorized<BFloat16>::loadu(ptr, count));
115
+ }
116
+
117
+ inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const Half* ptr, int64_t count) {
118
+ return convert_half_float(Vectorized<Half>::loadu(ptr, count));
119
+ }
120
+
121
+ inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const float* ptr, int64_t count) {
122
+ using Vec = Vectorized<float>;
123
+ if (count > Vec::size()) {
124
+ return std::make_tuple(Vec::loadu(ptr), Vec::loadu(ptr + Vec::size(), count - Vec::size()));
125
+ } else {
126
+ return std::make_tuple(Vec::loadu(ptr, count), Vec(0));
127
+ }
128
+ }
129
+
130
+ } // namespace
131
+
132
+ namespace utils {
133
+
134
+ template <typename T>
135
+ T CeilLog2(const T& x) {
136
+ if (x <= 2) {
137
+ return 1;
138
+ }
139
+ // Last set bit is floor(log2(x)), floor + 1 is ceil
140
+ // except when x is an exact powers of 2, so subtract 1 first
141
+ return static_cast<T>(llvm::findLastSet(static_cast<uint64_t>(x) - 1)) + 1;
142
+ }
143
+
144
+ // matrix transpose:
145
+ // src has shape of M by N, with leading dimension of ld_src
146
+ // dst has shape of N by M, with leading dimension of ld_dst
147
+ template <typename T>
148
+ inline void transpose(int64_t M, int64_t N, const T* src, int64_t ld_src, T* dst, int64_t ld_dst) {
149
+ for (int64_t j = 0; j < N; j++) {
150
+ for (int64_t i = 0; i < M; i++) {
151
+ dst[j * ld_dst + i] = src[i * ld_src + j];
152
+ }
153
+ }
154
+ }
155
+
156
+ #ifdef USE_FBGEMM
157
+ template <>
158
+ inline void transpose<float>(int64_t M, int64_t N, const float* src, int64_t ld_src, float* dst, int64_t ld_dst) {
159
+ TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
160
+ fbgemm::transpose_simd<float>(M, N, src, ld_src, dst, ld_dst);
161
+ }
162
+
163
+ template <>
164
+ inline void transpose<uint16_t>(int64_t M, int64_t N, const uint16_t* src, int64_t ld_src, uint16_t* dst, int64_t ld_dst) {
165
+ TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
166
+ fbgemm::transpose_simd<uint16_t>(M, N, src, ld_src, dst, ld_dst);
167
+ }
168
+ #endif
169
+
170
+ template <typename index_t, typename F>
171
+ inline void parallel_sparse_csr(
172
+ const TensorAccessor<index_t, 1>& crow_acc,
173
+ const int64_t M,
174
+ const int64_t nnz,
175
+ const F& f) {
176
+ TORCH_CHECK(crow_acc.size(0) == M + 1);
177
+
178
+ // directly parallel on `M` may lead to load imbalance,
179
+ // statically determine thread partition here to average payload
180
+ // for each thread.
181
+ int num_threads = at::get_num_threads();
182
+ std::vector<int64_t> thread_splits(num_threads + 1, M);
183
+
184
+ int64_t thread_averge_payload = std::max((int64_t)1, divup(nnz, num_threads));
185
+
186
+ thread_splits[0] = 0;
187
+ int64_t sum = 0;
188
+ int64_t t = 1;
189
+ for (const auto m : c10::irange(M)) {
190
+ int64_t row_start = crow_acc[m];
191
+ int64_t row_end = crow_acc[m + 1];
192
+ sum += row_end - row_start;
193
+ if (sum > t * thread_averge_payload) {
194
+ thread_splits[t] = m;
195
+ t++;
196
+ }
197
+ }
198
+ // need to restore the last index,
199
+ // due to rounding error when calculating `thread_averge_payload`.
200
+ thread_splits[num_threads] = M;
201
+
202
+ at::parallel_for(0, num_threads, 1, [&](int64_t cbegin, int64_t cend) {
203
+ int tid = at::get_thread_num();
204
+ int64_t begin = thread_splits[tid];
205
+ int64_t end = thread_splits[tid + 1];
206
+ f(begin, end);
207
+ });
208
+ }
209
+
210
+ } // namespace utils
211
+
212
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/zmath.h ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // Complex number math operations that act as no-ops for other dtypes.
4
+ #include <c10/util/complex.h>
5
+ #include <c10/util/MathConstants.h>
6
+ #include<ATen/NumericUtils.h>
7
+
8
+ namespace at::native {
9
+ inline namespace CPU_CAPABILITY {
10
+
11
+ template <typename SCALAR_TYPE, typename VALUE_TYPE=SCALAR_TYPE>
12
+ inline VALUE_TYPE zabs (SCALAR_TYPE z) {
13
+ return z;
14
+ }
15
+
16
+ template<>
17
+ inline c10::complex<float> zabs <c10::complex<float>> (c10::complex<float> z) {
18
+ return c10::complex<float>(std::abs(z));
19
+ }
20
+
21
+ template<>
22
+ inline float zabs <c10::complex<float>, float> (c10::complex<float> z) {
23
+ return std::abs(z);
24
+ }
25
+
26
+ template<>
27
+ inline c10::complex<double> zabs <c10::complex<double>> (c10::complex<double> z) {
28
+ return c10::complex<double>(std::abs(z));
29
+ }
30
+
31
+ template<>
32
+ inline double zabs <c10::complex<double>, double> (c10::complex<double> z) {
33
+ return std::abs(z);
34
+ }
35
+
36
+ // This overload corresponds to non-complex dtypes.
37
+ // The function is consistent with its NumPy equivalent
38
+ // for non-complex dtypes where `pi` is returned for
39
+ // negative real numbers and `0` is returned for 0 or positive
40
+ // real numbers.
41
+ // Note: `nan` is propagated.
42
+ template <typename SCALAR_TYPE, typename VALUE_TYPE=SCALAR_TYPE>
43
+ inline VALUE_TYPE angle_impl (SCALAR_TYPE z) {
44
+ if (at::_isnan(z)) {
45
+ return z;
46
+ }
47
+ return z < 0 ? c10::pi<double> : 0;
48
+ }
49
+
50
+ template<>
51
+ inline c10::complex<float> angle_impl <c10::complex<float>> (c10::complex<float> z) {
52
+ return c10::complex<float>(std::arg(z), 0.0);
53
+ }
54
+
55
+ template<>
56
+ inline float angle_impl <c10::complex<float>, float> (c10::complex<float> z) {
57
+ return std::arg(z);
58
+ }
59
+
60
+ template<>
61
+ inline c10::complex<double> angle_impl <c10::complex<double>> (c10::complex<double> z) {
62
+ return c10::complex<double>(std::arg(z), 0.0);
63
+ }
64
+
65
+ template<>
66
+ inline double angle_impl <c10::complex<double>, double> (c10::complex<double> z) {
67
+ return std::arg(z);
68
+ }
69
+
70
+ template <typename SCALAR_TYPE, typename VALUE_TYPE=SCALAR_TYPE>
71
+ constexpr VALUE_TYPE real_impl (SCALAR_TYPE z) {
72
+ return z; //No-Op
73
+ }
74
+
75
+ template<>
76
+ constexpr c10::complex<float> real_impl <c10::complex<float>> (c10::complex<float> z) {
77
+ return c10::complex<float>(z.real(), 0.0);
78
+ }
79
+
80
+ template<>
81
+ constexpr float real_impl <c10::complex<float>, float> (c10::complex<float> z) {
82
+ return z.real();
83
+ }
84
+
85
+ template<>
86
+ constexpr c10::complex<double> real_impl <c10::complex<double>> (c10::complex<double> z) {
87
+ return c10::complex<double>(z.real(), 0.0);
88
+ }
89
+
90
+ template<>
91
+ constexpr double real_impl <c10::complex<double>, double> (c10::complex<double> z) {
92
+ return z.real();
93
+ }
94
+
95
+ template <typename SCALAR_TYPE, typename VALUE_TYPE=SCALAR_TYPE>
96
+ constexpr VALUE_TYPE imag_impl (SCALAR_TYPE /*z*/) {
97
+ return 0;
98
+ }
99
+
100
+ template<>
101
+ constexpr c10::complex<float> imag_impl <c10::complex<float>> (c10::complex<float> z) {
102
+ return c10::complex<float>(z.imag(), 0.0);
103
+ }
104
+
105
+ template<>
106
+ constexpr float imag_impl <c10::complex<float>, float> (c10::complex<float> z) {
107
+ return z.imag();
108
+ }
109
+
110
+ template<>
111
+ constexpr c10::complex<double> imag_impl <c10::complex<double>> (c10::complex<double> z) {
112
+ return c10::complex<double>(z.imag(), 0.0);
113
+ }
114
+
115
+ template<>
116
+ constexpr double imag_impl <c10::complex<double>, double> (c10::complex<double> z) {
117
+ return z.imag();
118
+ }
119
+
120
+ template <typename TYPE>
121
+ inline TYPE conj_impl (TYPE z) {
122
+ return z; //No-Op
123
+ }
124
+
125
+ template<>
126
+ inline c10::complex<at::Half> conj_impl <c10::complex<at::Half>> (c10::complex<at::Half> z) {
127
+ return c10::complex<at::Half>{z.real(), -z.imag()};
128
+ }
129
+
130
+ template<>
131
+ inline c10::complex<float> conj_impl <c10::complex<float>> (c10::complex<float> z) {
132
+ return c10::complex<float>(z.real(), -z.imag());
133
+ }
134
+
135
+ template<>
136
+ inline c10::complex<double> conj_impl <c10::complex<double>> (c10::complex<double> z) {
137
+ return c10::complex<double>(z.real(), -z.imag());
138
+ }
139
+
140
+ template <typename TYPE>
141
+ inline TYPE ceil_impl (TYPE z) {
142
+ return std::ceil(z);
143
+ }
144
+
145
+ template <>
146
+ inline c10::complex<float> ceil_impl (c10::complex<float> z) {
147
+ return c10::complex<float>(std::ceil(z.real()), std::ceil(z.imag()));
148
+ }
149
+
150
+ template <>
151
+ inline c10::complex<double> ceil_impl (c10::complex<double> z) {
152
+ return c10::complex<double>(std::ceil(z.real()), std::ceil(z.imag()));
153
+ }
154
+
155
+ template<typename T>
156
+ inline c10::complex<T> sgn_impl (c10::complex<T> z) {
157
+ if (z == c10::complex<T>(0, 0)) {
158
+ return c10::complex<T>(0, 0);
159
+ } else {
160
+ return z / zabs(z);
161
+ }
162
+ }
163
+
164
+ template <typename TYPE>
165
+ inline TYPE floor_impl (TYPE z) {
166
+ return std::floor(z);
167
+ }
168
+
169
+ template <>
170
+ inline c10::complex<float> floor_impl (c10::complex<float> z) {
171
+ return c10::complex<float>(std::floor(z.real()), std::floor(z.imag()));
172
+ }
173
+
174
+ template <>
175
+ inline c10::complex<double> floor_impl (c10::complex<double> z) {
176
+ return c10::complex<double>(std::floor(z.real()), std::floor(z.imag()));
177
+ }
178
+
179
+ template <typename TYPE>
180
+ inline TYPE round_impl (TYPE z) {
181
+ return std::nearbyint(z);
182
+ }
183
+
184
+ template <>
185
+ inline c10::complex<float> round_impl (c10::complex<float> z) {
186
+ return c10::complex<float>(std::nearbyint(z.real()), std::nearbyint(z.imag()));
187
+ }
188
+
189
+ template <>
190
+ inline c10::complex<double> round_impl (c10::complex<double> z) {
191
+ return c10::complex<double>(std::nearbyint(z.real()), std::nearbyint(z.imag()));
192
+ }
193
+
194
+ template <typename TYPE>
195
+ inline TYPE trunc_impl (TYPE z) {
196
+ return std::trunc(z);
197
+ }
198
+
199
+ template <>
200
+ inline c10::complex<float> trunc_impl (c10::complex<float> z) {
201
+ return c10::complex<float>(std::trunc(z.real()), std::trunc(z.imag()));
202
+ }
203
+
204
+ template <>
205
+ inline c10::complex<double> trunc_impl (c10::complex<double> z) {
206
+ return c10::complex<double>(std::trunc(z.real()), std::trunc(z.imag()));
207
+ }
208
+
209
+ template <typename TYPE, std::enable_if_t<!c10::is_complex<TYPE>::value, int> = 0>
210
+ inline TYPE max_impl (TYPE a, TYPE b) {
211
+ if (_isnan<TYPE>(a) || _isnan<TYPE>(b)) {
212
+ return std::numeric_limits<TYPE>::quiet_NaN();
213
+ } else {
214
+ return std::max(a, b);
215
+ }
216
+ }
217
+
218
+ template <typename TYPE, std::enable_if_t<c10::is_complex<TYPE>::value, int> = 0>
219
+ inline TYPE max_impl (TYPE a, TYPE b) {
220
+ if (_isnan<TYPE>(a)) {
221
+ return a;
222
+ } else if (_isnan<TYPE>(b)) {
223
+ return b;
224
+ } else {
225
+ return std::abs(a) > std::abs(b) ? a : b;
226
+ }
227
+ }
228
+
229
+ template <typename TYPE, std::enable_if_t<!c10::is_complex<TYPE>::value, int> = 0>
230
+ inline TYPE min_impl (TYPE a, TYPE b) {
231
+ if (_isnan<TYPE>(a) || _isnan<TYPE>(b)) {
232
+ return std::numeric_limits<TYPE>::quiet_NaN();
233
+ } else {
234
+ return std::min(a, b);
235
+ }
236
+ }
237
+
238
+ template <typename TYPE, std::enable_if_t<c10::is_complex<TYPE>::value, int> = 0>
239
+ inline TYPE min_impl (TYPE a, TYPE b) {
240
+ if (_isnan<TYPE>(a)) {
241
+ return a;
242
+ } else if (_isnan<TYPE>(b)) {
243
+ return b;
244
+ } else {
245
+ return std::abs(a) < std::abs(b) ? a : b;
246
+ }
247
+ }
248
+
249
+ } // end namespace
250
+ } //end at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Activation.h ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/native/Activation.h>
3
+ #include <cstdint>
4
+
5
+ namespace at {
6
+ struct TensorIteratorBase;
7
+ class TensorBase;
8
+ }
9
+
10
+ namespace at { namespace native {
11
+
12
+ void launch_glu_backward_kernel(const TensorIteratorBase& iter,
13
+ int64_t gI_stride, int64_t I_stride);
14
+
15
+ void launch_log_sigmoid_forward_kernel(TensorIteratorBase& iter);
16
+
17
+ void GeluCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate);
18
+ void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate);
19
+
20
+ }} // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/BinaryInternal.h ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // DON'T include this except from Binary*.cu files. It should not leak into
2
+ // headers.
3
+ #pragma once
4
+ #define TORCH_ASSERT_NO_OPERATORS
5
+ #include <ATen/AccumulateType.h>
6
+ #include <ATen/Dispatch.h>
7
+ #include <ATen/native/BinaryOps.h>
8
+ #include <ATen/native/DispatchStub.h>
9
+ #include <ATen/native/TensorIterator.h>
10
+ #include <c10/cuda/CUDAGuard.h>
11
+ #include <c10/cuda/CUDAMathCompat.h>
12
+ #include <c10/util/TypeSafeSignMath.h>
13
+ #include <ATen/native/cuda/JitLoops.cuh>
14
+ #include <ATen/native/cuda/Loops.cuh>
15
+
16
+ #include <type_traits>
17
+
18
+ namespace at {
19
+ namespace native {
20
+ namespace binary_internal {
21
+
22
+ template <typename scalar_t>
23
+ struct DivFunctor {
24
+ __device__ scalar_t operator()(scalar_t a, scalar_t b) const {
25
+ return a / b;
26
+ }
27
+ };
28
+
29
+ template <typename T>
30
+ struct MulFunctor {
31
+ __device__ T operator()(T a, T b) const {
32
+ return a * b;
33
+ }
34
+ };
35
+
36
+ // Workaround for the error: '*' in boolean context, suggest '&&' instead
37
+ // [-Werror=int-in-bool-context]
38
+ template <>
39
+ struct MulFunctor<bool> {
40
+ __device__ bool operator()(bool a, bool b) const {
41
+ return a && b;
42
+ }
43
+ };
44
+ void div_true_kernel_cuda(TensorIteratorBase& iter);
45
+ void div_trunc_kernel_cuda(TensorIteratorBase& iter);
46
+ } // namespace binary_internal
47
+ } // namespace native
48
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CUDAJitLoops.cuh ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/jit_macros.h>
3
+
4
+ // Jiterator functions are guarded behind this macro
5
+ #if AT_USE_JITERATOR()
6
+
7
+ #include <ATen/OpMathType.h>
8
+ #include <ATen/TensorIterator.h>
9
+ #include <ATen/core/Array.h>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <ATen/cuda/detail/OffsetCalculator.cuh>
12
+ #include <ATen/native/cuda/jit_utils.h>
13
+ #include <ATen/native/cuda/MemoryAccess.cuh>
14
+ #include <ATen/native/cuda/thread_constants.h>
15
+
16
+ #include <ATen/native/cuda/Loops.cuh>
17
+
18
+ #include <c10/macros/Macros.h>
19
+ #include <c10/core/ScalarType.h>
20
+ #include <c10/util/SmallBuffer.h>
21
+
22
+ #include <initializer_list>
23
+ #include <type_traits>
24
+ #include <tuple>
25
+ #include <mutex>
26
+
27
+ namespace at {
28
+ namespace native {
29
+
30
+ template <typename Tuple, std::size_t... I>
31
+ constexpr auto tuple_to_array_helper(Tuple& t, std::index_sequence<I...> seq) {
32
+ constexpr auto size = seq.size();
33
+ (void)t; // warning : unused parameter when tuple is empty.
34
+ return std::array<void*, size>{static_cast<void*>(&std::get<I>(t))...};
35
+ }
36
+
37
+ // Helper function convert tuple to std::array<void*, N>
38
+ // for passing the arguments to CUDA Kernel
39
+ // NOTE: We capture tuple by reference,
40
+ // so the pointers in returned array are only valid
41
+ // till tuple is alive.
42
+ template <typename ...Args>
43
+ constexpr auto tuple_to_array(std::tuple<Args...>& extra_args) {
44
+ constexpr auto tuple_size = sizeof...(Args);
45
+ return tuple_to_array_helper(extra_args, std::make_index_sequence<tuple_size>{});
46
+ }
47
+
48
+ struct JittedVecKernelCache {
49
+ // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
50
+ at::cuda::jit::NvrtcFunction vec1;
51
+ at::cuda::jit::NvrtcFunction vec2;
52
+ at::cuda::jit::NvrtcFunction vec4;
53
+ };
54
+
55
+ struct JittedKernelVariantCache {
56
+ JittedVecKernelCache vec;
57
+ at::cuda::jit::NvrtcFunction noncontiguous;
58
+ at::cuda::jit::NvrtcFunction dynamic_contiguous;
59
+ at::cuda::jit::NvrtcFunction dynamic_noncontiguous;
60
+ };
61
+
62
+ inline c10::SmallBuffer<void*, 64> pack_kernel_args(
63
+ std::initializer_list<void*> args,
64
+ c10::ArrayRef<void*> extra_args) {
65
+ c10::SmallBuffer<void*, 64> ret(args.size() + extra_args.size());
66
+ std::copy(args.begin(), args.end(), ret.data());
67
+ std::copy(extra_args.begin(), extra_args.end(), ret.data() + args.size());
68
+ return ret;
69
+ }
70
+
71
+ template<typename array_t,
72
+ typename inp_calc_t,
73
+ typename out_calc_t,
74
+ typename loader_t,
75
+ typename storer_t>
76
+ void launch_jitted_unrolled_kernel(
77
+ std::mutex &jiterator_mutex,
78
+ at::cuda::jit::NvrtcFunction &fn_cache,
79
+ const at::cuda::jit::KernelDescriptor &desc,
80
+ int64_t N,
81
+ array_t data,
82
+ inp_calc_t ic,
83
+ out_calc_t oc,
84
+ loader_t l,
85
+ storer_t s,
86
+ bool contiguous,
87
+ at::cuda::jit::BinaryFuncVariant scalar_pos,
88
+ void* scalar_val,
89
+ c10::ArrayRef<void*> extra_args) {
90
+
91
+ TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
92
+ //casting result to int is always safe, intermediate is int64 and won't overflow
93
+ const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
94
+
95
+ if (!fn_cache.function) {
96
+ const std::lock_guard<std::mutex> lock{jiterator_mutex};
97
+ if (!fn_cache.function) {
98
+ constexpr bool dynamic_casting = !std::is_same<decltype(l), memory::LoadWithoutCast>() ||
99
+ !std::is_same<decltype(s), memory::StoreWithoutCast>();
100
+ auto code = at::cuda::jit::generate_code(
101
+ desc, contiguous, dynamic_casting, scalar_pos);
102
+ fn_cache = at::cuda::jit::jit_pwise_function(code, desc.name);
103
+ }
104
+ }
105
+
106
+ auto args = pack_kernel_args({&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args);
107
+ at::cuda::jit::launch_jitted_pwise_function(fn_cache, args.data(), {grid, 1u, 1u},
108
+ {num_threads(), 1u, 1u});
109
+ }
110
+
111
+ template<int arity, typename array_t>
112
+ void launch_jitted_vectorized_kernel(
113
+ std::mutex &jiterator_mutex, JittedVecKernelCache &fn_cache,
114
+ const at::cuda::jit::KernelDescriptor &desc, int64_t N, array_t data,
115
+ at::cuda::jit::BinaryFuncVariant scalar_pos,
116
+ void *scalar_val, c10::ArrayRef<void*> extra_args) {
117
+ TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
118
+ // N is still int64_t for the computation, but it's always safe to cast result to int
119
+ const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
120
+ const int vec_size = at::cuda::jit::can_vectorize_up_to(
121
+ desc, c10::ArrayRef<char*>(data.data, data.size()));
122
+
123
+ // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
124
+ // fn_ptr is set to the appropriate function based on the vec size and GPU used
125
+ at::cuda::jit::NvrtcFunction* fn_ptr;
126
+ if (vec_size == 4) {
127
+ fn_ptr = &fn_cache.vec4;
128
+ } else if (vec_size == 2) {
129
+ fn_ptr = &fn_cache.vec2;
130
+ } else if (vec_size ==1) {
131
+ fn_ptr = &fn_cache.vec1;
132
+ } else {
133
+ TORCH_INTERNAL_ASSERT(false, "unexpected vec_size for jitter vectorized kernel");
134
+ }
135
+
136
+ bool vectorized = vec_size > 1;
137
+
138
+ if (!fn_ptr->function) {
139
+ const std::lock_guard<std::mutex> lock{jiterator_mutex};
140
+ if (!fn_ptr->function) { // cache miss!
141
+
142
+ // Generates program
143
+ auto code = at::cuda::jit::generate_code(
144
+ desc, /*contiguous=*/true, /*dynamic_casting=*/false,
145
+ scalar_pos, vectorized, vec_size);
146
+ std::string kernel_name = vectorized ? desc.name + "_vectorized" + std::to_string(vec_size) : desc.name;
147
+
148
+ // Acquires the program
149
+ *fn_ptr = at::cuda::jit::jit_pwise_function(code, kernel_name);
150
+ }
151
+ }
152
+
153
+ if (vectorized) {
154
+ auto args = pack_kernel_args({&N, &data, scalar_val}, extra_args);
155
+ at::cuda::jit::launch_jitted_pwise_function(
156
+ *fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
157
+ } else {
158
+ // NVCC complains about unused variables l and s.
159
+ // It should be false positive in most cases, so we suppress the warnings.
160
+ #pragma nv_diagnostic push
161
+ #pragma nv_diag_suppress 177
162
+ auto ic = TrivialOffsetCalculator<arity>();
163
+ auto oc = TrivialOffsetCalculator<1>();
164
+ auto l = memory::LoadWithoutCast();
165
+ auto s = memory::StoreWithoutCast();
166
+
167
+ auto args = pack_kernel_args(
168
+ {&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args);
169
+ at::cuda::jit::launch_jitted_pwise_function(
170
+ *fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
171
+ #pragma nv_diagnostic pop
172
+ }
173
+ }
174
+
175
+ template <int arity>
176
+ void jitted_gpu_kernel_generic(
177
+ std::mutex &jiterator_mutex,
178
+ JittedKernelVariantCache &cache,
179
+ const at::cuda::jit::KernelDescriptor &desc,
180
+ at::cuda::jit::BinaryFuncVariant scalar_pos,
181
+ c10::ArrayRef<void*> extra_args,
182
+ TensorIteratorBase& iter,
183
+ const bool dynamic_casting,
184
+ void *scalar_val) {
185
+ TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
186
+ TORCH_INTERNAL_ASSERT(iter.ninputs() == arity);
187
+ TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
188
+
189
+ constexpr int ntensors = arity + 1;
190
+ at::detail::Array<char*, ntensors> data;
191
+ for (auto i : c10::irange(ntensors)) {
192
+ data[i] = (char*)iter.data_ptr(i);
193
+ }
194
+
195
+ int64_t numel = iter.numel();
196
+ bool contiguous = iter.is_contiguous();
197
+
198
+ // Decides which of 4 kernel types to launch
199
+ // Variations are:
200
+ // - Case 1: no dynamic casting and contiguous
201
+ // - Case 2: no dynamic casting and noncontiguous
202
+ // - Case 3: dynamic casting and contiguous
203
+ // - Case 4: dynamic casting and noncontiguous
204
+ // These cases align with the non-jitted CUDALoops.cuh cases in gpu_kernel_impl
205
+
206
+ if (!dynamic_casting) {
207
+ if (contiguous) {
208
+ // Case 1: no dynamic casting and contiguous
209
+ launch_jitted_vectorized_kernel<arity>(
210
+ jiterator_mutex, cache.vec, desc,
211
+ numel, data, scalar_pos, scalar_val, extra_args);
212
+ return;
213
+ }
214
+
215
+ // Case 2: no dynamic casting and noncontiguous
216
+ auto input_offset_calculator = make_input_offset_calculator<arity>(iter);
217
+ auto output_offset_calculator = make_output_offset_calculator(iter);
218
+ auto loader = memory::LoadWithoutCast();
219
+ auto storer = memory::StoreWithoutCast();
220
+ launch_jitted_unrolled_kernel(
221
+ jiterator_mutex, cache.noncontiguous, desc, numel, data,
222
+ input_offset_calculator, output_offset_calculator, loader,
223
+ storer, contiguous, scalar_pos, scalar_val, extra_args);
224
+ return;
225
+ }
226
+
227
+ // Cases 3 and 4 are handled below
228
+ // Both require construction of a storer (this asserts 1 output) and one or more loaders
229
+
230
+ // Creates store cast to output (the zeroth tensor in TensorIterator)
231
+ auto storer = memory::StoreWithCast<1>(iter);
232
+
233
+ // Creates load casts from inputs (note offset indexing into the iterators 1...n tensors)
234
+ auto loader = memory::LoadWithCast<arity>(iter);
235
+
236
+ if (contiguous) {
237
+ // Case 3: dynamic casting and contiguous
238
+ auto input_offset_calculator = TrivialOffsetCalculator<arity>();
239
+ auto output_offset_calculator = TrivialOffsetCalculator<1>();
240
+ launch_jitted_unrolled_kernel(
241
+ jiterator_mutex, cache.dynamic_contiguous, desc, numel, data, input_offset_calculator,
242
+ output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args);
243
+ return;
244
+ }
245
+
246
+ // Case 4: dynamic casting and noncontiguous
247
+ auto input_offset_calculator = make_input_offset_calculator<arity>(iter);
248
+ auto output_offset_calculator = make_output_offset_calculator(iter);
249
+ launch_jitted_unrolled_kernel(
250
+ jiterator_mutex, cache.dynamic_noncontiguous, desc, numel, data, input_offset_calculator,
251
+ output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args);
252
+ }
253
+
254
+ // NOTE: static to reduce chances of name collision.
255
+ template <
256
+ char const* name,
257
+ typename result_type,
258
+ typename f_inputs_type,
259
+ int arity,
260
+ at::cuda::jit::BinaryFuncVariant scalar_pos =
261
+ at::cuda::jit::BinaryFuncVariant::NoScalar,
262
+ typename... ExtraArgs>
263
+ static void jitted_gpu_kernel_impl(
264
+ TensorIteratorBase& iter,
265
+ const std::string &f,
266
+ const bool dynamic_casting,
267
+ at::opmath_type<f_inputs_type> scalar_val,
268
+ std::tuple<ExtraArgs...> extra_args) {
269
+
270
+ // TODO: Memory use can probably be optimized by re-using kernels across GPUs with
271
+ // the same compute capability
272
+ static std::mutex jiterator_mutex;
273
+ static std::vector<JittedKernelVariantCache> device_caches(c10::cuda::device_count());
274
+
275
+ constexpr int nInputs = arity;
276
+ constexpr int nOutputs = 1; // TODO: Support more than 1 output
277
+ static const auto desc = at::cuda::jit::make_kernel_descriptor<
278
+ result_type, f_inputs_type, ExtraArgs...>(name, f, nInputs, nOutputs);
279
+
280
+ auto &cache = device_caches[iter.device().index()];
281
+ auto extra_args_array = tuple_to_array(extra_args);
282
+ return jitted_gpu_kernel_generic<arity>(
283
+ jiterator_mutex,
284
+ cache,
285
+ desc,
286
+ scalar_pos,
287
+ extra_args_array,
288
+ iter,
289
+ dynamic_casting,
290
+ &scalar_val
291
+ );
292
+ }
293
+
294
+ }} // at::native
295
+
296
+ #endif // AT_USE_JITERATOR()
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CUDALoops.cuh ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // This file provides two functions to help write GPU elementwise kernels:
4
+ //
5
+ // gpu_kernel(TensorIterator iter, <lambda>)
6
+ // gpu_kernel_with_scalars(TensorIterator iter, <lambda>)
7
+ //
8
+ // The gpu_kernel_with_scalars generates specializations that support a
9
+ // single scalar CPU argument, such as from `cuda_tensor + 5`. The CPU scalar
10
+ // is lifted to a kernel parameter instead of copying to device memory.
11
+ // This should be used in conjunction with TensorIterator::allow_cpu_scalars_,
12
+ // which is the default for TensorIterator::binary_op. Otherwise, all inputs
13
+ // and the output must be on the GPU.
14
+ //
15
+ // For example, to write a reciprocal kernel for GPU float Tensors:
16
+ //
17
+ // gpu_kernel(iter, []GPU_LAMBDA(float a) {
18
+ // return 1.0f / a;
19
+ // });
20
+ //
21
+ // To write a multiplication kernel for GPU float Tensors where one argument
22
+ // may be a CPU scalar:
23
+ //
24
+ // gpu_kernel_with_scalars(iter, []GPU_LAMBDA(float a, float b) {
25
+ // return a * b;
26
+ // });
27
+ //
28
+ // See BinaryOpsKernel.cu for the complete implementation
29
+ //
30
+
31
+ #include <iostream>
32
+ #include <tuple>
33
+ #include <type_traits>
34
+
35
+ #include <ATen/core/Array.h>
36
+ #include <ATen/cuda/CUDAContext.h>
37
+ #include <ATen/detail/FunctionTraits.h>
38
+ #include <ATen/native/TensorIterator.h>
39
+ #include <c10/core/DynamicCast.h>
40
+ #include <c10/core/ScalarType.h>
41
+ #include <c10/macros/Macros.h>
42
+ #include <c10/util/TypeCast.h>
43
+
44
+ #ifdef __NVCC__
45
+ #define ASSERT_HOST_DEVICE_LAMBDA(type) \
46
+ static_assert( \
47
+ __nv_is_extended_host_device_lambda_closure_type(type), \
48
+ #type " must be a __host__ __device__ lambda")
49
+ #else
50
+ #define ASSERT_HOST_DEVICE_LAMBDA(type)
51
+ #endif
52
+
53
+ namespace at {
54
+ namespace native {
55
+
56
+ template <int vec_size, typename func_t, typename array_t>
57
+ C10_LAUNCH_BOUNDS_1(num_threads())
58
+ __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
59
+ using traits = function_traits<func_t>;
60
+ int remaining = N - block_work_size() * blockIdx.x;
61
+
62
+ if (remaining < block_work_size()) { // if this block handles the reminder,
63
+ // just do a naive unrolled loop
64
+ auto input_calc = TrivialOffsetCalculator<traits::arity>();
65
+ auto output_calc = TrivialOffsetCalculator<1>();
66
+ auto loader = memory::LoadWithoutCast();
67
+ auto storer = memory::StoreWithoutCast();
68
+ auto policy = memory::policies::unroll<
69
+ array_t,
70
+ decltype(input_calc),
71
+ decltype(output_calc),
72
+ memory::LoadWithoutCast,
73
+ memory::StoreWithoutCast>(
74
+ data, remaining, input_calc, output_calc, loader, storer);
75
+ elementwise_kernel_helper(f, policy);
76
+ } else { // if this block has a full `block_work_size` data to handle, use
77
+ // vectorized memory access
78
+ elementwise_kernel_helper(
79
+ f, memory::policies::vectorized<vec_size, array_t>(data));
80
+ }
81
+ }
82
+
83
+ template <
84
+ typename func_t,
85
+ typename array_t,
86
+ typename inp_calc_t,
87
+ typename out_calc_t,
88
+ typename loader_t,
89
+ typename storer_t>
90
+ C10_LAUNCH_BOUNDS_1(num_threads())
91
+ __global__ void unrolled_elementwise_kernel(
92
+ int N,
93
+ func_t f,
94
+ array_t data,
95
+ inp_calc_t ic,
96
+ out_calc_t oc,
97
+ loader_t l,
98
+ storer_t s) {
99
+ int remaining = N - block_work_size() * blockIdx.x;
100
+ auto policy = memory::policies::
101
+ unroll<array_t, inp_calc_t, out_calc_t, loader_t, storer_t>(
102
+ data, remaining, ic, oc, l, s);
103
+ elementwise_kernel_helper(f, policy);
104
+ }
105
+
106
+ // this function assume trivial 1d and no dynamic casting
107
+ template <typename func_t, typename array_t>
108
+ static inline void launch_vectorized_kernel(
109
+ int64_t N,
110
+ const func_t& f,
111
+ array_t data) {
112
+ TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
113
+ using traits = function_traits<func_t>;
114
+ int64_t grid = (N + block_work_size() - 1) / block_work_size();
115
+ auto stream = at::cuda::getCurrentCUDAStream();
116
+ int vec_size = memory::can_vectorize_up_to<func_t>(data);
117
+
118
+ switch (vec_size) {
119
+ case 4:
120
+ vectorized_elementwise_kernel<4, func_t, array_t>
121
+ <<<grid, num_threads(), 0, stream>>>(N, f, data);
122
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
123
+ break;
124
+ case 2:
125
+ vectorized_elementwise_kernel<2, func_t, array_t>
126
+ <<<grid, num_threads(), 0, stream>>>(N, f, data);
127
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
128
+ break;
129
+ case 1: {
130
+ auto input_calc = TrivialOffsetCalculator<traits::arity>();
131
+ auto output_calc = TrivialOffsetCalculator<1>();
132
+ auto loader = memory::LoadWithoutCast();
133
+ auto storer = memory::StoreWithoutCast();
134
+ unrolled_elementwise_kernel<func_t, array_t>
135
+ <<<grid, num_threads(), 0, stream>>>(
136
+ N, f, data, input_calc, output_calc, loader, storer);
137
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
138
+ break;
139
+ }
140
+ default:
141
+ TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size");
142
+ }
143
+ }
144
+
145
+ template <
146
+ typename func_t,
147
+ typename array_t,
148
+ typename inp_calc_t,
149
+ typename out_calc_t,
150
+ typename loader_t,
151
+ typename storer_t>
152
+ static inline void launch_unrolled_kernel(
153
+ int64_t N,
154
+ const func_t& f,
155
+ array_t data,
156
+ inp_calc_t ic,
157
+ out_calc_t oc,
158
+ loader_t l,
159
+ storer_t s) {
160
+ TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
161
+ int64_t grid = (N + block_work_size() - 1) / block_work_size();
162
+ auto stream = at::cuda::getCurrentCUDAStream();
163
+ unrolled_elementwise_kernel<func_t, array_t>
164
+ <<<grid, num_threads(), 0, stream>>>(N, f, data, ic, oc, l, s);
165
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
166
+ }
167
+
168
+ template <int nt, int vt, typename func_t>
169
+ C10_LAUNCH_BOUNDS_2(nt, 4)
170
+ __global__ void elementwise_kernel(int N, func_t f) {
171
+ int tid = threadIdx.x;
172
+ int nv = nt * vt;
173
+ int idx = nv * blockIdx.x + tid;
174
+ #pragma unroll
175
+ for (int i = 0; i < vt; i++) {
176
+ if (idx < N) {
177
+ f(idx);
178
+ idx += nt;
179
+ }
180
+ }
181
+ }
182
+
183
+ template <int nt, int vt, typename func_t>
184
+ static void launch_legacy_kernel(int64_t N, const func_t& f) {
185
+ TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
186
+ if (N == 0) {
187
+ return;
188
+ }
189
+ dim3 block(nt);
190
+ dim3 grid((N + block.x * vt - 1) / (block.x * vt));
191
+ auto stream = at::cuda::getCurrentCUDAStream();
192
+ elementwise_kernel<nt, vt, func_t><<<grid, block, 0, stream>>>(N, f);
193
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
194
+ }
195
+
196
+ template <typename traits, typename func_t, typename index_t, size_t... INDEX>
197
+ C10_HOST_DEVICE typename traits::result_type invoke_impl(
198
+ const func_t& f,
199
+ char* const C10_RESTRICT data[],
200
+ const index_t strides[],
201
+ int i,
202
+ std::index_sequence<INDEX...>) {
203
+ (void)strides;
204
+ (void)i;
205
+ return f(c10::load<typename traits::template arg<INDEX>::type>(
206
+ data[INDEX] + i * strides[INDEX])...);
207
+ }
208
+
209
+ template <
210
+ typename func_t,
211
+ typename index_t,
212
+ typename traits = function_traits<func_t>>
213
+ C10_HOST_DEVICE typename traits::result_type invoke(
214
+ const func_t& f,
215
+ char* const C10_RESTRICT data[],
216
+ const index_t strides[],
217
+ int i) {
218
+ using Indices = std::make_index_sequence<traits::arity>;
219
+ return invoke_impl<traits>(f, data, strides, i, Indices{});
220
+ }
221
+
222
+ template <typename traits, typename func_t, typename index_t, size_t... I>
223
+ C10_HOST_DEVICE typename traits::result_type invoke_impl(
224
+ const func_t& f,
225
+ char* const C10_RESTRICT data[],
226
+ const index_t strides[],
227
+ const ScalarType dtypes[],
228
+ int i,
229
+ std::index_sequence<I...>) {
230
+ (void)strides;
231
+ (void)i;
232
+ return f(c10::fetch_and_cast<typename traits::template arg<I>::type>(
233
+ dtypes[I], data[I] + i * strides[I])...);
234
+ }
235
+
236
+ template <
237
+ typename func_t,
238
+ typename index_t,
239
+ typename traits = function_traits<func_t>>
240
+ C10_HOST_DEVICE typename traits::result_type invoke(
241
+ const func_t& f,
242
+ char* const C10_RESTRICT data[],
243
+ const index_t strides[],
244
+ const ScalarType dtypes[],
245
+ int i) {
246
+ using Indices = std::make_index_sequence<traits::arity>;
247
+ return invoke_impl<traits>(f, data, strides, dtypes, i, Indices{});
248
+ }
249
+
250
+ template <typename func_t>
251
+ void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) {
252
+ using traits = function_traits<func_t>;
253
+ using arg0_t = typename traits::result_type;
254
+ constexpr int ntensors = traits::arity + 1;
255
+
256
+ TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
257
+ TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
258
+ TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
259
+ TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
260
+
261
+ at::detail::Array<char*, ntensors> data;
262
+ for (int i = 0; i < ntensors; i++) {
263
+ data[i] = (char*)iter.data_ptr(i);
264
+ }
265
+
266
+ int64_t numel = iter.numel();
267
+
268
+ bool contiguous = iter.is_contiguous();
269
+
270
+ if (contiguous) {
271
+ return launch_vectorized_kernel(numel, f, data);
272
+ }
273
+ auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
274
+ constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 2 : 4;
275
+ launch_legacy_kernel<128, unroll_factor>(numel, [=] GPU_LAMBDA(int idx) {
276
+ auto offsets = offset_calc.get(idx);
277
+ arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
278
+ *out = invoke(f, &data.data[1], &offsets.data[1], 1);
279
+ });
280
+ }
281
+
282
+ template <typename func_t>
283
+ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
284
+ if (!needs_dynamic_casting<func_t>::check(iter)) {
285
+ return gpu_kernel_impl_nocast(iter, f);
286
+ }
287
+ using traits = function_traits<func_t>;
288
+ using arg0_t = typename traits::result_type;
289
+ constexpr int ntensors = traits::arity + 1;
290
+
291
+ TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
292
+ TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
293
+ TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
294
+
295
+ at::detail::Array<char*, ntensors> data;
296
+ for (int i = 0; i < ntensors; i++) {
297
+ data[i] = (char*)iter.data_ptr(i);
298
+ }
299
+
300
+ int64_t numel = iter.numel();
301
+
302
+ bool contiguous = iter.is_contiguous();
303
+
304
+ if (contiguous) {
305
+ #ifdef USE_ROCM
306
+ at::detail::Array<ScalarType, ntensors> dtypes;
307
+ auto inner_strides = iter.get_inner_strides();
308
+ at::detail::Array<int, ntensors> strides;
309
+ for (int i = 0; i < ntensors; i++) {
310
+ dtypes[i] = iter.dtype(i);
311
+ strides[i] = inner_strides[i];
312
+ }
313
+ launch_legacy_kernel<512, 1>(numel, [=]GPU_LAMBDA(int idx) {
314
+ void* out = data[0] + strides[0] * idx;
315
+ arg0_t result = invoke(f, &data.data[1], &strides.data[1], &dtypes.data[1], idx);
316
+ c10::cast_and_store<arg0_t>(dtypes[0], out, result);
317
+ });
318
+ #else
319
+ auto loader = memory::LoadWithCast<traits::arity>(iter);
320
+ auto storer = memory::StoreWithCast<1>(iter);
321
+ auto input_offset_calculator = TrivialOffsetCalculator<traits::arity>();
322
+ auto output_offset_calculator = TrivialOffsetCalculator<1>();
323
+ launch_unrolled_kernel(
324
+ numel,
325
+ f,
326
+ data,
327
+ input_offset_calculator,
328
+ output_offset_calculator,
329
+ loader,
330
+ storer);
331
+ #endif
332
+ } else {
333
+ at::detail::Array<ScalarType, ntensors> dtypes;
334
+ for (int i = 0; i < ntensors; i++) {
335
+ dtypes[i] = iter.dtype(i);
336
+ }
337
+ auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
338
+ launch_legacy_kernel<128, 4>(numel, [=] GPU_LAMBDA(int idx) {
339
+ auto offsets = offset_calc.get(idx);
340
+ void* out = data[0] + offsets[0];
341
+ arg0_t result = invoke(f, &data.data[1], &offsets.data[1], &dtypes.data[1], 1);
342
+ c10::cast_and_store<arg0_t>(dtypes[0], out, result);
343
+ });
344
+ }
345
+ }
346
+
347
+ } // namespace native
348
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CompositeRandomAccessor.h ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/CompositeRandomAccessorCommon.h>
4
+ #include <thrust/tuple.h>
5
+
6
+ namespace at { namespace native {
7
+
8
+ struct TupleInfoCPU {
9
+ template <typename ...Types>
10
+ using tuple = thrust::tuple<Types...>;
11
+
12
+ template <typename ...Types>
13
+ static constexpr auto tie(Types&... args) noexcept {
14
+ return thrust::tie(args...);
15
+ }
16
+ };
17
+
18
+ template <typename KeyAccessor, typename ValueAccessor>
19
+ using CompositeRandomAccessorCPU =
20
+ CompositeRandomAccessor<KeyAccessor, ValueAccessor, TupleInfoCPU>;
21
+
22
+ template <typename Values, typename References>
23
+ void swap(
24
+ references_holder<Values, References> rh1,
25
+ references_holder<Values, References> rh2
26
+ ) {
27
+ return thrust::swap(rh1.data(), rh2.data());
28
+ }
29
+
30
+ template <int N, typename Values, typename References>
31
+ auto get(references_holder<Values, References> rh) -> decltype(thrust::get<N>(rh.data())) {
32
+ return thrust::get<N>(rh.data());
33
+ }
34
+
35
+ }} // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Copy.h ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ namespace at {
4
+ struct TensorIteratorBase;
5
+
6
+ namespace native {
7
+
8
+ void direct_copy_kernel_cuda(TensorIteratorBase &iter);
9
+
10
+ }} // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CuFFTUtils.h ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/Config.h>
4
+
5
+ #include <string>
6
+ #include <stdexcept>
7
+ #include <sstream>
8
+ #include <cufft.h>
9
+ #include <cufftXt.h>
10
+
11
+ namespace at { namespace native {
12
+
13
+ // This means that max dim is 3 + 2 = 5 with batch dimension and possible
14
+ // complex dimension
15
+ constexpr int max_rank = 3;
16
+
17
+ static inline std::string _cudaGetErrorEnum(cufftResult error)
18
+ {
19
+ switch (error)
20
+ {
21
+ case CUFFT_SUCCESS:
22
+ return "CUFFT_SUCCESS";
23
+ case CUFFT_INVALID_PLAN:
24
+ return "CUFFT_INVALID_PLAN";
25
+ case CUFFT_ALLOC_FAILED:
26
+ return "CUFFT_ALLOC_FAILED";
27
+ case CUFFT_INVALID_TYPE:
28
+ return "CUFFT_INVALID_TYPE";
29
+ case CUFFT_INVALID_VALUE:
30
+ return "CUFFT_INVALID_VALUE";
31
+ case CUFFT_INTERNAL_ERROR:
32
+ return "CUFFT_INTERNAL_ERROR";
33
+ case CUFFT_EXEC_FAILED:
34
+ return "CUFFT_EXEC_FAILED";
35
+ case CUFFT_SETUP_FAILED:
36
+ return "CUFFT_SETUP_FAILED";
37
+ case CUFFT_INVALID_SIZE:
38
+ return "CUFFT_INVALID_SIZE";
39
+ case CUFFT_UNALIGNED_DATA:
40
+ return "CUFFT_UNALIGNED_DATA";
41
+ case CUFFT_INCOMPLETE_PARAMETER_LIST:
42
+ return "CUFFT_INCOMPLETE_PARAMETER_LIST";
43
+ case CUFFT_INVALID_DEVICE:
44
+ return "CUFFT_INVALID_DEVICE";
45
+ case CUFFT_PARSE_ERROR:
46
+ return "CUFFT_PARSE_ERROR";
47
+ case CUFFT_NO_WORKSPACE:
48
+ return "CUFFT_NO_WORKSPACE";
49
+ case CUFFT_NOT_IMPLEMENTED:
50
+ return "CUFFT_NOT_IMPLEMENTED";
51
+ #if !defined(USE_ROCM)
52
+ case CUFFT_LICENSE_ERROR:
53
+ return "CUFFT_LICENSE_ERROR";
54
+ #endif
55
+ case CUFFT_NOT_SUPPORTED:
56
+ return "CUFFT_NOT_SUPPORTED";
57
+ default:
58
+ std::ostringstream ss;
59
+ ss << "unknown error " << error;
60
+ return ss.str();
61
+ }
62
+ }
63
+
64
+ static inline void CUFFT_CHECK(cufftResult error)
65
+ {
66
+ if (error != CUFFT_SUCCESS) {
67
+ std::ostringstream ss;
68
+ ss << "cuFFT error: " << _cudaGetErrorEnum(error);
69
+ AT_ERROR(ss.str());
70
+ }
71
+ }
72
+
73
+ }} // at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/DeviceSqrt.cuh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ namespace at { namespace native {
4
+ #if defined(USE_ROCM)
5
+ // take these out when ROCm implements std:: math functions
6
+ #include <math.h>
7
+ template <typename scalar_t>
8
+ static __forceinline__ __device__ scalar_t device_sqrt(scalar_t val);
9
+
10
+ template <>
11
+ __forceinline__ __device__ float device_sqrt(float val) {
12
+ return ::sqrtf(val);
13
+ }
14
+
15
+ template <>
16
+ __forceinline__ __device__ double device_sqrt(double val) {
17
+ return ::sqrt(val);
18
+ }
19
+ #else
20
+ template<typename scalar_t>
21
+ __forceinline__ __device__ double device_sqrt(scalar_t val) {
22
+ return std::sqrt(val);
23
+ }
24
+ #endif
25
+ }}
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/DistributionTemplates.h ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/AccumulateType.h>
4
+ #include <ATen/Dispatch.h>
5
+ #include <ATen/Dispatch_v2.h>
6
+ #include <ATen/ExpandBase.h>
7
+ #include <ATen/OpMathType.h>
8
+ #include <ATen/native/TensorIterator.h>
9
+ #include <ATen/native/cuda/Loops.cuh>
10
+ #include <c10/util/Half.h>
11
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
12
+ #include <ATen/cuda/CUDAContext.h>
13
+ #include <ATen/cuda/detail/OffsetCalculator.cuh>
14
+ #include <ATen/cuda/CUDAGraphsUtils.cuh>
15
+ #include <ATen/detail/FunctionTraits.h>
16
+ #include <ATen/core/DistributionsHelper.h>
17
+
18
+ #include <curand.h>
19
+ #include <curand_kernel.h>
20
+ #include <curand_philox4x32_x.h>
21
+ #include <cstdint>
22
+ #include <limits>
23
+ #include <utility>
24
+ #include <mutex>
25
+ #include <tuple>
26
+ #include <type_traits>
27
+
28
+ namespace at {
29
+ namespace native {
30
+ namespace {
31
+
32
+ // launch bounds used for kernels utilizing TensorIterator
33
+ const uint32_t block_size_bound = 256;
34
+ const uint32_t grid_size_bound = 4;
35
+ // At the time of writing, there is no curand_* call that increments the offset by more than 4.
36
+ // See: https://docs.nvidia.com/cuda/archive/11.8.0/curand/group__DEVICE.html
37
+ const uint32_t max_generator_offsets_per_curand_call = 4;
38
+
39
+ // utility function that calculates proper philox_offset
40
+ // for distributions utilizing TensorIterator. For distributions using
41
+ // TensorIterator, we are using a grid-stride loop with each
42
+ // thread yielding one element per thread. For the edge of the grid-stride
43
+ // loop, if the tensor size is large, the unroll loop will kick in and the float4
44
+ // from curand4 will start getting utilized (for common tensor sizes, we end up
45
+ // using rand.x from each thread). The philox_offset calculation was changed to
46
+ // (number of elements per thread * maximum generator increment per "curand_*" call), which makes
47
+ // sure that philox offset increment is not less than the number of randoms used
48
+ // in each thread.
49
+ std::tuple<uint64_t, dim3, dim3> calc_execution_policy(const int64_t total_elements, const uint32_t unroll_factor) {
50
+ const uint64_t numel = static_cast<uint64_t>(total_elements);
51
+ const uint32_t block_size = block_size_bound;
52
+ dim3 dim_block(block_size);
53
+ dim3 grid((numel + block_size - 1) / block_size);
54
+ uint32_t blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size;
55
+ grid.x = std::min(
56
+ static_cast<uint32_t>(at::cuda::getCurrentDeviceProperties()->multiProcessorCount) * blocks_per_sm,
57
+ grid.x);
58
+ //number of times random will be generated per thread, to offset philox counter in thc random state
59
+ uint64_t counter_offset = ((numel - 1) / (block_size * grid.x * unroll_factor) + 1) * max_generator_offsets_per_curand_call;
60
+ return std::make_tuple(counter_offset, grid, dim_block);
61
+ }
62
+
63
+ // grid stride loop kernel for distributions
64
+ template<typename accscalar_t, int unroll_factor, typename dist_t, typename transform_t>
65
+ C10_LAUNCH_BOUNDS_2(block_size_bound, grid_size_bound)
66
+ __global__ void distribution_elementwise_grid_stride_kernel(int numel,
67
+ PhiloxCudaState philox_args,
68
+ const dist_t dist_func,
69
+ const transform_t transform_func) {
70
+ auto seeds = at::cuda::philox::unpack(philox_args);
71
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
72
+ curandStatePhilox4_32_10_t state;
73
+ curand_init(std::get<0>(seeds),
74
+ idx,
75
+ std::get<1>(seeds),
76
+ &state);
77
+
78
+ int rounded_size = ((numel - 1)/(blockDim.x * gridDim.x * unroll_factor)+1) *
79
+ blockDim.x * gridDim.x * unroll_factor;
80
+ for(int linear_index = idx; linear_index < rounded_size; linear_index += blockDim.x * gridDim.x * unroll_factor) {
81
+ auto rand = dist_func(&state);
82
+ #pragma unroll
83
+ for (int ii = 0; ii < unroll_factor; ii++) {
84
+ int li = linear_index + blockDim.x * gridDim.x * ii;
85
+ if (li < numel) {
86
+ transform_func(li, static_cast<accscalar_t>((&rand.x)[ii]));
87
+ }
88
+ }
89
+ __syncthreads();
90
+ }
91
+ }
92
+
93
+ /**
94
+ * distribution_nullary_kernel is analogous to gpu_kernel in
95
+ * ATen/native/cuda/Loops.cuh. Like gpu_kernel, it uses
96
+ * TensorIterator to launch a kernel. However, the differences are
97
+ * - it launches a grid-stride loop based kernel. The kernel is not
98
+ * generic like elementwise_kernel in Loops.cuh and is specialized
99
+ * for the distribution kernels here.
100
+ * - For big size tensors, we can launch multiple kernels recursively
101
+ * (i.e. if (!iter.can_use_32bit_indexing())) and hence, the philox
102
+ * offset calculation is done in this function.
103
+ *
104
+ * FIXME: Can we specialize elementwise_kernel and launch_kernel in Loops.cuh
105
+ * to have grid-stride loop kernel and then use that to launch our distribution
106
+ * kernels? Note that we need a grid-stride loop kernel because, we found by testing
107
+ * that it achieves peak effective bandwidth.
108
+ */
109
+ template<typename scalar_t,
110
+ typename accscalar_t,
111
+ typename dist_func_return_t,
112
+ typename RNG,
113
+ typename dist_t,
114
+ typename transform_t>
115
+ void distribution_nullary_kernel(at::TensorIteratorBase& iter,
116
+ RNG gen,
117
+ const dist_t& dist_func,
118
+ const transform_t transform_func) {
119
+ const int unroll_factor = sizeof(dist_func_return_t) / sizeof(accscalar_t);
120
+ TORCH_CHECK(unroll_factor >= 1, "unroll_factor must be >= 1.");
121
+ int64_t numel = iter.numel();
122
+ if (numel == 0) {
123
+ return;
124
+ }
125
+
126
+ auto execution_policy = calc_execution_policy(numel, unroll_factor);
127
+ auto counter_offset = std::get<0>(execution_policy);
128
+ auto grid = std::get<1>(execution_policy);
129
+ auto block = std::get<2>(execution_policy);
130
+ PhiloxCudaState rng_engine_inputs;
131
+ {
132
+ // See Note [Acquire lock when using random generators]
133
+ std::lock_guard<std::mutex> lock(gen->mutex_);
134
+ rng_engine_inputs = gen->philox_cuda_state(counter_offset);
135
+ }
136
+
137
+ if (!iter.can_use_32bit_indexing()) {
138
+ for (auto& sub_iter : iter.with_32bit_indexing()) {
139
+ distribution_nullary_kernel<scalar_t, accscalar_t, dist_func_return_t>(sub_iter,
140
+ gen, dist_func, transform_func);
141
+ }
142
+ return;
143
+ }
144
+
145
+ char* out_data = (char*)iter.data_ptr(0);
146
+
147
+ auto stream = at::cuda::getCurrentCUDAStream();
148
+ if (iter.is_trivial_1d()) {
149
+ auto strides = iter.get_inner_strides();
150
+ int stride0 = strides[0];
151
+ distribution_elementwise_grid_stride_kernel<accscalar_t, unroll_factor><<<grid, block, 0, stream>>>(
152
+ numel,
153
+ rng_engine_inputs,
154
+ dist_func,
155
+ [=]__device__(int idx, accscalar_t rand) {
156
+ scalar_t* out = (scalar_t*)&out_data[stride0 * idx];
157
+ *out = transform_func(rand);
158
+ }
159
+ );
160
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
161
+ } else {
162
+ auto offset_calc = make_offset_calculator<1>(iter);
163
+ distribution_elementwise_grid_stride_kernel<accscalar_t, unroll_factor><<<grid, block, 0, stream>>>(
164
+ numel,
165
+ rng_engine_inputs,
166
+ dist_func,
167
+ [=]__device__(int idx, accscalar_t rand) {
168
+ auto offsets = offset_calc.get(idx);
169
+ scalar_t* out = (scalar_t*)&out_data[offsets[0]];
170
+ *out = transform_func(rand);
171
+ }
172
+ );
173
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
174
+ }
175
+ }
176
+
177
+ // Binary kernel
178
+ template <typename func_t, typename inp_offset_calc_t, typename out_offset_calc_t>
179
+ __global__ void distribution_binary_elementwise_kernel(
180
+ int numel,
181
+ func_t f,
182
+ PhiloxCudaState philox_args,
183
+ typename function_traits<func_t>::result_type *output_data,
184
+ const typename function_traits<func_t>::template arg<1>::type *input_data_1,
185
+ const typename function_traits<func_t>::template arg<2>::type *input_data_2,
186
+ inp_offset_calc_t inp_calc,
187
+ out_offset_calc_t out_calc) {
188
+ auto seeds = at::cuda::philox::unpack(philox_args);
189
+
190
+ using input_t_1 = typename function_traits<func_t>::template arg<1>::type;
191
+ using input_t_2 = typename function_traits<func_t>::template arg<2>::type;
192
+
193
+ input_t_1 inputs_1[thread_work_size()];
194
+ input_t_2 inputs_2[thread_work_size()];
195
+
196
+ int base_index = block_work_size() * blockIdx.x;
197
+ int remaining = std::min<int>(numel - base_index, block_work_size());
198
+
199
+ curandStatePhilox4_32_10_t state;
200
+ curand_init(std::get<0>(seeds),
201
+ blockIdx.x * blockDim.x + threadIdx.x,
202
+ std::get<1>(seeds),
203
+ &state);
204
+
205
+ // load data into registers
206
+ int thread_idx = threadIdx.x;
207
+ #pragma unroll
208
+ for (int i = 0; i < thread_work_size(); i++) {
209
+ if (thread_idx >= remaining) {
210
+ break;
211
+ }
212
+ int input_idx = thread_idx + base_index;
213
+ auto offsets = inp_calc.get(input_idx);
214
+ inputs_1[i] = input_data_1[offsets[0]];
215
+ inputs_2[i] = input_data_2[offsets[1]];
216
+
217
+ thread_idx += num_threads();
218
+ }
219
+
220
+ // compute and store
221
+ thread_idx = threadIdx.x;
222
+ #pragma unroll
223
+ for (int i = 0; i < thread_work_size(); i++) {
224
+ if (thread_idx >= remaining) {
225
+ break;
226
+ }
227
+ int input_idx = thread_idx + base_index;
228
+ auto offsets = out_calc.get(input_idx);
229
+ output_data[offsets[0]] = f(state, inputs_1[i], inputs_2[i]);
230
+ thread_idx += num_threads();
231
+ }
232
+ }
233
+
234
+ template <typename func_t>
235
+ void distribution_binary_kernel(TensorIteratorBase &iter, PhiloxCudaState philox_args, const func_t &f) {
236
+ static_assert(std::is_same<typename function_traits<func_t>::template arg<0>::type, curandStatePhilox4_32_10_t&>::value, "the first argument of functor must be curandStatePhilox4_32_10_t");
237
+ using input_t_1 = typename function_traits<func_t>::template arg<1>::type;
238
+ using input_t_2 = typename function_traits<func_t>::template arg<2>::type;
239
+ using output_t = typename function_traits<func_t>::result_type;
240
+
241
+ if (!iter.can_use_32bit_indexing()) {
242
+ for (auto& sub_iter : iter.with_32bit_indexing()) {
243
+ distribution_binary_kernel(sub_iter, philox_args, f);
244
+ }
245
+ return;
246
+ }
247
+
248
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(iter.can_use_32bit_indexing());
249
+
250
+ int64_t numel = iter.numel();
251
+ if (numel == 0) {
252
+ return;
253
+ }
254
+
255
+ output_t *output_data = static_cast<output_t *>(iter.data_ptr(0));
256
+ const input_t_1 *input_data_1 = static_cast<const input_t_1 *>(iter.data_ptr(1));
257
+ const input_t_2 *input_data_2 = static_cast<const input_t_2 *>(iter.data_ptr(2));
258
+
259
+ int64_t grid = (numel + block_work_size() - 1) / block_work_size();
260
+ auto stream = at::cuda::getCurrentCUDAStream();
261
+
262
+ if (iter.is_contiguous()) {
263
+ distribution_binary_elementwise_kernel<<<grid,num_threads(), 0, stream>>>(
264
+ numel, f, philox_args, output_data, input_data_1, input_data_2,
265
+ TrivialOffsetCalculator<2>(), TrivialOffsetCalculator<1>());
266
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
267
+ } else {
268
+ distribution_binary_elementwise_kernel<<<grid, num_threads(), 0, stream>>>(
269
+ numel, f, philox_args, output_data, input_data_1, input_data_2,
270
+ make_input_offset_calculator<2>(iter), make_output_offset_calculator(iter));
271
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
272
+ }
273
+ }
274
+
275
+ } // namespace
276
+ }} // namespace at::native
277
+
278
+
279
+ namespace at {
280
+ namespace native {
281
+ namespace templates {
282
+ namespace cuda {
283
+
284
+ // ==================================================== Random ========================================================
285
+
286
+ template<typename RNG>
287
+ void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG gen) {
288
+ AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cuda", AT_WRAP([&] {
289
+ if ((
290
+ std::is_same<scalar_t, int64_t>::value ||
291
+ std::is_same<scalar_t, double>::value ||
292
+ std::is_same<scalar_t, float>::value ||
293
+ std::is_same<scalar_t, at::BFloat16>::value) && range >= 1ULL << 32)
294
+ {
295
+ // define lambda to mod with range and add base
296
+ auto random_func = [range, base] __device__ (uint64_t rand) {
297
+ return transformation::uniform_int_from_to<scalar_t>(rand, range, base);
298
+ };
299
+ distribution_nullary_kernel<scalar_t, uint64_t, ulonglong2>(iter,
300
+ gen,
301
+ [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
302
+ ulonglong2 ret;
303
+ uint4 rand_val = curand4(state);
304
+ ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
305
+ ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
306
+ return ret;
307
+ },
308
+ random_func);
309
+ } else {
310
+ auto random_func = [range, base] __device__ (uint32_t rand) {
311
+ return transformation::uniform_int_from_to<scalar_t>(rand, range, base);
312
+ };
313
+ distribution_nullary_kernel<scalar_t, uint32_t, uint4>(iter,
314
+ gen,
315
+ [] __device__ (curandStatePhilox4_32_10_t* state) -> uint4 {
316
+ return curand4(state);
317
+ },
318
+ random_func);
319
+ }
320
+ }), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
321
+ }
322
+
323
+ // This is the special kernel to handle single specific case:
324
+ // from(inclusive) = std::numeric_limits<int64_t>::lowest()
325
+ // to(exclusive) = None (= std::numeric_limits<int64_t>::max() + 1)
326
+ template<typename RNG>
327
+ void random_full_64_bits_range_kernel(TensorIteratorBase& iter, RNG gen) {
328
+ AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, iter.dtype(), "random_full_64_bits_range_kernel_cuda", [&] {
329
+ if (std::is_same<scalar_t, int64_t>::value ||
330
+ std::is_same<scalar_t, double>::value ||
331
+ std::is_same<scalar_t, float>::value ||
332
+ std::is_same<scalar_t, at::BFloat16>::value) {
333
+ auto random_func = [] __device__ (uint64_t rand) {
334
+ return transformation::uniform_int_full_range<scalar_t>(rand);
335
+ };
336
+ distribution_nullary_kernel<scalar_t, uint64_t, ulonglong2>(iter,
337
+ gen,
338
+ [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
339
+ ulonglong2 ret;
340
+ uint4 rand_val = curand4(state);
341
+ ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
342
+ ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
343
+ return ret;
344
+ },
345
+ random_func);
346
+ } else {
347
+ TORCH_CHECK(false, "random_full_64_bits_range_kernel_cuda handles only int64, double, float and bfloat16");
348
+ }
349
+ });
350
+ }
351
+
352
+ template<typename RNG>
353
+ struct RandomFromToKernel {
354
+ void operator()(TensorIteratorBase& iter, uint64_t range, int64_t base, std::optional<Generator> gen) {
355
+ random_from_to_kernel(iter, range, base, check_generator<RNG>(gen));
356
+ }
357
+ void operator()(TensorIteratorBase& iter, std::optional<Generator> gen) {
358
+ random_full_64_bits_range_kernel(iter, check_generator<RNG>(gen));
359
+ }
360
+ };
361
+
362
+ template<typename RNG>
363
+ void random_kernel(TensorIteratorBase& iter, RNG gen) {
364
+ AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cuda", [&] {
365
+ if (std::is_same<scalar_t, double>::value || std::is_same<scalar_t, int64_t>::value) {
366
+ auto random_func = [] __device__ (uint64_t rand) {
367
+ return transformation::uniform_int<scalar_t>(rand);
368
+ };
369
+ distribution_nullary_kernel<scalar_t, uint64_t, ulonglong2>(iter, gen,
370
+ [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
371
+ ulonglong2 ret;
372
+ uint4 rand_val = curand4(state);
373
+ ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
374
+ ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
375
+ return ret;
376
+ },
377
+ random_func);
378
+ } else {
379
+ auto random_func = [] __device__ (uint32_t rand) {
380
+ return transformation::uniform_int<scalar_t>(rand);
381
+ };
382
+ distribution_nullary_kernel<scalar_t, uint32_t, uint4>(iter,
383
+ gen,
384
+ [] __device__ (curandStatePhilox4_32_10_t* state) -> uint4 {
385
+ return curand4(state);
386
+ },
387
+ random_func);
388
+ }
389
+ });
390
+ }
391
+
392
+ template<typename RNG>
393
+ struct RandomKernel {
394
+ void operator()(TensorIteratorBase& iter, RNG gen) {
395
+ random_kernel(iter, gen);
396
+ }
397
+ };
398
+
399
+ // ====================================================================================================================
400
+
401
+ template<typename scalar_t, typename accscalar_t, typename RNG, typename transform_t>
402
+ void uniform_and_transform(TensorIteratorBase& iter, RNG gen, transform_t transform) {
403
+ if (std::is_same<scalar_t, double>::value) {
404
+ distribution_nullary_kernel<scalar_t, accscalar_t, double2>(iter,
405
+ gen,
406
+ [] __device__ (curandStatePhilox4_32_10_t* state) -> double2 { return curand_uniform2_double(state); },
407
+ transform);
408
+ } else {
409
+ distribution_nullary_kernel<scalar_t, accscalar_t, float4>(iter,
410
+ gen,
411
+ [] __device__ (curandStatePhilox4_32_10_t* state) -> float4 { return curand_uniform4(state); },
412
+ transform);
413
+ }
414
+ }
415
+
416
+ template<typename scalar_t, typename accscalar_t, typename RNG, typename transform_t>
417
+ void normal_and_transform(TensorIteratorBase& iter, RNG gen, transform_t transform) {
418
+ if (std::is_same<scalar_t, double>::value) {
419
+ distribution_nullary_kernel<scalar_t, accscalar_t, double2>(iter,
420
+ gen,
421
+ [] __device__ (curandStatePhilox4_32_10_t* state) -> double2 { return curand_normal2_double(state); },
422
+ transform);
423
+ } else {
424
+ distribution_nullary_kernel<scalar_t, accscalar_t, float4>(iter,
425
+ gen,
426
+ [] __device__ (curandStatePhilox4_32_10_t* state) -> float4 { return curand_normal4(state); },
427
+ transform);
428
+ }
429
+ }
430
+
431
+ // ==================================================== Normal ========================================================
432
+
433
+ template<typename RNG>
434
+ void normal_kernel(const TensorBase &self, double mean_, double std_, RNG gen) {
435
+ auto iter = TensorIterator::borrowing_nullary_op(self);
436
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "normal_kernel_cuda", [&] {
437
+ using accscalar_t = at::acc_type<scalar_t, true>;
438
+ auto mean = static_cast<accscalar_t>(mean_);
439
+ auto std = static_cast<accscalar_t>(std_);
440
+ // define lambda to multiply std and add mean
441
+ auto normal_func = [mean, std] __device__ (accscalar_t rand) {
442
+ return static_cast<scalar_t>(transformation::normal<accscalar_t>(rand, mean, std));
443
+ };
444
+ normal_and_transform<scalar_t, accscalar_t>(iter, gen, normal_func);
445
+ });
446
+ }
447
+
448
+ template<typename RNG>
449
+ struct NormalKernel {
450
+ void operator()(const TensorBase &self, double mean, double std, std::optional<Generator> gen) {
451
+ normal_kernel(self, mean, std, check_generator<RNG>(gen));
452
+ }
453
+ };
454
+
455
+ // ==================================================== Uniform ========================================================
456
+
457
+ template<typename RNG>
458
+ void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG gen) {
459
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "uniform_kernel_cuda", [&] {
460
+ auto from = static_cast<scalar_t>(from_);
461
+ auto to = static_cast<scalar_t>(to_);
462
+ using opmath_t = at::opmath_type<scalar_t>;
463
+ auto range = static_cast<opmath_t>(to-from);
464
+ // define lambda to reverse bounds, multiply 'range' and add 'from_'
465
+ auto uniform_func = [range, from, to] __device__ (opmath_t rand) {
466
+ // Compute output value before reversing the bounds
467
+ // BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/96947
468
+ auto value = static_cast<scalar_t>(rand * range + from);
469
+ // reverse the bounds of curand4 from (0, 1] to [0, 1)
470
+ // Note that this method is from legacy THCTensorRandom and is likely to give
471
+ // you more 0-s, since, the probability of gettings 1-s is higher than 0-s and
472
+ // by reversing the bounds, we are flipping the probabilities of 1-s and 0-s.
473
+ // BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/16706
474
+ auto reverse_bound_value = value == to ? from : value;
475
+ return reverse_bound_value;
476
+ };
477
+ uniform_and_transform<scalar_t, opmath_t>(iter, gen, uniform_func);
478
+ });
479
+ }
480
+
481
+ template<typename RNG>
482
+ struct UniformKernel {
483
+ void operator()(TensorIteratorBase& iter, double from, double to, std::optional<Generator> gen) {
484
+ uniform_kernel(iter, from, to, check_generator<RNG>(gen));
485
+ }
486
+ };
487
+
488
+ // ================================================== LogNormal =======================================================
489
+
490
+ template<typename RNG>
491
+ void log_normal_kernel(TensorIteratorBase& iter, double mean_, double std_, RNG gen) {
492
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "log_normal_cuda", [&] {
493
+ using accscalar_t = at::acc_type<scalar_t, true>;
494
+ auto mean = static_cast<accscalar_t>(mean_);
495
+ auto std = static_cast<accscalar_t>(std_);
496
+ // define lambda for log_normal transformation
497
+ auto log_normal_func = [mean, std] __device__ (accscalar_t rand) {
498
+ return static_cast<scalar_t>(transformation::log_normal<accscalar_t>(transformation::normal<accscalar_t>(rand, mean, std)));
499
+ };
500
+ normal_and_transform<scalar_t, accscalar_t>(iter, gen, log_normal_func);
501
+ });
502
+ }
503
+
504
+ template<typename RNG>
505
+ struct LogNormalKernel {
506
+ void operator()(TensorIteratorBase& iter, double mean, double std, std::optional<Generator> gen) {
507
+ log_normal_kernel(iter, mean, std, check_generator<RNG>(gen));
508
+ }
509
+ };
510
+
511
+ // =================================================== Geometric ======================================================
512
+
513
+ template<typename RNG>
514
+ void geometric_kernel(TensorIteratorBase& iter, double p, RNG gen) {
515
+ AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_cuda", [&] {
516
+ using accscalar_t = at::DiscreteDistributionType<scalar_t>::type;
517
+ // define lambda for geometric transformation
518
+ auto geometric_func = [p] __device__ (accscalar_t rand) {
519
+ return static_cast<scalar_t>(transformation::geometric<accscalar_t>(rand, p));
520
+ };
521
+ uniform_and_transform<scalar_t, accscalar_t>(iter, gen, geometric_func);
522
+ });
523
+ }
524
+
525
+ template<typename RNG>
526
+ struct GeometricKernel {
527
+ void operator()(TensorIteratorBase& iter, double p, std::optional<Generator> gen) {
528
+ geometric_kernel(iter, p, check_generator<RNG>(gen));
529
+ }
530
+ };
531
+
532
+ // ================================================== Exponential =====================================================
533
+
534
+ template<typename RNG>
535
+ void exponential_kernel(TensorIteratorBase& iter, double lambda_, RNG gen) {
536
+ TORCH_CHECK(isFloatingType(iter.dtype()), "Exponential distribution is a continuous probability distribution. dtype must be a floating point but you specified ", iter.dtype());
537
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_cuda", [&] {
538
+ using accscalar_t = at::acc_type<scalar_t, true>;
539
+ auto lambda = static_cast<accscalar_t>(lambda_);
540
+ // define lambda for exponential transformation
541
+ auto exponential_func = [lambda] __device__ (accscalar_t rand) {
542
+ return static_cast<scalar_t>(transformation::exponential<accscalar_t>(rand, lambda));
543
+ };
544
+ uniform_and_transform<scalar_t, accscalar_t>(iter, gen, exponential_func);
545
+ });
546
+ }
547
+
548
+ template<typename RNG>
549
+ struct ExponentialKernel {
550
+ void operator()(TensorIteratorBase& iter, double lambda, std::optional<Generator> gen) {
551
+ exponential_kernel(iter, lambda, check_generator<RNG>(gen));
552
+ }
553
+ };
554
+
555
+ // ==================================================== Cauchy ========================================================
556
+
557
+ template<typename RNG>
558
+ void cauchy_kernel(TensorIteratorBase& iter, double median_, double sigma_, RNG gen) {
559
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "cauchy_cuda", [&] {
560
+ using accscalar_t = at::acc_type<scalar_t, true>;
561
+ auto median = static_cast<accscalar_t>(median_);
562
+ auto sigma = static_cast<accscalar_t>(sigma_);
563
+ // define lambda for cauchy transformation
564
+ auto cauchy_func = [median, sigma] __device__ (accscalar_t rand) {
565
+ return static_cast<scalar_t>(transformation::cauchy<accscalar_t>(rand, median, sigma));
566
+ };
567
+ uniform_and_transform<scalar_t, accscalar_t>(iter, gen, cauchy_func);
568
+ });
569
+ }
570
+
571
+ template<typename RNG>
572
+ struct CauchyKernel {
573
+ void operator()(TensorIteratorBase& iter, double median, double sigma, std::optional<Generator> gen) {
574
+ cauchy_kernel(iter, median, sigma, check_generator<RNG>(gen));
575
+ }
576
+ };
577
+
578
+ // ==================================================== Bernoulli =====================================================
579
+
580
+ template<typename scalar_t, typename prob_t>
581
+ void bernoulli_tensor_cuda_kernel(
582
+ const TensorBase &ret, const at::TensorBase &p,
583
+ PhiloxCudaState philox_args) {
584
+ auto functor = [philox_args] __device__(
585
+ int n, scalar_t& v1, scalar_t& v2, scalar_t& v3, scalar_t& v4,
586
+ const prob_t& p1, const prob_t& p2, const prob_t& p3, const prob_t& p4) {
587
+ auto seeds = at::cuda::philox::unpack(philox_args);
588
+ curandStatePhilox4_32_10_t state;
589
+ curand_init(std::get<0>(seeds),
590
+ blockIdx.x * blockDim.x + threadIdx.x,
591
+ std::get<1>(seeds),
592
+ &state);
593
+
594
+ // See Note [Register spilling in curand call for CUDA < 10]
595
+ float4 rand = curand_uniform4(&state);
596
+ switch (n) {
597
+ case 4: {
598
+ CUDA_KERNEL_ASSERT(0 <= p4 && p4 <= 1);
599
+ v4 = static_cast<scalar_t>(rand.w <= p4);
600
+ [[fallthrough]];
601
+ }
602
+ case 3: {
603
+ CUDA_KERNEL_ASSERT(0 <= p3 && p3 <= 1);
604
+ v3 = static_cast<scalar_t>(rand.z <= p3);
605
+ [[fallthrough]];
606
+ }
607
+ case 2: {
608
+ CUDA_KERNEL_ASSERT(0 <= p2 && p2 <= 1);
609
+ v2 = static_cast<scalar_t>(rand.y <= p2);
610
+ [[fallthrough]];
611
+ }
612
+ case 1: {
613
+ CUDA_KERNEL_ASSERT(0 <= p1 && p1 <= 1);
614
+ v1 = static_cast<scalar_t>(rand.x <= p1);
615
+ }
616
+ }
617
+ };
618
+ // The template argument `4` below indicates that we want to operate on four
619
+ // element at each time. See NOTE [ CUDA_tensor_applyN helpers ] for details.
620
+ at::cuda::CUDA_tensor_apply2<scalar_t, const prob_t, 4, decltype(functor),
621
+ /*max_threads_per_block=*/512,
622
+ /*min_blocks_per_sm==*/2>(ret, p, functor);
623
+ }
624
+
625
+ template<typename RNG>
626
+ void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG gen) {
627
+ PhiloxCudaState rng_engine_inputs;
628
+ {
629
+ // See Note [Acquire lock when using random generators]
630
+ std::lock_guard<std::mutex> lock(gen->mutex_);
631
+ rng_engine_inputs = gen->philox_cuda_state(10);
632
+ }
633
+ TORCH_CHECK(at::isFloatingType(p_.scalar_type()), "expected probabilities tensor to have floating type, got ", p_.scalar_type());
634
+ // cast probabilities tensor to double for double `self` tensor, and to `float` for everything else
635
+ const auto p_type = self.dtype() == at::kDouble ? at::kDouble : at::kFloat;
636
+ auto p_cuda = p_.to(TensorOptions().device(self.device()).dtype(p_type));
637
+ auto p = expand_inplace(self, p_cuda);
638
+ AT_DISPATCH_ALL_TYPES_AND3(
639
+ at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, self.scalar_type(), "bernoulli_tensor_cuda_self_", [&] {
640
+ if (std::is_same<scalar_t, double>::value) {
641
+ return bernoulli_tensor_cuda_kernel<double, double>(self, *p, rng_engine_inputs);
642
+ } else {
643
+ return bernoulli_tensor_cuda_kernel<scalar_t, float>(self, *p, rng_engine_inputs);
644
+ }
645
+ });
646
+ }
647
+
648
+ template<typename RNG>
649
+ void bernoulli_kernel(TensorIteratorBase& iter, double p, RNG gen) {
650
+ AT_DISPATCH_ALL_TYPES_AND3(
651
+ at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "bernoulli_scalar_cuda_", [&] {
652
+ using accscalar_t = at::DiscreteDistributionType<scalar_t>::type;
653
+ // define lambda for bernoulli transformation
654
+ auto bernoulli_func = [p] __device__ (accscalar_t rand) {
655
+ return static_cast<scalar_t>(transformation::bernoulli<accscalar_t>(rand, p));
656
+ };
657
+ uniform_and_transform<scalar_t, accscalar_t>(iter, gen, bernoulli_func);
658
+ });
659
+ }
660
+
661
+ template<typename RNG>
662
+ struct BernoulliKernel {
663
+ void operator()(TensorIteratorBase& iter, double p, std::optional<Generator> gen) {
664
+ bernoulli_kernel(iter, p, check_generator<RNG>(gen));
665
+ }
666
+ void operator()(const TensorBase &self, const TensorBase &p_, std::optional<Generator> gen) {
667
+ bernoulli_kernel(self, p_, check_generator<RNG>(gen));
668
+ }
669
+ };
670
+
671
+ }}}}
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Distributions.h ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ namespace at {
4
+ struct CUDAGeneratorImpl;
5
+ struct TensorIteratorBase;
6
+ class TensorBase;
7
+
8
+ namespace native {
9
+
10
+ void launch_poisson_cuda_kernel(
11
+ const TensorBase &ret, const TensorBase &lambda, CUDAGeneratorImpl *gen);
12
+
13
+ void launch_gamma_kernel(
14
+ const TensorBase &ret, const TensorBase &alpha, CUDAGeneratorImpl *gen);
15
+
16
+ void launch_binomial_cuda_kernel(
17
+ TensorIteratorBase &iter, CUDAGeneratorImpl *gen);
18
+
19
+ void launch_dirichlet_kernel(TensorIteratorBase &iter);
20
+
21
+ void launch_standard_gamma_grad_kernel(TensorIteratorBase &iter);
22
+
23
+ void launch_dirichlet_grad_kernel(TensorIteratorBase &iter);
24
+
25
+ }} // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/ForeachFunctors.cuh ADDED
@@ -0,0 +1,681 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/OpMathType.h>
3
+ #include <ATen/native/ForeachUtils.h>
4
+ #include <ATen/native/cuda/MultiTensorApply.cuh>
5
+ #include <ATen/native/cuda/Pow.cuh>
6
+
7
+ namespace at::native {
8
+
9
+ namespace {
10
+
11
+ // TODO(crcrpar): Handle version bump in codegen.
12
+ // rel:
13
+ // https://github.com/pytorch/pytorch/blob/9cf84347767c8abb8feba18a9a1baba321eeb8b9/tools/autograd/gen_inplace_or_view_type.py#L481-L482
14
+ inline void increment_version(TensorList tensors) {
15
+ for (const auto& t : tensors) {
16
+ t.unsafeGetTensorImpl()->bump_version();
17
+ }
18
+ }
19
+
20
+ // Initializes args and checks if all args are aligned
21
+ template <int depth, typename T>
22
+ __device__ bool init_args(
23
+ T** args,
24
+ TensorListMetadata<depth>& tl,
25
+ const int64_t chunk_idx,
26
+ const int64_t chunk_size,
27
+ const int64_t tensor_loc) {
28
+ bool all_aligned = true;
29
+ for (int i = 0; i < depth; i++) {
30
+ args[i] = (T*)tl.addresses[i][tensor_loc];
31
+ args[i] += chunk_idx * chunk_size;
32
+
33
+ if (!is_aligned(args[i])) {
34
+ all_aligned = false;
35
+ }
36
+ }
37
+ return all_aligned;
38
+ }
39
+
40
+ // Initializes args and checks if all args are aligned
41
+ template <int depth, typename T, typename T2>
42
+ __device__ bool init_args(
43
+ T** args,
44
+ TensorListScalarListMetadata<T2, depth>& tl,
45
+ const int64_t chunk_idx,
46
+ const int64_t chunk_size,
47
+ const int64_t tensor_loc) {
48
+ bool all_aligned = true;
49
+ for (int i = 0; i < depth; i++) {
50
+ args[i] = (T*)tl.addresses[i][tensor_loc];
51
+ args[i] += chunk_idx * chunk_size;
52
+
53
+ if (!is_aligned(args[i])) {
54
+ all_aligned = false;
55
+ }
56
+ }
57
+ return all_aligned;
58
+ }
59
+
60
+ template <int depth, typename T>
61
+ __device__ bool init_args(
62
+ T** args,
63
+ FusedOptimizerTensorListMetadata<depth>& tl,
64
+ const int64_t chunk_idx,
65
+ const int64_t chunk_size,
66
+ const int64_t tensor_loc) {
67
+ bool all_aligned = true;
68
+ for (int i = 0; i < depth; i++) {
69
+ args[i] = (T*)tl.addresses[i][tensor_loc];
70
+ args[i] += chunk_idx * chunk_size;
71
+
72
+ if (!is_aligned(args[i])) {
73
+ all_aligned = false;
74
+ }
75
+ }
76
+ return all_aligned;
77
+ }
78
+
79
+ template <int depth, typename T>
80
+ __device__ void load_args(
81
+ T r_args[][kILP],
82
+ T** args,
83
+ const int64_t i_start,
84
+ const int64_t chunk_size,
85
+ const int64_t n) {
86
+ #pragma unroll
87
+ for (int ii = 0; ii < kILP; ii++) {
88
+ const auto i = i_start + threadIdx.x + ii * blockDim.x;
89
+ for (int r_index = 0; r_index < depth; r_index++) {
90
+ r_args[r_index][ii] = 0;
91
+ if (i < n && i < chunk_size) {
92
+ r_args[r_index][ii] = args[r_index][i];
93
+ }
94
+ }
95
+ }
96
+ }
97
+
98
+ template <typename T>
99
+ __device__ void store_args(
100
+ T* dst,
101
+ T* src,
102
+ const int64_t i_start,
103
+ const int64_t chunk_size,
104
+ const int64_t n) {
105
+ #pragma unroll
106
+ for (int ii = 0; ii < kILP; ii++) {
107
+ const int64_t i = i_start + threadIdx.x + ii * blockDim.x;
108
+ if (i < n && i < chunk_size)
109
+ dst[i] = src[ii];
110
+ }
111
+ }
112
+
113
+ template <int res_arg_index, typename Op, typename T, typename opmath_t>
114
+ __device__ __forceinline__ void binary_op_scalar(
115
+ T r_args[][kILP],
116
+ T** args,
117
+ opmath_t scalar,
118
+ const int64_t n,
119
+ const int64_t chunk_size,
120
+ const bool all_aligned,
121
+ Op op) {
122
+ // to make things simple, we put aligned case in a different code path
123
+ if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
124
+ for (int64_t i_start = threadIdx.x;
125
+ i_start * kILP < n && i_start * kILP < chunk_size;
126
+ i_start += blockDim.x) {
127
+ // load
128
+ load_store(r_args[0], args[0], 0, i_start);
129
+ #pragma unroll
130
+ for (int ii = 0; ii < kILP; ii++) {
131
+ r_args[0][ii] = static_cast<T>(
132
+ op(static_cast<opmath_t>(r_args[0][ii]),
133
+ static_cast<opmath_t>(scalar)));
134
+ }
135
+ // store
136
+ load_store(args[res_arg_index], r_args[0], i_start, 0);
137
+ }
138
+ } else {
139
+ for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
140
+ i_start += blockDim.x * kILP) {
141
+ // Regardless if depth is 1 (for inplace) or 2 (for out of place), r_args
142
+ // has depth 1
143
+ load_args<1>(r_args, args, i_start, chunk_size, n);
144
+ #pragma unroll
145
+ for (int ii = 0; ii < kILP; ii++) {
146
+ r_args[0][ii] = static_cast<T>(
147
+ op(static_cast<opmath_t>(r_args[0][ii]),
148
+ static_cast<opmath_t>(scalar)));
149
+ }
150
+ store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
151
+ }
152
+ }
153
+ }
154
+
155
+ template <int res_arg_index, typename Op, typename T, typename opmath_t>
156
+ __device__ __forceinline__ void pointwise_op_scalar(
157
+ T r_args[][kILP],
158
+ T** args,
159
+ opmath_t scalar,
160
+ const int64_t n,
161
+ const int64_t chunk_size,
162
+ const bool all_aligned,
163
+ Op op) {
164
+ // to make things simple, we put aligned case in a different code path
165
+ if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
166
+ for (int64_t i_start = threadIdx.x;
167
+ i_start * kILP < n && i_start * kILP < chunk_size;
168
+ i_start += blockDim.x) {
169
+ // load
170
+ load_store(r_args[0], args[0], 0, i_start);
171
+ load_store(r_args[1], args[1], 0, i_start);
172
+ load_store(r_args[2], args[2], 0, i_start);
173
+ #pragma unroll
174
+ for (int ii = 0; ii < kILP; ii++) {
175
+ r_args[0][ii] = static_cast<T>(
176
+ static_cast<opmath_t>(r_args[0][ii]) +
177
+ scalar *
178
+ op(static_cast<opmath_t>(r_args[1][ii]),
179
+ static_cast<opmath_t>(r_args[2][ii])));
180
+ }
181
+ // store
182
+ load_store(args[res_arg_index], r_args[0], i_start, 0);
183
+ }
184
+ } else {
185
+ for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
186
+ i_start += blockDim.x * kILP) {
187
+ // Regardless if depth is 3 (for inplace) or 4 (for out of place), r_args
188
+ // has depth 3
189
+ load_args<3>(r_args, args, i_start, chunk_size, n);
190
+ #pragma unroll
191
+ for (int ii = 0; ii < kILP; ii++) {
192
+ r_args[0][ii] = static_cast<T>(
193
+ static_cast<opmath_t>(r_args[0][ii]) +
194
+ scalar *
195
+ op(static_cast<opmath_t>(r_args[1][ii]),
196
+ static_cast<opmath_t>(r_args[2][ii])));
197
+ }
198
+ store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
199
+ }
200
+ }
201
+ }
202
+
203
+ //
204
+ // Binary Functors
205
+ //
206
+ template <typename T, int depth, int r_args_depth, int res_arg_index>
207
+ struct BinaryOpScalarFunctor {
208
+ using opmath_t = at::opmath_type<T>;
209
+ template <typename Op>
210
+ __device__ __forceinline__ void operator()(
211
+ int chunk_size,
212
+ TensorListMetadata<depth>& tl,
213
+ Op op,
214
+ opmath_t scalar) {
215
+ const int tensor_loc = tl.block_to_tensor[blockIdx.x];
216
+ const int chunk_idx = tl.block_to_chunk[blockIdx.x];
217
+ auto n = tl.numel_for_tensor[tensor_loc];
218
+
219
+ T* args[depth];
220
+ const bool all_aligned =
221
+ init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
222
+ n -= chunk_idx * chunk_size;
223
+ T r_args[r_args_depth][kILP];
224
+
225
+ binary_op_scalar<res_arg_index>(
226
+ r_args, args, scalar, n, chunk_size, all_aligned, op);
227
+ }
228
+ };
229
+
230
+ template <typename T, int depth, int r_args_depth, int res_arg_index>
231
+ struct BinaryOpScalarListFunctor {
232
+ using opmath_t = at::opmath_type<T>;
233
+ template <typename Op>
234
+ __device__ __forceinline__ void operator()(
235
+ int chunk_size,
236
+ TensorListScalarListMetadata<opmath_t, depth>& tl,
237
+ Op op) {
238
+ const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
239
+ const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
240
+ auto n = tl.numel_for_tensor[tensor_loc];
241
+
242
+ T* args[depth];
243
+ const bool all_aligned =
244
+ init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
245
+ opmath_t scalar = tl.scalar_vals[tensor_loc];
246
+ n -= chunk_idx * chunk_size;
247
+ T r_args[r_args_depth][kILP];
248
+
249
+ binary_op_scalar<res_arg_index>(
250
+ r_args, args, scalar, n, chunk_size, all_aligned, op);
251
+ }
252
+ };
253
+
254
+ template <typename T, int depth, int r_args_depth, int res_arg_index>
255
+ struct BinaryOpListAlphaFunctor {
256
+ using opmath_t = at::opmath_type<T>;
257
+ template <typename Op>
258
+ __device__ __forceinline__ void operator()(
259
+ int chunk_size,
260
+ TensorListMetadata<depth>& tl,
261
+ Op op,
262
+ opmath_t alpha) {
263
+ const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
264
+ const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
265
+ auto n = tl.numel_for_tensor[tensor_loc];
266
+
267
+ T* args[depth];
268
+ const bool all_aligned =
269
+ init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
270
+ n -= chunk_idx * chunk_size;
271
+ T r_args[r_args_depth][kILP];
272
+
273
+ // to make things simple, we put aligned case in a different code path
274
+ if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
275
+ for (int64_t i_start = threadIdx.x;
276
+ i_start * kILP < n && i_start * kILP < chunk_size;
277
+ i_start += blockDim.x) {
278
+ // load
279
+ load_store(r_args[0], args[0], 0, i_start);
280
+ load_store(r_args[1], args[1], 0, i_start);
281
+ #pragma unroll
282
+ for (int ii = 0; ii < kILP; ii++) {
283
+ r_args[0][ii] = static_cast<T>(
284
+ op(static_cast<opmath_t>(r_args[0][ii]),
285
+ alpha * static_cast<opmath_t>(r_args[1][ii])));
286
+ }
287
+ // store
288
+ load_store(args[res_arg_index], r_args[0], i_start, 0);
289
+ }
290
+ } else {
291
+ for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
292
+ i_start += blockDim.x * kILP) {
293
+ load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
294
+ #pragma unroll
295
+ for (int ii = 0; ii < kILP; ii++) {
296
+ r_args[0][ii] = static_cast<T>(
297
+ op(static_cast<opmath_t>(r_args[0][ii]),
298
+ alpha * static_cast<opmath_t>(r_args[1][ii])));
299
+ }
300
+ store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
301
+ }
302
+ }
303
+ }
304
+ };
305
+
306
+ template <typename T, int depth, int r_args_depth, int res_arg_index>
307
+ struct BinaryOpScalarTensorFunctor {
308
+ using opmath_t = at::opmath_type<T>;
309
+ template <typename Op>
310
+ __device__ __forceinline__ void operator()(
311
+ int chunk_size,
312
+ TensorListMetadata<depth>& tl,
313
+ Op op,
314
+ T* scalar,
315
+ opmath_t alpha) {
316
+ const int tensor_loc = tl.block_to_tensor[blockIdx.x];
317
+ const int chunk_idx = tl.block_to_chunk[blockIdx.x];
318
+ auto n = tl.numel_for_tensor[tensor_loc];
319
+
320
+ T* args[depth];
321
+ const bool all_aligned =
322
+ init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
323
+ n -= chunk_idx * chunk_size;
324
+ T r_args[r_args_depth][kILP];
325
+
326
+ // to make things simple, we put aligned case in a different code path
327
+ if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
328
+ for (int64_t i_start = threadIdx.x;
329
+ i_start * kILP < n && i_start * kILP < chunk_size;
330
+ i_start += blockDim.x) {
331
+ // load
332
+ load_store(r_args[0], args[0], 0, i_start);
333
+ #pragma unroll
334
+ for (int ii = 0; ii < kILP; ii++) {
335
+ r_args[0][ii] = static_cast<T>(op(
336
+ static_cast<opmath_t>(r_args[0][ii]),
337
+ static_cast<opmath_t>(alpha) * static_cast<opmath_t>(*scalar)));
338
+ }
339
+ // store
340
+ load_store(args[res_arg_index], r_args[0], i_start, 0);
341
+ }
342
+ } else {
343
+ for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
344
+ i_start += blockDim.x * kILP) {
345
+ // Regardless if depth is 1 (for inplace) or 2 (for out of place),
346
+ // r_args has depth 1
347
+ load_args<1>(r_args, args, i_start, chunk_size, n);
348
+ #pragma unroll
349
+ for (int ii = 0; ii < kILP; ii++) {
350
+ r_args[0][ii] = static_cast<T>(op(
351
+ static_cast<opmath_t>(r_args[0][ii]),
352
+ static_cast<opmath_t>(alpha) * static_cast<opmath_t>(*scalar)));
353
+ }
354
+ store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
355
+ }
356
+ }
357
+ }
358
+ };
359
+
360
+ //
361
+ // Unary Functors
362
+ //
363
+
364
+ template <typename T, int depth, int r_args_depth, int res_arg_index>
365
+ struct ZeroFunctor {
366
+ __device__ __forceinline__ void operator()(
367
+ int chunk_size,
368
+ TensorListMetadata<1>& tl) {
369
+ const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
370
+ const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
371
+ auto n = tl.numel_for_tensor[tensor_loc];
372
+
373
+ T* args[depth];
374
+ const auto all_aligned =
375
+ init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
376
+ n -= chunk_idx * chunk_size;
377
+ T r_args[r_args_depth][kILP];
378
+
379
+ // to make things simple, we put aligned case in a different code path
380
+ if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
381
+ for (int64_t i_start = threadIdx.x;
382
+ i_start * kILP < n && i_start * kILP < chunk_size;
383
+ i_start += blockDim.x) {
384
+ #pragma unroll
385
+ for (int ii = 0; ii < kILP; ii++) {
386
+ r_args[0][ii] = 0;
387
+ }
388
+ // store
389
+ load_store(args[0], r_args[0], i_start, 0);
390
+ }
391
+ } else {
392
+ for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
393
+ i_start += blockDim.x * kILP) {
394
+ #pragma unroll
395
+ for (int ii = 0; ii < kILP; ii++) {
396
+ r_args[0][ii] = 0;
397
+ }
398
+ store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
399
+ }
400
+ }
401
+ }
402
+ };
403
+
404
+ template <typename T, int depth, int r_args_depth, int res_arg_index>
405
+ struct UnaryOpFunctor {
406
+ using opmath_t = at::opmath_type<T>;
407
+ template <typename Op>
408
+ __device__ __forceinline__ void operator()(
409
+ int chunk_size,
410
+ TensorListMetadata<depth>& tl,
411
+ Op op) {
412
+ const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
413
+ const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
414
+ auto n = tl.numel_for_tensor[tensor_loc];
415
+
416
+ T* args[depth];
417
+ bool all_aligned =
418
+ init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
419
+ n -= chunk_idx * chunk_size;
420
+ T r_args[r_args_depth][kILP];
421
+
422
+ // to make things simple, we put aligned case in a different code path
423
+ if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
424
+ for (int64_t i_start = threadIdx.x;
425
+ i_start * kILP < n && i_start * kILP < chunk_size;
426
+ i_start += blockDim.x) {
427
+ // load
428
+ load_store(r_args[0], args[0], 0, i_start);
429
+ #pragma unroll
430
+ for (int ii = 0; ii < kILP; ii++) {
431
+ r_args[0][ii] =
432
+ static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii])));
433
+ }
434
+ // store
435
+ load_store(args[res_arg_index], r_args[0], i_start, 0);
436
+ }
437
+ } else {
438
+ for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
439
+ i_start += blockDim.x * kILP) {
440
+ load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
441
+ #pragma unroll
442
+ for (int ii = 0; ii < kILP; ii++) {
443
+ r_args[0][ii] =
444
+ static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii])));
445
+ }
446
+ store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
447
+ }
448
+ }
449
+ }
450
+ };
451
+
452
+ //
453
+ // Pointwise Functors
454
+ //
455
+
456
+ template <typename T, int depth, int r_args_depth, int res_arg_index>
457
+ struct PointwiseOpScalarFunctor {
458
+ using opmath_t = at::opmath_type<T>;
459
+ template <typename Op>
460
+ __device__ __forceinline__ void operator()(
461
+ int chunk_size,
462
+ TensorListMetadata<depth>& tl,
463
+ Op op,
464
+ opmath_t scalar) {
465
+ const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
466
+ const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
467
+ auto n = tl.numel_for_tensor[tensor_loc];
468
+
469
+ T* args[depth];
470
+ const bool all_aligned =
471
+ init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
472
+ n -= chunk_idx * chunk_size;
473
+ T r_args[r_args_depth][kILP];
474
+
475
+ pointwise_op_scalar<res_arg_index>(
476
+ r_args, args, scalar, n, chunk_size, all_aligned, op);
477
+ }
478
+ };
479
+
480
+ template <typename T, int depth, int r_args_depth, int res_arg_index>
481
+ struct PointwiseOpScalarListFunctor {
482
+ using opmath_t = at::opmath_type<T>;
483
+ template <typename Op>
484
+ __device__ __forceinline__ void operator()(
485
+ int chunk_size,
486
+ TensorListScalarListMetadata<opmath_t, depth>& tl,
487
+ Op op) {
488
+ const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
489
+ const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
490
+ auto n = tl.numel_for_tensor[tensor_loc];
491
+
492
+ T* args[depth];
493
+ const bool all_aligned =
494
+ init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
495
+ opmath_t scalar = tl.scalar_vals[tensor_loc];
496
+ n -= chunk_idx * chunk_size;
497
+ T r_args[r_args_depth][kILP];
498
+
499
+ pointwise_op_scalar<res_arg_index>(
500
+ r_args, args, scalar, n, chunk_size, all_aligned, op);
501
+ }
502
+ };
503
+
504
+ template <typename T, int depth>
505
+ struct PointwiseOpListFunctor {
506
+ using opmath_t = at::opmath_type<T>;
507
+ template <typename Op>
508
+ __device__ __forceinline__ void operator()(
509
+ int chunk_size,
510
+ TensorListMetadata<depth>& tl,
511
+ Op op) {
512
+ const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
513
+ const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
514
+ auto n = tl.numel_for_tensor[tensor_loc];
515
+
516
+ T* args[depth];
517
+ const bool all_aligned =
518
+ init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
519
+ n -= chunk_idx * chunk_size;
520
+ T r_args[depth - 1][kILP];
521
+
522
+ // to make things simple, we put aligned case in a different code path
523
+ if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
524
+ for (int64_t i_start = threadIdx.x;
525
+ i_start * kILP < n && i_start * kILP < chunk_size;
526
+ i_start += blockDim.x) {
527
+ // load
528
+ load_store(r_args[0], args[0], 0, i_start);
529
+ load_store(r_args[1], args[1], 0, i_start);
530
+ #pragma unroll
531
+ for (int ii = 0; ii < kILP; ii++) {
532
+ r_args[0][ii] = static_cast<T>(
533
+ op(static_cast<opmath_t>(r_args[0][ii]),
534
+ static_cast<opmath_t>(r_args[1][ii])));
535
+ }
536
+ // store
537
+ load_store(args[2], r_args[0], i_start, 0);
538
+ }
539
+ } else {
540
+ for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
541
+ i_start += blockDim.x * kILP) {
542
+ load_args<depth - 1>(r_args, args, i_start, chunk_size, n);
543
+ #pragma unroll
544
+ for (int ii = 0; ii < kILP; ii++) {
545
+ r_args[0][ii] = static_cast<T>(
546
+ op(static_cast<opmath_t>(r_args[0][ii]),
547
+ static_cast<opmath_t>(r_args[1][ii])));
548
+ }
549
+ store_args(args[2], r_args[0], i_start, chunk_size, n);
550
+ }
551
+ }
552
+ }
553
+ };
554
+
555
+ template <typename T, int depth, int r_args_depth, int res_arg_index>
556
+ struct TernaryOpListFunctor {
557
+ using opmath_t = at::opmath_type<T>;
558
+ template <typename Op>
559
+ __device__ __forceinline__ void operator()(
560
+ int chunk_size,
561
+ TensorListMetadata<depth>& tl,
562
+ Op op) {
563
+ static_assert(depth == 3 || depth == 4, "");
564
+ static_assert(depth >= r_args_depth, "");
565
+ static_assert(res_arg_index == depth - 1 || res_arg_index == 0, "");
566
+ const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
567
+ const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
568
+ auto n = tl.numel_for_tensor[tensor_loc];
569
+
570
+ T* args[depth];
571
+ const bool all_aligned =
572
+ init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
573
+ n -= chunk_idx * chunk_size;
574
+ T r_args[r_args_depth][kILP];
575
+
576
+ if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
577
+ for (int64_t i_start = threadIdx.x;
578
+ i_start * kILP < n && i_start * kILP < chunk_size;
579
+ i_start += blockDim.x) {
580
+ load_store(r_args[0], args[0], 0, i_start);
581
+ load_store(r_args[1], args[1], 0, i_start);
582
+ load_store(r_args[2], args[2], 0, i_start);
583
+ #pragma unroll
584
+ for (int ii = 0; ii < kILP; ii++) {
585
+ r_args[0][ii] =
586
+ op(static_cast<opmath_t>(r_args[0][ii]),
587
+ static_cast<opmath_t>(r_args[1][ii]),
588
+ static_cast<opmath_t>(r_args[2][ii]));
589
+ }
590
+ load_store(args[res_arg_index], r_args[0], i_start, 0);
591
+ }
592
+ } else {
593
+ for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
594
+ i_start += blockDim.x * kILP) {
595
+ load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
596
+ #pragma unroll
597
+ for (int ii = 0; ii < kILP; ii++) {
598
+ r_args[0][ii] =
599
+ op(static_cast<opmath_t>(r_args[0][ii]),
600
+ static_cast<opmath_t>(r_args[1][ii]),
601
+ static_cast<opmath_t>(r_args[2][ii]));
602
+ }
603
+ store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
604
+ }
605
+ }
606
+ }
607
+ };
608
+
609
+ template <typename T, int depth, int r_args_depth, int res_arg_index>
610
+ struct TernaryOpScalarFunctor {
611
+ using opmath_t = at::opmath_type<T>;
612
+ template <typename Op>
613
+ __device__ __forceinline__ void operator()(
614
+ int chunk_size,
615
+ TensorListMetadata<depth>& tl,
616
+ Op op,
617
+ opmath_t alpha) {
618
+ static_assert(depth == 2 || depth == 3, "");
619
+ static_assert(depth >= r_args_depth, "");
620
+ static_assert(res_arg_index == depth - 1 || res_arg_index == 0, "");
621
+ const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
622
+ const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
623
+ auto n = tl.numel_for_tensor[tensor_loc];
624
+
625
+ T* args[depth];
626
+ const bool all_aligned =
627
+ init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
628
+ n -= chunk_idx * chunk_size;
629
+ T r_args[r_args_depth][kILP];
630
+
631
+ // to make things simple, we put aligned case in a different code path
632
+ if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
633
+ for (int64_t i_start = threadIdx.x;
634
+ i_start * kILP < n && i_start * kILP < chunk_size;
635
+ i_start += blockDim.x) {
636
+ // load
637
+ load_store(r_args[0], args[0], 0, i_start);
638
+ load_store(r_args[1], args[1], 0, i_start);
639
+ #pragma unroll
640
+ for (int ii = 0; ii < kILP; ii++) {
641
+ r_args[0][ii] =
642
+ op(static_cast<opmath_t>(r_args[0][ii]),
643
+ static_cast<opmath_t>(r_args[1][ii]),
644
+ alpha);
645
+ }
646
+ // store
647
+ load_store(args[res_arg_index], r_args[0], i_start, 0);
648
+ }
649
+ } else {
650
+ for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
651
+ i_start += blockDim.x * kILP) {
652
+ load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
653
+ #pragma unroll
654
+ for (int ii = 0; ii < kILP; ii++) {
655
+ r_args[0][ii] =
656
+ op(static_cast<opmath_t>(r_args[0][ii]),
657
+ static_cast<opmath_t>(r_args[1][ii]),
658
+ alpha);
659
+ }
660
+ store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
661
+ }
662
+ }
663
+ }
664
+ };
665
+
666
+ template <typename T>
667
+ struct power_functor {
668
+ C10_DEVICE T operator()(const T& a, const T& b) const {
669
+ return at::native::pow_(a, b);
670
+ }
671
+ };
672
+
673
+ template <typename T>
674
+ struct reverse_power_functor {
675
+ C10_DEVICE T operator()(const T& a, const T& b) const {
676
+ return at::native::pow_(b, a);
677
+ }
678
+ };
679
+
680
+ } // namespace
681
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/GridSampler.cuh ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/native/cuda/KernelUtils.cuh>
3
+ #include <ATen/native/GridSamplerUtils.h>
4
+
5
+ namespace at { namespace native {
6
+
7
+ using detail::GridSamplerInterpolation;
8
+ using detail::GridSamplerPadding;
9
+
10
+ // Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value,
11
+ // where we view each pixel as an area between (idx - 0.5) and (idx + 0.5).
12
+ // if align_corners: -1 and +1 get sent to the centers of the corner pixels
13
+ // -1 --> 0
14
+ // +1 --> (size - 1)
15
+ // scale_factor = (size - 1) / 2
16
+ // if not align_corners: -1 and +1 get sent to the image edges
17
+ // -1 --> -0.5
18
+ // +1 --> (size - 1) + 0.5 == size - 0.5
19
+ // scale_factor = size / 2
20
+ template <typename scalar_t>
21
+ __forceinline__ __device__
22
+ scalar_t grid_sampler_unnormalize(scalar_t coord, int size, bool align_corners) {
23
+ if (align_corners) {
24
+ // unnormalize coord from [-1, 1] to [0, size - 1]
25
+ return ((coord + 1.f) / 2) * (size - 1);
26
+ } else {
27
+ // unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
28
+ return ((coord + 1.f) * size - 1) / 2;
29
+ }
30
+ }
31
+
32
+ // grid_sampler_unnormalize_set_grad works the same as grid_sampler_unnormalize
33
+ // except that it also returns the `d output / d input` via pointer argument
34
+ // `grad_in`.
35
+ // This is useful in the backward pass of grid_sampler.
36
+ template <typename scalar_t>
37
+ __forceinline__ __device__
38
+ scalar_t grid_sampler_unnormalize_set_grad(scalar_t coord, int size,
39
+ bool align_corners, scalar_t *grad_in) {
40
+ if (align_corners) {
41
+ // unnormalize coord from [-1, 1] to [0, size - 1]
42
+ *grad_in = static_cast<scalar_t>(size - 1) / 2;
43
+ return ((coord + 1.f) / 2) * (size - 1);
44
+ } else {
45
+ // unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
46
+ *grad_in = static_cast<scalar_t>(size) / 2;
47
+ return ((coord + 1.f) * size - 1) / 2;
48
+ }
49
+ }
50
+
51
+ // Clips coordinates to between 0 and clip_limit - 1
52
+ template <typename scalar_t>
53
+ __forceinline__ __device__
54
+ scalar_t clip_coordinates(scalar_t in, int clip_limit) {
55
+ return ::min(static_cast<scalar_t>(clip_limit - 1), ::max(in, static_cast<scalar_t>(0)));
56
+ }
57
+
58
+ // clip_coordinates_set_grad works similarly to clip_coordinates except that
59
+ // it also returns the `d output / d input` via pointer argument `grad_in`.
60
+ // This is useful in the backward pass of grid_sampler.
61
+ template <typename scalar_t>
62
+ __forceinline__ __device__
63
+ scalar_t clip_coordinates_set_grad(scalar_t in, int clip_limit, scalar_t *grad_in) {
64
+ // Note that it is important for the gradient calculation that borders
65
+ // are considered out of bounds.
66
+ if (in <= static_cast<scalar_t>(0)) {
67
+ *grad_in = static_cast<scalar_t>(0);
68
+ return static_cast<scalar_t>(0);
69
+ } else {
70
+ scalar_t max = static_cast<scalar_t>(clip_limit - 1);
71
+ if (in >= max) {
72
+ *grad_in = static_cast<scalar_t>(0);
73
+ return max;
74
+ } else {
75
+ *grad_in = static_cast<scalar_t>(1);
76
+ return in;
77
+ }
78
+ }
79
+ }
80
+
81
+ // Reflects coordinates until they fall between low and high (inclusive).
82
+ // The bounds are passed as twice their value so that half-integer values
83
+ // can be represented as ints.
84
+ template <typename scalar_t>
85
+ __forceinline__ __device__
86
+ scalar_t reflect_coordinates(scalar_t in, int twice_low, int twice_high) {
87
+ if (twice_low == twice_high) {
88
+ return static_cast<scalar_t>(0);
89
+ }
90
+ scalar_t min = static_cast<scalar_t>(twice_low) / 2;
91
+ scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
92
+ in = ::fabs(in - min);
93
+ // `fmod` returns same sign as `in`, which is positive after the `fabs` above.
94
+ scalar_t extra = ::fmod(in, span);
95
+ int flips = static_cast<int>(::floor(in / span));
96
+ if (flips % 2 == 0) {
97
+ return extra + min;
98
+ } else {
99
+ return span - extra + min;
100
+ }
101
+ }
102
+
103
+ // reflect_coordinates_set_grad works similarly to reflect_coordinates except
104
+ // that it also returns the `d output / d input` via pointer argument
105
+ // `grad_in`.
106
+ // This is useful in the backward pass of grid_sampler.
107
+ template <typename scalar_t>
108
+ __forceinline__ __device__
109
+ scalar_t reflect_coordinates_set_grad(scalar_t in, int twice_low, int twice_high,
110
+ scalar_t *grad_in) {
111
+ if (twice_low == twice_high) {
112
+ *grad_in = static_cast<scalar_t>(0);
113
+ return static_cast<scalar_t>(0);
114
+ }
115
+ int grad_in_mult_;
116
+ scalar_t min = static_cast<scalar_t>(twice_low) / 2;
117
+ scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
118
+ in = in - min;
119
+ if (in < static_cast<scalar_t>(0)) {
120
+ grad_in_mult_ = -1;
121
+ in = -in;
122
+ } else {
123
+ grad_in_mult_ = 1;
124
+ }
125
+ // `fmod` returns same sign as `in`, which is positive after the `if` above.
126
+ scalar_t extra = ::fmod(in, span);
127
+ int flips = static_cast<int>(::floor(in / span));
128
+ if (flips % 2 == 0) {
129
+ *grad_in = static_cast<scalar_t>(grad_in_mult_);
130
+ return extra + min;
131
+ } else {
132
+ *grad_in = static_cast<scalar_t>(-grad_in_mult_);
133
+ return span - extra + min;
134
+ }
135
+ }
136
+
137
+ template<typename scalar_t>
138
+ __forceinline__ __device__
139
+ scalar_t safe_downgrade_to_int_range(scalar_t x){
140
+ // -100.0 does not have special meaning. This is just to make sure
141
+ // it's not within_bounds_2d or within_bounds_3d, and does not cause
142
+ // undefined behavior. See #35506.
143
+ if (x > INT_MAX-1 || x < INT_MIN || !::isfinite(static_cast<double>(x)))
144
+ return static_cast<scalar_t>(-100.0);
145
+ return x;
146
+ }
147
+
148
+ template<typename scalar_t>
149
+ __forceinline__ __device__
150
+ scalar_t compute_coordinates(scalar_t coord, int size,
151
+ GridSamplerPadding padding_mode,
152
+ bool align_corners) {
153
+ if (padding_mode == GridSamplerPadding::Border) {
154
+ // clip coordinates to image borders
155
+ coord = clip_coordinates(coord, size);
156
+ } else if (padding_mode == GridSamplerPadding::Reflection) {
157
+ // reflect coordinates by image borders
158
+ if (align_corners) {
159
+ coord = reflect_coordinates(coord, 0, 2*(size - 1));
160
+ } else {
161
+ coord = reflect_coordinates(coord, -1, 2*size - 1);
162
+ }
163
+ // clip coordinates to image borders
164
+ coord = clip_coordinates(coord, size);
165
+ }
166
+
167
+ coord = safe_downgrade_to_int_range(coord);
168
+ return coord;
169
+ }
170
+
171
+ // Computes the pixel source index value for a grid coordinate
172
+ template <typename scalar_t>
173
+ __forceinline__ __device__
174
+ scalar_t grid_sampler_compute_source_index(
175
+ scalar_t coord,
176
+ int size,
177
+ GridSamplerPadding padding_mode,
178
+ bool align_corners) {
179
+ coord = grid_sampler_unnormalize(coord, size, align_corners);
180
+ coord = compute_coordinates(coord, size, padding_mode, align_corners);
181
+ return coord;
182
+ }
183
+
184
+ // grid_sampler_compute_source_index_set_grad works similarly to
185
+ // grid_sampler_compute_source_index except that it also returns the
186
+ // `d output / d input` via pointer argument `grad_in`.
187
+ // This is useful in the backward pass of grid_sampler.
188
+ template <typename scalar_t>
189
+ __forceinline__ __device__
190
+ scalar_t grid_sampler_compute_source_index_set_grad(
191
+ scalar_t coord,
192
+ int size,
193
+ GridSamplerPadding padding_mode,
194
+ bool align_corners,
195
+ scalar_t *grad_in) {
196
+ scalar_t grad_clip, grad_refl;
197
+ coord = grid_sampler_unnormalize_set_grad(coord, size, align_corners, grad_in);
198
+ if (padding_mode == GridSamplerPadding::Border) {
199
+ // clip coordinates to image borders
200
+ coord = clip_coordinates_set_grad(coord, size, &grad_clip);
201
+ *grad_in = (*grad_in) * grad_clip;
202
+ } else if (padding_mode == GridSamplerPadding::Reflection) {
203
+ // reflect coordinates by image borders
204
+ if (align_corners) {
205
+ coord = reflect_coordinates_set_grad(coord, 0, 2*(size - 1), &grad_refl);
206
+ } else {
207
+ coord = reflect_coordinates_set_grad(coord, -1, 2*size - 1, &grad_refl);
208
+ }
209
+ // clip coordinates to image borders
210
+ coord = clip_coordinates_set_grad(coord, size, &grad_clip);
211
+ *grad_in = (*grad_in) * grad_refl * grad_clip;
212
+ }
213
+
214
+ coord = safe_downgrade_to_int_range(coord);
215
+ return coord;
216
+ }
217
+
218
+ __forceinline__ __device__
219
+ bool within_bounds_2d(int h, int w, int H, int W) {
220
+ return h >= 0 && h < H && w >= 0 && w < W;
221
+ }
222
+
223
+ __forceinline__ __device__
224
+ bool within_bounds_3d(int d, int h, int w, int D, int H, int W) {
225
+ return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W;
226
+ }
227
+
228
+ template<typename scalar_t>
229
+ __forceinline__ __device__
230
+ scalar_t get_value_bounded(
231
+ const scalar_t *data, scalar_t x, scalar_t y, int W, int H, int sW, int sH,
232
+ GridSamplerPadding padding_mode,
233
+ bool align_corners) {
234
+
235
+ x = compute_coordinates(x, W, padding_mode, align_corners);
236
+ y = compute_coordinates(y, H, padding_mode, align_corners);
237
+
238
+ int ix = static_cast<int>(x);
239
+ int iy = static_cast<int>(y);
240
+
241
+ if (within_bounds_2d(iy, ix, H, W)) {
242
+ return data[iy * sH + ix * sW];
243
+ }
244
+ return static_cast<scalar_t>(0);
245
+ }
246
+
247
+ template<typename scalar_t, typename index_t>
248
+ __forceinline__ __device__
249
+ void safe_add_2d(scalar_t *data, int h, int w,
250
+ int sH, int sW, int H, int W,
251
+ scalar_t delta,
252
+ const index_t NC_offset,
253
+ const index_t memory_span) {
254
+ if (within_bounds_2d(h, w, H, W)) {
255
+ fastAtomicAdd(data,
256
+ NC_offset + h * sH + w * sW,
257
+ memory_span,
258
+ delta,
259
+ true);
260
+ }
261
+ }
262
+
263
+ template<typename scalar_t, typename index_t>
264
+ __forceinline__ __device__
265
+ void safe_add_3d(scalar_t *data, int d, int h, int w,
266
+ int sD, int sH, int sW, int D, int H, int W,
267
+ scalar_t delta,
268
+ const index_t NC_offset,
269
+ const index_t memory_span) {
270
+ if (within_bounds_3d(d, h, w, D, H, W)) {
271
+ fastAtomicAdd(data,
272
+ NC_offset + d * sD + h * sH + w * sW,
273
+ memory_span,
274
+ delta,
275
+ true);
276
+ }
277
+ }
278
+
279
+ template<typename scalar_t, typename index_t>
280
+ __forceinline__ __device__
281
+ void add_value_bounded(
282
+ scalar_t* data, scalar_t x, scalar_t y, int W, int H, int sW, int sH,
283
+ scalar_t delta,
284
+ GridSamplerPadding padding_mode,
285
+ bool align_corners,
286
+ const index_t NC_offset,
287
+ const index_t memory_span) {
288
+
289
+ x = compute_coordinates(x, W, padding_mode, align_corners);
290
+ y = compute_coordinates(y, H, padding_mode, align_corners);
291
+
292
+ int ix = static_cast<int>(x);
293
+ int iy = static_cast<int>(y);
294
+
295
+ safe_add_2d(data, iy, ix, sH, sW, H, W, delta, NC_offset, memory_span);
296
+ }
297
+
298
+ // Calculate the differential of the cubic convolution, i.e. `d coeff / d x`
299
+ template<typename scalar_t>
300
+ __forceinline__ __device__
301
+ void get_cubic_coefficients_grad(
302
+ scalar_t coeffs[4],
303
+ scalar_t t) {
304
+
305
+ // Must be the same as forward calculation in
306
+ // aten/src/ATen/native/cuda/UpSample.cuh:get_cubic_upsample_coefficients
307
+ scalar_t A = -0.75;
308
+
309
+ scalar_t x;
310
+ x = -1 - t; // 1 < x = |-1 - tx| < 2
311
+ coeffs[0] = (-3 * A * x - 10 * A ) * x - 8 * A;
312
+ x = -t; // x = |0 - tx| <= 1
313
+ coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x;
314
+ x = 1 - t; // x = |1 - tx| <= 1
315
+ coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x;
316
+ x = 2 - t; // 1 < x = |2 - tx| < 2
317
+ coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A;
318
+ }
319
+
320
+
321
+ }} // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/IndexKernel.h ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <c10/core/ScalarType.h>
3
+ #include <cstdint>
4
+
5
+ namespace at {
6
+ struct TensorIteratorBase;
7
+ class TensorBase;
8
+ }
9
+
10
+ namespace at {
11
+ namespace native {
12
+ /// @param maskPrefixSum[in,out]
13
+ void launch_masked_scatter_kernel(
14
+ const TensorBase &self, const TensorBase &mask,
15
+ const TensorBase &maskPrefixSum, const TensorBase &source);
16
+ }}
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/KernelUtils.cuh ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/cuda/Atomic.cuh>
3
+
4
+ #if !(defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
5
+ #include <cuda_bf16.h>
6
+ #endif
7
+
8
+ namespace at {
9
+ namespace native {
10
+
11
+ __device__ __forceinline__ size_t
12
+ idx(const size_t nc,
13
+ const size_t height,
14
+ const size_t width,
15
+ const size_t h,
16
+ const size_t w) {
17
+ return (nc * height + h) * width + w;
18
+ }
19
+
20
+ // for channels-last
21
+ __device__ __forceinline__ size_t
22
+ idx_cl(
23
+ const size_t n, const size_t h, const size_t w, const size_t c,
24
+ const size_t height, const size_t width, const size_t channel
25
+ ) {
26
+ return ((n * height + h) * width + w) * channel + c;
27
+ }
28
+
29
+ // fastSpecializedAtomicAdd (and fastAtomicAdd) are an optimization
30
+ // that speed up half-precision atomics. The situation with half
31
+ // precision atomics is that we have a slow __half atomic, and
32
+ // a fast vectored __half2 atomic (this can be worth up to a 6x
33
+ // speedup, see https://github.com/pytorch/pytorch/pull/21879).
34
+ // We can convert a __half atomic into a __half2 atomic by simply
35
+ // pairing the __half with a zero entry on the left/right depending
36
+ // on alignment... but only if this wouldn't cause an out of bounds
37
+ // access! Thus, you must specify tensor and numel so we can check
38
+ // if you would be out-of-bounds and use a plain __half atomic if
39
+ // you would be.
40
+ template <
41
+ typename scalar_t,
42
+ typename index_t,
43
+ typename std::enable_if<std::is_same<c10::Half, scalar_t>::value>::type* =
44
+ nullptr>
45
+ __device__ __forceinline__ void fastSpecializedAtomicAdd(
46
+ scalar_t* tensor,
47
+ index_t index,
48
+ const index_t numel,
49
+ scalar_t value) {
50
+ #if ( \
51
+ (defined(USE_ROCM)) || \
52
+ (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
53
+ gpuAtomicAddNoReturn(
54
+ reinterpret_cast<at::Half*>(tensor) + index,
55
+ static_cast<at::Half>(value));
56
+ #else
57
+ // Accounts for the chance tensor falls on an odd 16 bit alignment (ie, not 32 bit aligned)
58
+ __half* target_addr = reinterpret_cast<__half*>(tensor + index);
59
+ bool low_byte = (reinterpret_cast<std::uintptr_t>(target_addr) % sizeof(__half2) == 0);
60
+
61
+ if (low_byte && index < (numel - 1)) {
62
+ __half2 value2;
63
+ value2.x = static_cast<__half>(value);
64
+ value2.y = __int2half_rz(0);
65
+ atomicAdd(reinterpret_cast<__half2*>(target_addr), value2);
66
+
67
+ } else if (!low_byte && index > 0) {
68
+ __half2 value2;
69
+ value2.x = __int2half_rz(0);
70
+ value2.y = static_cast<__half>(value);
71
+ atomicAdd(reinterpret_cast<__half2*>(target_addr - 1), value2);
72
+
73
+ } else {
74
+ atomicAdd(
75
+ reinterpret_cast<__half*>(tensor) + index, static_cast<__half>(value));
76
+ }
77
+ #endif
78
+ }
79
+
80
+ template <
81
+ typename scalar_t,
82
+ typename index_t,
83
+ typename std::enable_if<std::is_same<c10::BFloat16, scalar_t>::value>::type* =
84
+ nullptr>
85
+ __device__ __forceinline__ void fastSpecializedAtomicAdd(
86
+ scalar_t* tensor,
87
+ index_t index,
88
+ const index_t numel,
89
+ scalar_t value) {
90
+ #if ( \
91
+ (defined(USE_ROCM)) || \
92
+ (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))
93
+ gpuAtomicAddNoReturn(
94
+ reinterpret_cast<at::BFloat16*>(tensor) + index,
95
+ static_cast<at::BFloat16>(value));
96
+ #else
97
+ // Accounts for the chance tensor falls on an odd 16 bit alignment (ie, not 32 bit aligned)
98
+ __nv_bfloat16* target_addr = reinterpret_cast<__nv_bfloat16*>(tensor + index);
99
+ bool low_byte = (reinterpret_cast<std::uintptr_t>(target_addr) % sizeof(__nv_bfloat162) == 0);
100
+
101
+ if (low_byte && index < (numel - 1)) {
102
+ __nv_bfloat162 value2;
103
+ value2.x = *reinterpret_cast<__nv_bfloat16*>(&value);
104
+ value2.y = __int2bfloat16_rz(0);
105
+ atomicAdd(reinterpret_cast<__nv_bfloat162*>(target_addr), value2);
106
+
107
+ } else if (!low_byte && index > 0) {
108
+ __nv_bfloat162 value2;
109
+ value2.x = __int2bfloat16_rz(0);
110
+ value2.y = *reinterpret_cast<__nv_bfloat16*>(&value);
111
+ atomicAdd(reinterpret_cast<__nv_bfloat162*>(target_addr - 1), value2);
112
+
113
+ } else {
114
+ atomicAdd(
115
+ reinterpret_cast<__nv_bfloat16*>(tensor) + index, *reinterpret_cast<__nv_bfloat16*>(&value));
116
+ }
117
+ #endif
118
+ }
119
+
120
+
121
+ template <
122
+ typename scalar_t,
123
+ typename index_t,
124
+ typename std::enable_if<!std::is_same<c10::Half, scalar_t>::value && !std::is_same<c10::BFloat16, scalar_t>::value >::type* =
125
+ nullptr>
126
+ __device__ __forceinline__ void fastSpecializedAtomicAdd(
127
+ scalar_t* tensor,
128
+ index_t index,
129
+ const index_t numel,
130
+ scalar_t value) {
131
+ gpuAtomicAddNoReturn(tensor + index, value);
132
+ }
133
+
134
+ template <class scalar_t, class index_t>
135
+ __device__ __forceinline__ void fastAtomicAdd(
136
+ scalar_t* tensor,
137
+ index_t index,
138
+ const index_t numel,
139
+ scalar_t value,
140
+ bool fast_atomics) {
141
+ if (fast_atomics) {
142
+ fastSpecializedAtomicAdd(tensor, index, numel, value);
143
+ } else {
144
+ gpuAtomicAddNoReturn(tensor + index, value);
145
+ }
146
+ }
147
+
148
+ } // namespace native
149
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/LaunchUtils.h ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include<algorithm>
3
+
4
+ namespace at {
5
+ namespace native {
6
+
7
+ // returns 2**floor(log2(n))
8
+ static int lastPow2(unsigned int n) {
9
+ n |= (n >> 1);
10
+ n |= (n >> 2);
11
+ n |= (n >> 4);
12
+ n |= (n >> 8);
13
+ n |= (n >> 16);
14
+ return std::max<int>(1, n - (n >> 1));
15
+ }
16
+
17
+ } // namespace native
18
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/MemoryAccess.cuh ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cstdint>
4
+ #include <type_traits>
5
+ #include <c10/core/DynamicCast.h>
6
+ #include <c10/util/Exception.h>
7
+ #include <c10/util/TypeCast.h>
8
+ #include <c10/macros/Macros.h>
9
+ #include <ATen/core/Array.h>
10
+ #include <ATen/detail/FunctionTraits.h>
11
+ #include <ATen/cuda/detail/OffsetCalculator.cuh>
12
+ #include <ATen/native/cuda/thread_constants.h>
13
+
14
+ #include <thrust/tuple.h>
15
+
16
+ // References:
17
+ // https://devblogs.nvidia.com/cuda-pro-tip-increase-performance-with-vectorized-memory-access/
18
+
19
+ namespace at { namespace native { namespace memory {
20
+
21
+ namespace detail {
22
+
23
+ // What does the `static_unroll` do?
24
+ //
25
+ // We want to do something like:
26
+ //
27
+ // using args_t = typename traits::ArgsTuple;
28
+ // args_t args;
29
+ // #pragma unroll
30
+ // for (int i = 0; i < traits::arity; i++) {
31
+ // std::get<i>(args) = ....
32
+ // }
33
+ //
34
+ // but unfortunately the above code does not work because
35
+ // the template argument has to be a compile time constant
36
+ // so `static_unroll` is created to simulate `#pragma unroll`
37
+ // using template metaprogramming.
38
+
39
+ template<template<int i> typename func, int end, int current=0>
40
+ struct static_unroll {
41
+ template<typename... Args>
42
+ static inline C10_HOST_DEVICE void with_args(Args&&... args) {
43
+ func<current>::apply(std::forward<Args>(args)...);
44
+ static_unroll<func, end, current+1>::with_args(args...);
45
+ }
46
+ };
47
+
48
+ template<template<int i> typename func, int end>
49
+ struct static_unroll<func, end, end> {
50
+ template<typename... Args>
51
+ static inline C10_HOST_DEVICE void with_args(Args... args) {}
52
+ };
53
+
54
+ // helper structs to be used with static_unroll to load arguments
55
+ // one by one
56
+
57
+ template<int arg_index>
58
+ struct vectorized_load_helper {
59
+ template <typename args_t, typename policy_t>
60
+ static __device__ void apply(policy_t &self, args_t *args, int idx) {
61
+ using arg_t = std::tuple_element_t<arg_index, args_t>;
62
+ // `data` hold the data_ptr for tensors [output, input0, input1, ...], so we
63
+ // need a +1 offset to get the input
64
+ auto ptr = reinterpret_cast<arg_t *>(self.data[arg_index + 1]) + block_work_size() * idx;
65
+ auto args_accessor = [&args] __device__ (int thread_unroll_idx) -> arg_t & { return std::get<arg_index>(args[thread_unroll_idx]); };
66
+ self.load_single_arg(args_accessor, ptr);
67
+ }
68
+ };
69
+
70
+ template<int arg_index>
71
+ struct unroll_load_helper {
72
+ template <typename args_t, typename policy_t, typename offset_t, typename loader_t>
73
+ static __device__ void apply(policy_t &self, args_t *args, offset_t offset, loader_t loader, int j, int num_outputs) {
74
+ using arg_t = std::tuple_element_t<arg_index, args_t>;
75
+ // `data` hold the data_ptr for tensors [output, input0, input1, ...], so we
76
+ // need a +1 offset to get the input
77
+ std::get<arg_index>(args[j]) = loader.template load<arg_t>(self.data[arg_index + num_outputs], offset[arg_index], arg_index);
78
+ }
79
+ };
80
+
81
+ template <int current>
82
+ struct multi_outputs_store_helper {
83
+ template<int ntensors, int num_outputs, typename ...Args>
84
+ C10_HOST_DEVICE static void apply(
85
+ at::detail::Array<char*, ntensors> data,
86
+ at::detail::Array<uint32_t, num_outputs> offsets,
87
+ thrust::tuple<Args...> ret) {
88
+ using T = typename thrust::tuple_element<current, thrust::tuple<Args...>>::type;
89
+ T *to = reinterpret_cast<T *>(data[current]) + offsets[current];
90
+ *to = thrust::get<current>(ret);
91
+ }
92
+ };
93
+
94
+ } // namespace detail
95
+
96
+ struct LoadWithoutCast {
97
+ template<typename scalar_t>
98
+ __device__ scalar_t load(char *base_ptr, uint32_t offset, int arg) {
99
+ return c10::load(reinterpret_cast<scalar_t *>(base_ptr) + offset);
100
+ }
101
+ };
102
+
103
+ template <int N>
104
+ struct LoadWithCast {
105
+ using array_t = at::detail::Array<at::ScalarType, std::max<int>(N, 1)>;
106
+ using size_array_t = at::detail::Array<uint32_t, std::max<int>(N, 1)>;
107
+
108
+ array_t dtypes;
109
+ size_array_t element_sizes;
110
+
111
+ LoadWithCast(const TensorIteratorBase& iter) {
112
+ CUDA_KERNEL_ASSERT(iter.ninputs() == N);
113
+ #pragma unroll
114
+ for (auto i = 0; i < N; ++i) {
115
+ this->dtypes[i] = iter.dtype(i + iter.noutputs());
116
+ element_sizes[i] = c10::elementSize(iter.dtype(i + iter.noutputs()));
117
+ }
118
+ }
119
+
120
+ template<typename scalar_t>
121
+ __device__ scalar_t load(char *base_ptr, uint32_t offset, int arg) {
122
+ void *ptr = base_ptr + element_sizes[arg] * offset;
123
+ return c10::fetch_and_cast<scalar_t>(dtypes[arg], ptr);
124
+ }
125
+ };
126
+
127
+ struct StoreWithoutCast {
128
+ template<typename scalar_t>
129
+ __device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg = 0) {
130
+ *(reinterpret_cast<scalar_t *>(base_ptr) + offset) = value;
131
+ }
132
+ };
133
+
134
+ template <int N = 1>
135
+ struct StoreWithCast {
136
+ using array_t = at::detail::Array<at::ScalarType, std::max<int>(N, 1)>;
137
+ using size_array_t = at::detail::Array<uint32_t, std::max<int>(N, 1)>;
138
+
139
+ array_t dtypes;
140
+ size_array_t element_sizes;
141
+
142
+ StoreWithCast(const TensorIteratorBase& iter) {
143
+ CUDA_KERNEL_ASSERT(iter.noutputs() == N);
144
+ #pragma unroll
145
+ for (auto i = 0; i < N; ++i) {
146
+ this->dtypes[i] = iter.dtype(i);
147
+ element_sizes[i] = c10::elementSize(iter.dtype(i));
148
+ }
149
+ }
150
+
151
+ template<typename scalar_t>
152
+ __device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg = 0) {
153
+ void *ptr = base_ptr + element_sizes[arg] * offset;
154
+ c10::cast_and_store<scalar_t>(dtypes[arg], ptr, value);
155
+ }
156
+ };
157
+
158
+ // aligned vector generates vectorized load/store on CUDA
159
+ template<typename scalar_t, int vec_size>
160
+ struct alignas(sizeof(scalar_t) * vec_size) aligned_vector {
161
+ scalar_t val[vec_size];
162
+ };
163
+
164
+ template <int vec_size, typename scalar_t>
165
+ __device__ aligned_vector<scalar_t, vec_size> load_vector(const scalar_t *base_ptr, uint32_t offset) {
166
+ using vec_t = aligned_vector<scalar_t, vec_size>;
167
+ auto *from = reinterpret_cast<const vec_t *>(base_ptr);
168
+ return from[offset];
169
+ }
170
+
171
+ template <int vec_size>
172
+ __device__ aligned_vector<bool, vec_size> load_vector(const bool *base_ptr, uint32_t offset) {
173
+ // See NOTE [Loading boolean values]
174
+ auto tmp = load_vector<vec_size>(reinterpret_cast<const uint8_t*>(base_ptr), offset);
175
+ aligned_vector<bool, vec_size> ret;
176
+ for (int i = 0; i < vec_size; ++i) {
177
+ ret.val[i] = bool(tmp.val[i]);
178
+ }
179
+ return ret;
180
+ }
181
+
182
+ namespace policies {
183
+
184
+ // Assumption:
185
+ // all tensors are contiguous, that is: stride == sizeof(type) for all tensors
186
+ template<typename data_t, typename inp_calc_t, typename out_calc_t, typename loader_t, typename storer_t, int num_outputs = 1>
187
+ struct unroll {
188
+
189
+ data_t data;
190
+ int remaining;
191
+ inp_calc_t input_offset_calculator;
192
+ out_calc_t output_offset_calculator;
193
+ loader_t loader;
194
+ storer_t storer;
195
+
196
+ __device__ unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s):
197
+ data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc), loader(l), storer(s) {}
198
+
199
+ __device__ inline bool check_inbounds(int thread_work_elem) {
200
+ return ((int)(threadIdx.x + thread_work_elem*num_threads()) < remaining);
201
+ }
202
+
203
+ template<typename args_t>
204
+ __device__ inline void load(args_t *args, int idx) {
205
+ constexpr int arity = std::tuple_size<args_t>::value;
206
+ int thread_idx = threadIdx.x;
207
+ #pragma unroll
208
+ for (int i = 0; i < thread_work_size(); i++) {
209
+ if (thread_idx >= remaining) {
210
+ return;
211
+ }
212
+ int linear_idx = thread_idx + block_work_size() * idx;
213
+ auto offset = input_offset_calculator.get(linear_idx);
214
+ detail::static_unroll<detail::unroll_load_helper, arity>::with_args(*this, args, offset, loader, i, num_outputs);
215
+ thread_idx += num_threads();
216
+ }
217
+ }
218
+
219
+ template<typename scalar_t>
220
+ __device__ inline void store(scalar_t *from, int idx) {
221
+ int thread_idx = threadIdx.x;
222
+ #pragma unroll
223
+ for (int i = 0; i < thread_work_size(); i++) {
224
+ if (thread_idx >= remaining) {
225
+ return;
226
+ }
227
+ int linear_idx = thread_idx + block_work_size() * idx;
228
+ int offset = output_offset_calculator.get(linear_idx)[0];
229
+ storer.store(from[i], data[0], offset);
230
+ thread_idx += num_threads();
231
+ }
232
+ }
233
+ };
234
+
235
+ // Assumption:
236
+ // all tensors are contiguous, that is: stride == sizeof(type) for all tensors
237
+ // Note:
238
+ // Functions in vectorized policy does not do boundary check. It assumes the whole block
239
+ // has its job to do. So the reminders should be handled by the caller manually.
240
+ template <int vec_size, typename data_t> // vec_size: number of scalars, can be 1, 2, or 4.
241
+ struct vectorized {
242
+
243
+ static_assert(thread_work_size() % vec_size == 0, "The workload per thread must be a multiple of vec_size");
244
+ static constexpr int loop_size = thread_work_size() / vec_size;
245
+
246
+ data_t data;
247
+
248
+ __device__ vectorized(data_t data) : data(data) {}
249
+
250
+ __device__ inline constexpr bool check_inbounds(int thread_work_elem) {
251
+ return true;
252
+ }
253
+
254
+ template<typename accessor_t, typename scalar_t>
255
+ __device__ inline void load_single_arg(accessor_t to, scalar_t *from) {
256
+ int thread_idx = threadIdx.x;
257
+ #pragma unroll
258
+ for (int i = 0; i < loop_size; i++) {
259
+ int index = thread_idx + i * num_threads();
260
+ auto v = load_vector<vec_size>(from, index);
261
+ #pragma unroll
262
+ for (int j = 0; j < vec_size; j++) {
263
+ to(vec_size * i + j) = v.val[j];
264
+ }
265
+ }
266
+ }
267
+
268
+ template<typename args_t>
269
+ __device__ inline void load(args_t *args, int idx) {
270
+ constexpr int arity = std::tuple_size<args_t>::value;
271
+ detail::static_unroll<detail::vectorized_load_helper, arity>::with_args(*this, args, idx);
272
+ }
273
+
274
+ template<typename scalar_t>
275
+ __device__ inline void store(scalar_t *from, int idx) {
276
+ using vec_t = aligned_vector<scalar_t, vec_size>;
277
+ scalar_t *to = reinterpret_cast<scalar_t *>(data[0]) + block_work_size() * idx;
278
+ vec_t *to_ = reinterpret_cast<vec_t *>(to);
279
+ int thread_idx = threadIdx.x;
280
+ #pragma unroll
281
+ for (int i = 0; i < loop_size; i++) {
282
+ int index = thread_idx + i * num_threads();
283
+ vec_t v;
284
+ for (int j = 0; j < vec_size; j++) {
285
+ v.val[j] = from[vec_size * i + j];
286
+ }
287
+ to_[index] = v;
288
+ }
289
+ }
290
+ };
291
+
292
+ template <typename data_t, typename inp_calc_t, typename out_calc_t, int num_outputs>
293
+ struct multi_outputs_unroll {
294
+ //multi_outputs_unroll struct members and check_inbounds and load methods are copypasted from unroll struct
295
+ //we don't use inheritance because of compiler bug in cuda 10.2+
296
+ data_t data;
297
+ int remaining;
298
+ inp_calc_t input_offset_calculator;
299
+ out_calc_t output_offset_calculator;
300
+ LoadWithoutCast loader;
301
+ StoreWithoutCast storer;
302
+
303
+ __device__ multi_outputs_unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc):
304
+ data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc) {}
305
+
306
+ __device__ inline bool check_inbounds(int thread_work_elem) {
307
+ return ((int)(threadIdx.x + thread_work_elem*num_threads()) < remaining);
308
+ }
309
+
310
+ template<typename args_t>
311
+ __device__ inline void load(args_t *args, int idx) {
312
+ constexpr int arity = std::tuple_size<args_t>::value;
313
+ int thread_idx = threadIdx.x;
314
+ #pragma unroll
315
+ for (int i = 0; i < thread_work_size(); i++) {
316
+ if (thread_idx >= remaining) {
317
+ return;
318
+ }
319
+ int linear_idx = thread_idx + block_work_size() * idx;
320
+ auto offset = input_offset_calculator.get(linear_idx);
321
+ detail::static_unroll<detail::unroll_load_helper, arity>::with_args(*this, args, offset, loader, i, num_outputs);
322
+ thread_idx += num_threads();
323
+ }
324
+ }
325
+
326
+
327
+ template <typename return_t>
328
+ __device__ inline void store(return_t *from, int idx) {
329
+ int thread_idx = threadIdx.x;
330
+ #pragma unroll
331
+ for (int i = 0; i < thread_work_size(); i++) {
332
+ if (thread_idx >= this->remaining) {
333
+ return;
334
+ }
335
+ int linear_idx = thread_idx + block_work_size() * idx;
336
+ auto offsets = this->output_offset_calculator.get(linear_idx);
337
+ memory::detail::static_unroll<detail::multi_outputs_store_helper, num_outputs>::with_args(this->data, offsets, from[i]);
338
+ thread_idx += num_threads();
339
+ }
340
+ }
341
+ };
342
+
343
+ } // namespace policies
344
+
345
+ // This is only used in host, but we will wrap this into some templates
346
+ // which is C10_HOST_DEVICE, so we have to make this C10_HOST_DEVICE
347
+ // in order to compile
348
+ template<typename scalar_t>
349
+ inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) {
350
+ uint64_t address = reinterpret_cast<uint64_t>(pointer);
351
+ constexpr int vec2_alignment = std::alignment_of<aligned_vector<scalar_t, 2>>::value;
352
+ constexpr int vec4_alignment = std::alignment_of<aligned_vector<scalar_t, 4>>::value;
353
+ if (address % vec4_alignment == 0) {
354
+ return 4;
355
+ } else if (address % vec2_alignment == 0) {
356
+ return 2;
357
+ }
358
+ return 1;
359
+ }
360
+
361
+ template<typename scalar_t>
362
+ inline C10_HOST_DEVICE int can_vectorize_up_to(char *pointer) {
363
+ return can_vectorize_up_to<scalar_t>(static_cast<const char*>(pointer));
364
+ }
365
+
366
+ template<int i>
367
+ struct can_vectorize_up_to_helper {
368
+ template <typename array_t, typename traits>
369
+ static C10_HOST_DEVICE void apply(int &result, array_t pointers, traits _) {
370
+ using arg_t = typename traits::template arg<i>::type;
371
+ // `pointers` hold the data_ptr for tensors [output, input0, input1, ...], so we
372
+ // need a +1 offset to get the input
373
+ result = std::min<int>(result, can_vectorize_up_to<arg_t>(pointers[i + 1]));
374
+ }
375
+ };
376
+
377
+ template<typename func_t, typename array_t>
378
+ inline int can_vectorize_up_to(array_t pointers) {
379
+ using traits = function_traits<func_t>;
380
+ using return_t = typename traits::result_type;
381
+ constexpr int arity = traits::arity;
382
+ int result = can_vectorize_up_to<return_t>(pointers[0]);
383
+ // We need to get the type for each argument of `func_t`, this can only
384
+ // be done at compile time.
385
+ detail::static_unroll<can_vectorize_up_to_helper, arity>::with_args(result, pointers, traits());
386
+ return result;
387
+ }
388
+
389
+ }}} // namespace at::native::memory
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/MiscUtils.h ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/cuda/Exceptions.h>
3
+ #include <ATen/cuda/CUDAContext.h>
4
+ #include <ATen/cuda/CUDAConfig.h>
5
+ #include <ATen/cuda/PinnedMemoryAllocator.h>
6
+
7
+ namespace at {
8
+ namespace native {
9
+
10
+ static inline int cuda_int_cast(int64_t value, const char* varname) {
11
+ auto result = static_cast<int>(value);
12
+ TORCH_CHECK(static_cast<int64_t>(result) == value,
13
+ "cuda_int_cast: The value of ", varname, "(", (long long)value,
14
+ ") is too large to fit into a int (", sizeof(int), " bytes)");
15
+ return result;
16
+ }
17
+
18
+ // Creates an array of size elements of type T, backed by pinned memory
19
+ // wrapped in a Storage
20
+ template<class T>
21
+ static inline Storage pin_memory(int64_t size) {
22
+ auto* allocator = cuda::getPinnedMemoryAllocator();
23
+ int64_t adjusted_size = size * sizeof(T);
24
+ return Storage(
25
+ Storage::use_byte_size_t(),
26
+ adjusted_size,
27
+ allocator,
28
+ /*resizable=*/false);
29
+ }
30
+
31
+ } // namespace native
32
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/MultiTensorApply.cuh ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/Tensor.h>
3
+ #include <ATen/cuda/CUDAContext.h>
4
+ #include <c10/cuda/CUDAGuard.h>
5
+ #include <ATen/native/cuda/Loops.cuh>
6
+ #include <ATen/native/cuda/MemoryAccess.cuh>
7
+ #include <vector>
8
+
9
+ namespace at::native {
10
+
11
+ namespace {
12
+
13
+ static constexpr int64_t kILP = 4;
14
+ static constexpr int64_t kChunkSize = 65536;
15
+ static constexpr int64_t kBlockSize = 512;
16
+
17
+ // TODO(crcrpar): Add `n>5` for `low prec params & their higher prec copy`
18
+ // TensorListMetadata has to be < 4KB - the limit for kernel launch argument
19
+ static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
20
+ static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
21
+ static constexpr int depth_to_max_tensors_scalarlist[5] = {96, 64, 48, 36, 30};
22
+ static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = {
23
+ 72,
24
+ 60};
25
+
26
+ template <typename T>
27
+ __device__ __forceinline__ bool is_aligned(T* p) {
28
+ return ((uint64_t)p) % (kILP * sizeof(T)) == 0;
29
+ }
30
+
31
+ template <typename T>
32
+ __device__ __forceinline__ void load_store(
33
+ T* dst,
34
+ T* src,
35
+ int64_t dst_offset,
36
+ int64_t src_offset) {
37
+ using LT = at::native::memory::aligned_vector<T, kILP>;
38
+ ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
39
+ }
40
+
41
+ template <int n>
42
+ struct TensorListMetadata {
43
+ const void* addresses[n][depth_to_max_tensors[n - 1]];
44
+ int64_t numel_for_tensor[depth_to_max_tensors[n - 1]];
45
+ unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
46
+ int block_to_chunk[depth_to_max_blocks[n - 1]];
47
+ int start_tensor_this_launch;
48
+ };
49
+
50
+ template <typename scalar_vals_t, int n>
51
+ struct TensorListScalarListMetadata {
52
+ const void* addresses[n][depth_to_max_tensors_scalarlist[n - 1]];
53
+ int64_t numel_for_tensor[depth_to_max_tensors_scalarlist[n - 1]];
54
+ scalar_vals_t scalar_vals[depth_to_max_tensors_scalarlist[n - 1]];
55
+ unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
56
+ int block_to_chunk[depth_to_max_blocks[n - 1]];
57
+ };
58
+
59
+ // note(mkozuki): `n` of 1&2 violate the limit of cuda kernel argument size of
60
+ // 4kb with `c10::complex<double>`
61
+ template <>
62
+ struct TensorListScalarListMetadata<c10::complex<double>, 1> {
63
+ const void* addresses[1]
64
+ [depth_to_max_tensors_scalarlist_of_complex_double[0]];
65
+ int64_t
66
+ numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[0]];
67
+ c10::complex<double>
68
+ scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[0]];
69
+ unsigned char block_to_tensor[depth_to_max_blocks[1 - 1]];
70
+ int block_to_chunk[depth_to_max_blocks[1 - 1]];
71
+ };
72
+
73
+ template <>
74
+ struct TensorListScalarListMetadata<c10::complex<double>, 2> {
75
+ const void* addresses[2]
76
+ [depth_to_max_tensors_scalarlist_of_complex_double[1]];
77
+ int64_t
78
+ numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[1]];
79
+ c10::complex<double>
80
+ scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[1]];
81
+ unsigned char block_to_tensor[depth_to_max_blocks[2 - 1]];
82
+ int block_to_chunk[depth_to_max_blocks[2 - 1]];
83
+ };
84
+
85
+ // NOTE(crcrpar): This is a conservative resolution to handle `state_steps`
86
+ // whose each element is `at::Tensor` of 1 element representing the number of
87
+ // `step`s called so far.
88
+ template <int n>
89
+ struct FusedOptimizerTensorListMetadata {
90
+ const void* addresses[n][depth_to_max_tensors[n - 1]];
91
+ int64_t numel_for_tensor[depth_to_max_tensors[n - 1]];
92
+ const void* state_steps_addresses[depth_to_max_tensors_scalarlist[n - 1]];
93
+ unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
94
+ int block_to_chunk[depth_to_max_blocks[n - 1]];
95
+ int start_tensor_this_launch;
96
+ };
97
+
98
+ template <typename T, typename U, typename... ArgTypes>
99
+ C10_LAUNCH_BOUNDS_1(kBlockSize)
100
+ __global__ void multi_tensor_apply_kernel(
101
+ T tensorListMeta,
102
+ U callable,
103
+ ArgTypes... args) {
104
+ // Hand the chunk information to the user-supplied functor to process however
105
+ // it likes.
106
+ callable(kChunkSize, tensorListMeta, args...);
107
+ }
108
+
109
+ } // namespace
110
+
111
+ // multi_tensor_apply enables horizontal fusion across lists of tensors.
112
+ // For example, whereas you once had a for-loop of a + b = c, where a, b,
113
+ // and c are individual tensors in lists as, bs, and cs, you can now with
114
+ // fewer kernel launches compute as + bs = cs.
115
+ //
116
+ // You can also imagine bs to be a scalar list vs a tensor list.
117
+ //
118
+ // The function below takes in tensor lists, scalars, and a callable and
119
+ // chunks up the computation to launch as few kernels as possible by iterating
120
+ // through every "chunk" in every tensor (thus the nested for loops). In the
121
+ // simplest case, everything gets bundled into just one kernel launch, but
122
+ // due to blocksize constraints, we may need to launch multiple kernels.
123
+ // Each kernel launch is defined by one tensorListMeta construct, which we
124
+ // use to track and reset the necessary metadata for each launch.
125
+ template <int depth, typename scalar_T, typename T, typename... ArgTypes>
126
+ void multi_tensor_apply(
127
+ std::vector<std::vector<at::Tensor>>& tensor_lists,
128
+ at::ArrayRef<Scalar> scalars,
129
+ T callable,
130
+ ArgTypes... args) {
131
+ TORCH_CHECK(
132
+ tensor_lists.size() == depth,
133
+ "Number of tensor lists has to match the depth.");
134
+ const size_t n_tensors = tensor_lists[0].size();
135
+ using scalar_vals_t = typename T::opmath_t;
136
+ TensorListScalarListMetadata<scalar_vals_t, depth> tensorListMeta;
137
+
138
+ int loc_block_info = 0;
139
+ int loc_tensor_info = 0;
140
+ for (size_t t = 0; t < n_tensors; t++) {
141
+ // short-circuit to avoid adding empty tensors to tensorListMeta
142
+ if (tensor_lists[0][t].numel() == 0) {
143
+ continue;
144
+ }
145
+ tensorListMeta.scalar_vals[loc_tensor_info] = scalars[t].to<scalar_T>();
146
+ tensorListMeta.numel_for_tensor[loc_tensor_info] =
147
+ tensor_lists[0][t].numel();
148
+ for (int d = 0; d < depth; d++) {
149
+ tensorListMeta.addresses[d][loc_tensor_info] =
150
+ tensor_lists[d][t].const_data_ptr();
151
+ }
152
+ loc_tensor_info++;
153
+
154
+ // now we enter [chunking territory].
155
+ // we will launch a kernel when EITHER the blocks get filled up OR
156
+ // the tensors get filled up. There will always be at least one block
157
+ // per tensor since the zero-sized ones will not enter the loop, so
158
+ // the nested forloop within represents iterating through the chunks
159
+ // of a single tensor.
160
+ const auto numel = tensor_lists[0][t].numel();
161
+ const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0);
162
+ for (auto chunk = 0; chunk < chunks; chunk++) {
163
+ tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
164
+ tensorListMeta.block_to_chunk[loc_block_info] = chunk;
165
+ loc_block_info++;
166
+
167
+ // a tensor is not considered full unless all its chunks have been
168
+ // processed
169
+ const bool tensors_full =
170
+ (loc_tensor_info == depth_to_max_tensors_scalarlist[depth - 1] &&
171
+ chunk == chunks - 1);
172
+ const bool blocks_full =
173
+ (loc_block_info == depth_to_max_blocks[depth - 1]);
174
+
175
+ if (tensors_full || blocks_full) {
176
+ multi_tensor_apply_kernel<<<
177
+ loc_block_info,
178
+ kBlockSize,
179
+ 0,
180
+ at::cuda::getCurrentCUDAStream()>>>(
181
+ tensorListMeta, callable, args...);
182
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
183
+
184
+ // Reset.
185
+ loc_block_info = 0;
186
+ // all chunks have already been handled in the kernel
187
+ if (chunk == chunks - 1) {
188
+ loc_tensor_info = 0;
189
+ } else { // blocks were full and tensor chunks remain
190
+ tensorListMeta.numel_for_tensor[0] =
191
+ tensorListMeta.numel_for_tensor[loc_tensor_info - 1];
192
+ tensorListMeta.scalar_vals[0] =
193
+ tensorListMeta.scalar_vals[loc_tensor_info - 1];
194
+ for (int d = 0; d < depth; d++) {
195
+ tensorListMeta.addresses[d][0] =
196
+ tensorListMeta.addresses[d][loc_tensor_info - 1];
197
+ }
198
+ loc_tensor_info = 1;
199
+ }
200
+ }
201
+ }
202
+ }
203
+
204
+ // note: [finishing what we started]
205
+ // if there's remaining work to be done but the tensors/blocks aren't full
206
+ // yet we are at the end, submit the kernel to do the work!
207
+ if (loc_block_info != 0) {
208
+ multi_tensor_apply_kernel<<<
209
+ loc_block_info,
210
+ kBlockSize,
211
+ 0,
212
+ at::cuda::getCurrentCUDAStream()>>>(tensorListMeta, callable, args...);
213
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
214
+ }
215
+ }
216
+
217
+ template <int depth, typename T, typename... ArgTypes>
218
+ void multi_tensor_apply(
219
+ std::vector<std::vector<at::Tensor>>& tensor_lists,
220
+ T callable,
221
+ ArgTypes... args) {
222
+ TORCH_CHECK(
223
+ tensor_lists.size() == depth,
224
+ "Number of tensor lists has to match the depth.");
225
+ const size_t n_tensors = tensor_lists[0].size();
226
+ TensorListMetadata<depth> tensorListMeta;
227
+ tensorListMeta.start_tensor_this_launch = 0;
228
+
229
+ int loc_block_info = 0;
230
+ int loc_tensor_info = 0;
231
+ for (size_t t = 0; t < n_tensors; t++) {
232
+ // short-circuit to avoid adding empty tensors to tensorListMeta
233
+ if (tensor_lists[0][t].numel() == 0) {
234
+ continue;
235
+ }
236
+ tensorListMeta.numel_for_tensor[loc_tensor_info] =
237
+ tensor_lists[0][t].numel();
238
+ for (int d = 0; d < depth; d++) {
239
+ tensorListMeta.addresses[d][loc_tensor_info] =
240
+ tensor_lists[d][t].const_data_ptr();
241
+ }
242
+ loc_tensor_info++;
243
+
244
+ // see note: [chunking territory].
245
+ const auto numel = tensor_lists[0][t].numel();
246
+ const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0);
247
+ for (auto chunk = 0; chunk < chunks; chunk++) {
248
+ tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
249
+ tensorListMeta.block_to_chunk[loc_block_info] = chunk;
250
+ loc_block_info++;
251
+
252
+ const bool tensors_full =
253
+ (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
254
+ chunk == chunks - 1);
255
+ const bool blocks_full =
256
+ (loc_block_info == depth_to_max_blocks[depth - 1]);
257
+
258
+ if (tensors_full || blocks_full) {
259
+ multi_tensor_apply_kernel<<<
260
+ loc_block_info,
261
+ kBlockSize,
262
+ 0,
263
+ at::cuda::getCurrentCUDAStream()>>>(
264
+ tensorListMeta, callable, args...);
265
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
266
+
267
+ // Reset.
268
+ loc_block_info = 0;
269
+ if (chunk == chunks - 1) {
270
+ loc_tensor_info = 0;
271
+ tensorListMeta.start_tensor_this_launch = t + 1;
272
+ } else {
273
+ tensorListMeta.numel_for_tensor[0] =
274
+ tensorListMeta.numel_for_tensor[loc_tensor_info - 1];
275
+ for (int d = 0; d < depth; d++) {
276
+ tensorListMeta.addresses[d][0] =
277
+ tensorListMeta.addresses[d][loc_tensor_info - 1];
278
+ }
279
+ loc_tensor_info = 1;
280
+ tensorListMeta.start_tensor_this_launch = t;
281
+ }
282
+ }
283
+ }
284
+ }
285
+
286
+ // see note: [finishing what we started]
287
+ if (loc_block_info != 0) {
288
+ multi_tensor_apply_kernel<<<
289
+ loc_block_info,
290
+ kBlockSize,
291
+ 0,
292
+ at::cuda::getCurrentCUDAStream()>>>(tensorListMeta, callable, args...);
293
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
294
+ }
295
+ }
296
+
297
+ template <int depth, typename T, typename... ArgTypes>
298
+ void multi_tensor_apply_for_fused_optimizer(
299
+ std::vector<std::vector<at::Tensor>>& tensor_lists,
300
+ at::TensorList state_steps,
301
+ T callable,
302
+ ArgTypes... args) {
303
+ TORCH_CHECK(
304
+ tensor_lists.size() == depth,
305
+ "Number of tensor lists has to match the depth");
306
+ const auto num_tensors = tensor_lists[0].size();
307
+ FusedOptimizerTensorListMetadata<depth> tensorListMeta;
308
+
309
+ int loc_block_info = 0;
310
+ int loc_tensor_info = 0;
311
+ for (const auto& tensor_index : c10::irange(num_tensors)) {
312
+ // short-circuit to avoid adding empty tensors to tensorListMeta
313
+ if (tensor_lists[0][tensor_index].numel() == 0) {
314
+ continue;
315
+ }
316
+ tensorListMeta.state_steps_addresses[loc_tensor_info] =
317
+ state_steps[tensor_index].const_data_ptr();
318
+ tensorListMeta.numel_for_tensor[loc_tensor_info] =
319
+ tensor_lists[0][tensor_index].numel();
320
+ for (const auto& d : c10::irange(depth)) {
321
+ tensorListMeta.addresses[d][loc_tensor_info] =
322
+ tensor_lists[d][tensor_index].const_data_ptr();
323
+ }
324
+ loc_tensor_info++;
325
+
326
+ // see above note: [chunking territory]
327
+ const auto numel = tensor_lists[0][tensor_index].numel();
328
+ const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0);
329
+ TORCH_CHECK(chunks > -1);
330
+ for (const auto& chunk : c10::irange(chunks)) {
331
+ tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
332
+ tensorListMeta.block_to_chunk[loc_block_info] = chunk;
333
+ loc_block_info++;
334
+
335
+ const auto tensor_full =
336
+ (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
337
+ chunk == chunks - 1);
338
+ const auto blocks_full = loc_block_info == depth_to_max_blocks[depth - 1];
339
+
340
+ if (tensor_full || blocks_full) {
341
+ multi_tensor_apply_kernel<<<
342
+ loc_block_info,
343
+ kBlockSize,
344
+ 0,
345
+ at::cuda::getCurrentCUDAStream()>>>(
346
+ tensorListMeta, callable, args...);
347
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
348
+
349
+ // Reset.
350
+ loc_block_info = 0;
351
+ if (chunk == chunks - 1) {
352
+ loc_tensor_info = 0;
353
+ } else {
354
+ tensorListMeta.numel_for_tensor[0] =
355
+ tensorListMeta.numel_for_tensor[loc_tensor_info - 1];
356
+ tensorListMeta.state_steps_addresses[0] =
357
+ tensorListMeta.state_steps_addresses[loc_tensor_info - 1];
358
+ for (const auto& d : c10::irange(depth)) {
359
+ tensorListMeta.addresses[d][0] =
360
+ tensorListMeta.addresses[d][loc_tensor_info - 1];
361
+ }
362
+ loc_tensor_info = 1;
363
+ }
364
+ }
365
+ }
366
+ }
367
+
368
+ // see above note: [finishing what we've started]
369
+ if (loc_block_info != 0) {
370
+ multi_tensor_apply_kernel<<<
371
+ loc_block_info,
372
+ kBlockSize,
373
+ 0,
374
+ at::cuda::getCurrentCUDAStream()>>>(tensorListMeta, callable, args...);
375
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
376
+ }
377
+ }
378
+
379
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Normalization.cuh ADDED
@@ -0,0 +1,1742 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/Dispatch.h>
5
+ #include <ATen/AccumulateType.h>
6
+ #include <ATen/ceil_div.h>
7
+ #include <ATen/cuda/CUDAContext.h>
8
+ #include <ATen/cuda/DeviceUtils.cuh>
9
+ #include <ATen/native/cuda/block_reduce.cuh>
10
+ #include <ATen/native/cuda/DeviceSqrt.cuh>
11
+ #include <ATen/native/cuda/LaunchUtils.h>
12
+ #include <c10/macros/Macros.h>
13
+
14
+ #ifndef AT_PER_OPERATOR_HEADERS
15
+ #include <ATen/Functions.h>
16
+ #else
17
+ #include <ATen/ops/empty.h>
18
+ #include <ATen/ops/empty_like.h>
19
+ #include <ATen/ops/zeros.h>
20
+ #endif
21
+
22
+ namespace at { namespace native {
23
+
24
+ // The maximum number of threads in a block
25
+ #if defined(USE_ROCM)
26
+ constexpr int MAX_BLOCK_SIZE = 256;
27
+ #else
28
+ constexpr int MAX_BLOCK_SIZE = 512;
29
+ #endif
30
+
31
+ constexpr unsigned MAX_GRID_SIZE = 65535u;
32
+
33
+ // Number of threads in a block given an input size up to MAX_BLOCK_SIZE
34
+ static int getNumThreads(int nElem) {
35
+ #if defined(USE_ROCM)
36
+ int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
37
+ #else
38
+ int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
39
+ #endif
40
+ for (int i = 0; i != 5; ++i) {
41
+ if (nElem <= threadSizes[i]) {
42
+ return threadSizes[i];
43
+ }
44
+ }
45
+ return MAX_BLOCK_SIZE;
46
+ }
47
+
48
+ // Returns the index of the most significant 1 bit in `val`.
49
+ __device__ __forceinline__ int getMSB(int val) {
50
+ return 31 - __clz(val);
51
+ }
52
+
53
+ template <typename scalar_t, typename accscalar_t>
54
+ struct Float2 {
55
+ accscalar_t v1, v2;
56
+ __device__ Float2() {}
57
+ __device__ Float2(scalar_t v1, scalar_t v2) : v1(static_cast<accscalar_t>(v1)), v2(static_cast<accscalar_t>(v2)) {}
58
+ __device__ Float2(int v) : v1(static_cast<accscalar_t>(v)), v2(static_cast<accscalar_t>(v)) {}
59
+ __device__ Float2& operator+=(const Float2& a) {
60
+ v1 += a.v1;
61
+ v2 += a.v2;
62
+ return *this;
63
+ }
64
+ __device__ friend Float2 operator+(Float2 a, const Float2& b) {
65
+ a += b;
66
+ return a;
67
+ }
68
+ };
69
+
70
+ template <typename scalar_t, typename accscalar_t, typename PTA>
71
+ struct GradOp {
72
+ __device__ GradOp(accscalar_t m, const PTA& i, const PTA& g)
73
+ : mean(m), input(i), grad_output(g) {}
74
+ __device__ __forceinline__ Float2<scalar_t, accscalar_t> operator()(int batch, int plane, int n) {
75
+ accscalar_t g = grad_output[batch][plane][n];
76
+ accscalar_t c = static_cast<accscalar_t>(input[batch][plane][n]) - mean;
77
+ return Float2<scalar_t, accscalar_t>(g, g * c);
78
+ }
79
+ const accscalar_t mean;
80
+ const PTA& input;
81
+ const PTA& grad_output;
82
+ };
83
+
84
+ template <typename acc_t>
85
+ struct SumReduceOp {
86
+ __device__ __forceinline__ acc_t combine(acc_t a, acc_t b) const { return a + b; }
87
+
88
+ __device__ __forceinline__ acc_t warp_shfl_down(acc_t data, int offset) const {
89
+ return WARP_SHFL_DOWN(data, offset);
90
+ }
91
+ };
92
+
93
+ template <typename scalar_t, typename accscalar_t>
94
+ struct SumReduceOp<Float2<scalar_t, accscalar_t>> {
95
+ using acc_t = Float2<scalar_t, accscalar_t>;
96
+
97
+ __device__ __forceinline__ acc_t combine(acc_t a, acc_t b) const { return a + b; }
98
+
99
+ __device__ __forceinline__ acc_t warp_shfl_down(acc_t data, int offset) const {
100
+ return {WARP_SHFL_DOWN(data.v1, offset), WARP_SHFL_DOWN(data.v2, offset)};
101
+ }
102
+ };
103
+
104
+ // Sum across (batch, x/y/z) applying Op() pointwise
105
+ // this works by first having each thread sum it's part
106
+ // of the data. Then there is a double-shuffling reduction.
107
+ // First each warp (of C10_WARP_SIZE threads) uses warpSum to reduce its
108
+ // data to the "warp leader", who writes its value into shared memory.
109
+ // Then a single warp reads the remaining (at most C10_WARP_SIZE) items
110
+ // and reduces them using another warpSum.
111
+ // The implicit assumption is that there are no more
112
+ // than C10_WARP_SIZE**2 threads.
113
+ template<typename scalar_t, typename Op, typename PTA>
114
+ __device__ scalar_t reduce(Op op, PTA tensor, int plane) {
115
+ // first the reductions each thread does separately
116
+ scalar_t sum = static_cast<scalar_t>(0);
117
+ for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) {
118
+ for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) {
119
+ sum += op(batch, plane, x);
120
+ }
121
+ }
122
+ __shared__ scalar_t shared[C10_WARP_SIZE];
123
+ SumReduceOp<scalar_t> reduce_op;
124
+ sum = cuda_utils::BlockReduce<scalar_t, SumReduceOp<scalar_t>, cuda_utils::Block2D>(sum, reduce_op, 0, shared);
125
+ if (threadIdx.x == 0 && threadIdx.y == 0) {
126
+ shared[0] = sum;
127
+ }
128
+ __syncthreads();
129
+ // Everyone picks it up, should be broadcast into the whole grad_input
130
+ return shared[0];
131
+ }
132
+
133
+ constexpr int ELEMENTS_PER_ITER = 4; // enables concurrency within each thread to hide latency
134
+ constexpr int ELEMENTS_PER_THREAD = 16;
135
+ constexpr int OPTIMAL_TILE_W = 32;
136
+ constexpr int MAX_H_BLOCK = 128;
137
+
138
+ __host__ void flexible_launch_configs(
139
+ const int reduction,
140
+ const int stride,
141
+ dim3 &block,
142
+ dim3 &grid,
143
+ const bool coop_flag = false) {
144
+ int block_x = std::min(lastPow2(stride), OPTIMAL_TILE_W);
145
+ int block_y = std::min(lastPow2(at::ceil_div(reduction , ELEMENTS_PER_THREAD)),
146
+ MAX_BLOCK_SIZE / block_x);
147
+ if (block_x * block_y != MAX_BLOCK_SIZE) {
148
+ block_x = std::min(lastPow2(stride), MAX_BLOCK_SIZE / block_y);
149
+ }
150
+
151
+ int grid_x = at::ceil_div(stride, block_x);
152
+ int grid_y = std::min(at::ceil_div(reduction, block_y * ELEMENTS_PER_THREAD), MAX_H_BLOCK);
153
+ if (coop_flag) {
154
+ // it's not worth having a grid reduction if the reduction dimension is not big enough
155
+ grid_y = grid_y < 8 ? 1 : grid_y;
156
+ }
157
+
158
+ block.x = block_x;
159
+ block.y = block_y;
160
+ block.z = 1;
161
+ grid.x = grid_x;
162
+ grid.y = grid_y;
163
+ grid.z = 1;
164
+ }
165
+
166
+ template<typename T, typename C>
167
+ __device__ __forceinline__ void welford_merge_element(C& count,
168
+ T& mean,
169
+ T& m2n,
170
+ const C& count_new,
171
+ const T& mean_new,
172
+ const T& m2n_new) {
173
+ T factor = T(1.0) / ::max(1, (count + count_new));
174
+ T delta0 = mean - mean_new;
175
+ mean = (mean_new * count_new + mean * count) * factor;
176
+ m2n += m2n_new + delta0 * delta0 * count_new * count * factor;
177
+ count += count_new;
178
+ }
179
+
180
+ // merge mean/m2n among threadIdx.y within block
181
+ template<typename T, typename C>
182
+ __device__ __forceinline__ void welford_merge_block_vertical(C& count,
183
+ T& mean,
184
+ T& m2n,
185
+ C* shmem_count,
186
+ T* shmem_mean,
187
+ T* shmem_m2n) {
188
+ // write to shared memory
189
+ auto address_base = threadIdx.x + threadIdx.y * blockDim.x;
190
+
191
+ #pragma unroll
192
+ for (int offset = blockDim.y/2; offset > 0; offset >>= 1) {
193
+ if (threadIdx.y < offset*2) {
194
+ shmem_mean[address_base] = mean;
195
+ shmem_m2n[address_base] = m2n;
196
+ shmem_count[address_base] = count;
197
+ }
198
+ __syncthreads();
199
+ if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
200
+ auto address = address_base + offset * blockDim.x;
201
+ // read shared memory back to register for reduction
202
+ auto count_new = shmem_count[address];
203
+ auto mean_new = shmem_mean[address];
204
+ auto m2n_new = shmem_m2n[address];
205
+
206
+ welford_merge_element(count, mean, m2n, count_new, mean_new, m2n_new);
207
+ }
208
+ }
209
+ }
210
+
211
+ template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, bool train, typename index_t>
212
+ __global__ void batch_norm_transform_input_kernel(
213
+ const GenericPackedTensorAccessor<const input_scalar_t, 3, RestrictPtrTraits, index_t> input,
214
+ GenericPackedTensorAccessor<input_scalar_t, 3, RestrictPtrTraits, index_t> output,
215
+ const GenericPackedTensorAccessor<typename std::conditional<train, stat_accscalar_t, stat_scalar_t>::type, 1, RestrictPtrTraits, index_t> mean_,
216
+ const GenericPackedTensorAccessor<typename std::conditional<train, stat_accscalar_t, stat_scalar_t>::type, 1, RestrictPtrTraits, index_t> var_or_invstd,
217
+ const GenericPackedTensorAccessor<const stat_scalar_t, 1, RestrictPtrTraits, index_t> weight,
218
+ const GenericPackedTensorAccessor<const stat_scalar_t, 1, RestrictPtrTraits, index_t> bias,
219
+ stat_accscalar_t epsilon) {
220
+
221
+ index_t plane = blockIdx.x;
222
+
223
+ if (plane >= input.size(1)) {
224
+ return;
225
+ }
226
+
227
+ stat_accscalar_t gamma = weight.size(0) > 0 ? static_cast<stat_accscalar_t>(weight[plane]) : static_cast<stat_accscalar_t>(1);
228
+ stat_accscalar_t beta = bias.size(0) > 0 ? static_cast<stat_accscalar_t>(bias[plane]) : static_cast<stat_accscalar_t>(0);
229
+ stat_accscalar_t mean = static_cast<stat_accscalar_t>(mean_[plane]);
230
+ stat_accscalar_t invstd;
231
+ if (train) {
232
+ invstd = var_or_invstd[plane];
233
+ } else {
234
+ invstd = static_cast<stat_accscalar_t>(1) / device_sqrt(static_cast<stat_accscalar_t>(var_or_invstd[plane]) + epsilon);
235
+ }
236
+
237
+ index_t bs = input.size(0);
238
+ index_t fs = input.size(2);
239
+
240
+ index_t bstep = blockDim.y * gridDim.y;
241
+ for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) {
242
+ auto o = output[batch][plane];
243
+ auto i = input[batch][plane];
244
+ for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) {
245
+ o[feature] = static_cast<input_scalar_t>(gamma * (i[feature] - mean) * invstd + beta);
246
+ }
247
+ }
248
+ }
249
+
250
+ struct InvStd {
251
+ template <typename T>
252
+ __device__ __forceinline__ T operator()(T var, double epsilon) const {
253
+ T invstd = 0;
254
+ if (var != static_cast<T>(0) || epsilon != static_cast<T>(0)) {
255
+ invstd = static_cast<T>(1) / device_sqrt(var + epsilon);
256
+ }
257
+ return invstd;
258
+ }
259
+ };
260
+
261
+ struct Var {
262
+ template <typename T>
263
+ __device__ __forceinline__ T operator()(T var, double epsilon) const {
264
+ return var;
265
+ }
266
+ };
267
+
268
+ template <typename VarTransform, typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
269
+ __global__ void batch_norm_collect_statistics_kernel(
270
+ const GenericPackedTensorAccessor<const input_scalar_t, 3, RestrictPtrTraits, index_t> input,
271
+ const stat_accscalar_t epsilon,
272
+ const stat_accscalar_t momentum,
273
+ GenericPackedTensorAccessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t> save_mean,
274
+ GenericPackedTensorAccessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t> save_transformed_var) {
275
+
276
+ __shared__ int shared_n[2 * 2 * C10_WARP_SIZE + C10_WARP_SIZE];
277
+
278
+ int plane = blockIdx.x;
279
+ int N = input.size(0) * input.size(2);
280
+ int tid = threadIdx.x + threadIdx.y * blockDim.x;
281
+
282
+ // Compute the mean and variance across (batch, x/y/z)
283
+ // this uses the Welford (in the for loop)/parallel algorithm (to sum across the block)
284
+ // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
285
+ // and the parallel algorithm on the same page.
286
+ // We use two shuffles to reduce across the entire block.
287
+ // https://devblogs.nvidia.com/faster-parallel-reductions-kepler/ has a description.
288
+ stat_accscalar_t* shared_avg_var = (stat_accscalar_t*) &shared_n[C10_WARP_SIZE];
289
+
290
+ // first the reductions each thread does separately
291
+ stat_accscalar_t avg = 0;
292
+ stat_accscalar_t var_n = 0;
293
+ int n = 0;
294
+ for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) {
295
+ for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) {
296
+ stat_accscalar_t v = input[batch][plane][x];
297
+ stat_accscalar_t d1 = v - avg;
298
+ n++;
299
+ avg += d1 / n;
300
+ var_n += d1 * (v - avg);
301
+ }
302
+ }
303
+
304
+ // first warpSum to get one value per thread to
305
+ // one value per warp
306
+ for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) {
307
+ stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, C10_WARP_SIZE);
308
+ int o_n = WARP_SHFL_XOR(n, 1 << i, C10_WARP_SIZE);
309
+ stat_accscalar_t factor = 1.0 / fmaxf(1.0, n+o_n);
310
+ var_n += WARP_SHFL_XOR(var_n, 1 << i, C10_WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor;
311
+ avg = (n * avg + o_n * o_avg) * factor;
312
+ n += o_n;
313
+ }
314
+
315
+ // this writes each warps item into shared memory
316
+ // there are at most C10_WARP_SIZE items left because
317
+ // there are at most C10_WARP_SIZE**2 threads at the beginning
318
+ __syncthreads();
319
+ if (tid % C10_WARP_SIZE == 0) {
320
+ shared_n[tid / C10_WARP_SIZE] = n;
321
+ shared_avg_var[tid / C10_WARP_SIZE * 2] = avg;
322
+ shared_avg_var[tid / C10_WARP_SIZE * 2 + 1] = var_n;
323
+ }
324
+ __syncthreads();
325
+ // now have a second warpSum to reduce the intermediate values
326
+ // from shared memory to a single number. The very first
327
+ // thread writes it to shared memory.
328
+
329
+ if (tid < C10_WARP_SIZE) {
330
+ n = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_n[tid] : 0);
331
+ avg = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_avg_var[2 * tid] : stat_accscalar_t(0));
332
+ var_n = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_avg_var[2 * tid + 1] : stat_accscalar_t(0));
333
+ }
334
+ for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) {
335
+ stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, C10_WARP_SIZE);
336
+ int o_n = WARP_SHFL_XOR(n, 1 << i, C10_WARP_SIZE);
337
+ stat_accscalar_t factor = 1.0 / fmaxf(1.0, n+o_n);
338
+ var_n += WARP_SHFL_XOR(var_n, 1 << i, C10_WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor;
339
+ avg = (n * avg + o_n * o_avg) * factor;
340
+ n += o_n;
341
+ }
342
+
343
+ // Save the mean, variance, and moving averages
344
+ if (tid == 0) {
345
+ if (save_mean.data() != NULL) {
346
+ save_mean[plane] = avg;
347
+ }
348
+ if (save_transformed_var.data() != NULL) {
349
+ save_transformed_var[plane] = VarTransform{}(var_n / N, epsilon);
350
+ }
351
+ }
352
+
353
+ }
354
+
355
+ template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
356
+ __global__ void batch_norm_backward_kernel(
357
+ const GenericPackedTensorAccessor<const input_scalar_t, 3, DefaultPtrTraits, index_t> input,
358
+ const GenericPackedTensorAccessor<const input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
359
+ GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input,
360
+ GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_weight,
361
+ GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_bias,
362
+ const GenericPackedTensorAccessor<const stat_scalar_t, 1, DefaultPtrTraits, index_t> weight,
363
+ const GenericPackedTensorAccessor<const stat_scalar_t, 1, DefaultPtrTraits, index_t> running_mean,
364
+ const GenericPackedTensorAccessor<const stat_scalar_t, 1, DefaultPtrTraits, index_t> running_var,
365
+ const GenericPackedTensorAccessor<const stat_accscalar_t, 1, DefaultPtrTraits, index_t> save_mean,
366
+ const GenericPackedTensorAccessor<const stat_accscalar_t, 1, DefaultPtrTraits, index_t> save_invstd,
367
+ bool train,
368
+ stat_accscalar_t epsilon) {
369
+
370
+ index_t plane = blockIdx.x;
371
+ index_t N = grad_output.size(0) * grad_output.size(2);
372
+
373
+ stat_accscalar_t mean, invstd;
374
+ if (train) {
375
+ mean = save_mean[plane];
376
+ invstd = save_invstd[plane];
377
+ } else {
378
+ mean = static_cast<stat_accscalar_t>(running_mean[plane]);
379
+ invstd = static_cast<stat_accscalar_t>(1) / device_sqrt(static_cast<stat_accscalar_t>(running_var[plane]) + epsilon);
380
+ }
381
+
382
+ stat_accscalar_t weight_val = weight.size(0) > 0 ? static_cast<stat_accscalar_t>(weight[plane]) : stat_accscalar_t(1);
383
+ stat_accscalar_t norm = stat_accscalar_t(1) / N;
384
+
385
+ // Compute two values across (batch, x/y/z) in one pass:
386
+ // 1. Sum(grad_output)
387
+ // 2. DotProduct(input - mean, grad_output)
388
+ GradOp<input_scalar_t, stat_accscalar_t, GenericPackedTensorAccessor<const input_scalar_t, 3, DefaultPtrTraits, index_t>> g(mean, input, grad_output);
389
+ auto res = reduce<Float2<input_scalar_t, stat_accscalar_t>>(g, grad_output, plane);
390
+
391
+ stat_accscalar_t grad_output_sum = res.v1;
392
+ stat_accscalar_t dot_p = res.v2;
393
+
394
+ stat_accscalar_t grad_mean = grad_output_sum * norm;
395
+ stat_accscalar_t proj_scale = dot_p * norm * invstd * invstd;
396
+ stat_accscalar_t grad_scale = invstd * weight_val;
397
+
398
+ if (grad_input.data() != NULL) {
399
+ for (int batch = threadIdx.y; batch < grad_output.size(0); batch += blockDim.y) {
400
+ for (int x = threadIdx.x; x < grad_output.size(2); x += blockDim.x) {
401
+ input_scalar_t go = grad_output[batch][plane][x];
402
+ if (train) {
403
+ stat_accscalar_t inp = input[batch][plane][x];
404
+ stat_accscalar_t proj = (inp - mean) * proj_scale;
405
+ grad_input[batch][plane][x] = static_cast<input_scalar_t>((go - proj - grad_mean) * grad_scale);
406
+ } else {
407
+ grad_input[batch][plane][x] = static_cast<input_scalar_t>(go * grad_scale);
408
+ }
409
+ }
410
+ }
411
+ }
412
+
413
+ if (grad_weight.size(0) > 0) {
414
+ if (threadIdx.x == 0) {
415
+ grad_weight[plane] = static_cast<stat_scalar_t>(dot_p * invstd);
416
+ }
417
+ }
418
+
419
+ if (grad_bias.size(0) > 0) {
420
+ if (threadIdx.x == 0) {
421
+ grad_bias[plane] = static_cast<stat_scalar_t>(grad_output_sum);
422
+ }
423
+ }
424
+ }
425
+
426
+ template <typename scalar_t, typename accscalar_t, typename index_t>
427
+ __global__ void batch_norm_reduce_statistics_kernel(
428
+ const GenericPackedTensorAccessor<accscalar_t, 2, RestrictPtrTraits, index_t> vec_mean,
429
+ const GenericPackedTensorAccessor<accscalar_t, 2, RestrictPtrTraits, index_t> vec_invstd,
430
+ GenericPackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> mean,
431
+ GenericPackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> invstd,
432
+ GenericPackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> running_mean,
433
+ GenericPackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> running_var,
434
+ const accscalar_t epsilon,
435
+ const accscalar_t momentum,
436
+ const GenericPackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> counts) {
437
+
438
+ int feature_size = vec_mean.size(1);
439
+ int world_size = vec_mean.size(0);
440
+
441
+ int bid = blockIdx.x;
442
+ int tid = threadIdx.x;
443
+
444
+ // first the reductions each thread does separately
445
+ for (int i = bid*blockDim.x+tid; i < feature_size; i += gridDim.x*blockDim.x) {
446
+ accscalar_t avg = 0;
447
+ accscalar_t var_n = 0;
448
+ index_t n = 0;
449
+ for (int j = 0; j < world_size; j++) {
450
+ scalar_t count = counts[j];
451
+ accscalar_t m = vec_mean[j][i];
452
+ accscalar_t v = accscalar_t(1.0) / (vec_invstd[j][i]);
453
+ v = (v * v - epsilon) * count;
454
+ accscalar_t factor = 1.0 / (n + count);
455
+ var_n += v + (avg - m) * (avg - m) * n * count * factor;
456
+ avg = n * factor * avg + count * factor * m;
457
+ n += count;
458
+ }
459
+ mean[i] = avg;
460
+ invstd[i] = static_cast<accscalar_t>(1) / device_sqrt(var_n / n + epsilon);
461
+ if (running_mean.data() != NULL) {
462
+ running_mean[i] = static_cast<scalar_t>((1 - momentum) * running_mean[i] + momentum * avg);
463
+ }
464
+ accscalar_t unbiasedVar = var_n / (n - 1);
465
+ if (running_var.data() != NULL) {
466
+ running_var[i] = static_cast<scalar_t>((1 - momentum) * running_var[i] + momentum * unbiasedVar);
467
+ }
468
+ }
469
+
470
+ }
471
+
472
+ template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
473
+ __global__ void batch_norm_backward_reduce_kernel(
474
+ const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
475
+ const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
476
+ GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean,
477
+ GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
478
+ GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy,
479
+ GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu,
480
+ GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_weight,
481
+ GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_bias) {
482
+
483
+ index_t plane = blockIdx.x;
484
+
485
+ stat_accscalar_t r_mean = mean[plane];
486
+ stat_accscalar_t factor = invstd[plane];
487
+
488
+ GradOp<input_scalar_t, stat_accscalar_t, GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t>> g(r_mean, input, grad_output);
489
+ auto res = reduce<Float2<input_scalar_t, stat_accscalar_t>>(g, grad_output, plane);
490
+
491
+ if (threadIdx.x == 0) {
492
+ if (grad_weight.size(0) > 0) {
493
+ grad_weight[plane] = static_cast<stat_scalar_t>(res.v2 * factor);
494
+ }
495
+ if (grad_bias.size(0) > 0) {
496
+ grad_bias[plane] = static_cast<stat_scalar_t>(res.v1);
497
+ }
498
+ if (sum_dy.size(0) > 0) {
499
+ sum_dy[plane] = static_cast<stat_accscalar_t>(res.v1);
500
+ }
501
+ if (sum_dy_xmu.size(0) > 0) {
502
+ sum_dy_xmu[plane] = static_cast<stat_accscalar_t>(res.v2);
503
+ }
504
+ }
505
+ }
506
+
507
+ template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
508
+ __device__ __forceinline__ void batch_norm_backward_elemt_kernel_impl(
509
+ const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
510
+ const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
511
+ const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean,
512
+ const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
513
+ const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> weight,
514
+ const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy,
515
+ const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu,
516
+ GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input,
517
+ const stat_accscalar_t norm_fct) {
518
+ index_t plane = blockIdx.x;
519
+
520
+ if (plane >= input.size(1)) {
521
+ return;
522
+ }
523
+
524
+ stat_accscalar_t m_c = mean[plane];
525
+ stat_accscalar_t m_dy_c = sum_dy[plane] * norm_fct;
526
+ stat_accscalar_t factor_1_c = invstd[plane];
527
+ stat_accscalar_t factor_2_c = weight.size(0) > 0 ? static_cast<stat_accscalar_t>(weight[plane]) : stat_accscalar_t(1);
528
+ factor_2_c *= factor_1_c;
529
+ factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[plane] * norm_fct;
530
+
531
+ index_t bs = input.size(0);
532
+ index_t fs = input.size(2);
533
+
534
+ index_t bstep = blockDim.y * gridDim.y;
535
+ for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) {
536
+ auto g_i = grad_input[batch][plane];
537
+ auto g_o = grad_output[batch][plane];
538
+ auto i = input[batch][plane];
539
+ for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) {
540
+ g_i[feature] = static_cast<input_scalar_t>((g_o[feature] - m_dy_c - (i[feature] - m_c) * factor_1_c) * factor_2_c);
541
+ }
542
+ }
543
+ }
544
+
545
+ template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
546
+ __global__ void batch_norm_backward_elemt_kernel(
547
+ const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
548
+ const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
549
+ const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean,
550
+ const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
551
+ const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> weight,
552
+ const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy,
553
+ const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu,
554
+ GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input,
555
+ const int* __restrict__ numel, const int world_size) {
556
+ int64_t total_numel = 0;
557
+ for (int i = 0; i < world_size; i ++) {
558
+ total_numel += numel[i];
559
+ }
560
+
561
+ const stat_accscalar_t norm_fct =
562
+ static_cast<stat_accscalar_t>(1) / static_cast<stat_accscalar_t>(total_numel);
563
+ batch_norm_backward_elemt_kernel_impl(
564
+ input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct);
565
+ }
566
+
567
+ template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
568
+ __global__ void batch_norm_backward_elemt_kernel(
569
+ const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
570
+ const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
571
+ const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean,
572
+ const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
573
+ const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> weight,
574
+ const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy,
575
+ const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu,
576
+ GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input,
577
+ const stat_accscalar_t norm_fct) {
578
+ batch_norm_backward_elemt_kernel_impl(
579
+ input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct);
580
+ }
581
+
582
+ template <typename scalar_t, int64_t dim, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
583
+ static GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t> get_packed_accessor(
584
+ const Tensor& t, c10::string_view var_name) {
585
+ constexpr auto expect_type = c10::CppTypeToScalarType<typename std::remove_const<scalar_t>::type>::value;
586
+ const auto actual_type = t.scalar_type();
587
+ TORCH_CHECK(actual_type == expect_type, "Expected ", var_name,
588
+ " to have type ", expect_type, " but got ", actual_type);
589
+ return t.generic_packed_accessor<scalar_t, dim, PtrTraits, index_t>();
590
+ }
591
+
592
+ template <typename scalar_t, int64_t dim, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
593
+ static GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t> packed_accessor_or_dummy(
594
+ const Tensor& t, c10::string_view var_name) {
595
+ if (!t.defined()) {
596
+ const std::array<index_t, dim> zeros{{0}};
597
+ return GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t>(nullptr, zeros.data(), zeros.data());
598
+ }
599
+ return get_packed_accessor<scalar_t, dim, PtrTraits, index_t>(t, var_name);
600
+ }
601
+
602
+ template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
603
+ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda_template(const Tensor& grad_out_, const Tensor& input_, const Tensor& weight_,
604
+ const Tensor& running_mean_, const Tensor& running_var_, const Tensor& save_mean_, const Tensor& save_invstd_,
605
+ bool train, double epsilon, std::array<bool,3> grad_input_mask) {
606
+
607
+ using accscalar_t = at::acc_type<stat_scalar_t, true>;
608
+ Tensor grad_input_;
609
+ Tensor grad_input_reshaped;
610
+ Tensor grad_weight_;
611
+ Tensor grad_bias_;
612
+ auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1});
613
+ auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
614
+
615
+ if (grad_input_mask[0]) {
616
+ grad_input_ = at::empty_like(input_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
617
+ grad_input_reshaped = grad_input_.view(input_reshaped.sizes());
618
+ }
619
+ if (grad_input_mask[1]) {
620
+ grad_weight_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
621
+ }
622
+ if (grad_input_mask[2]) {
623
+ grad_bias_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
624
+ }
625
+
626
+ auto input = get_packed_accessor<
627
+ const input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
628
+ auto grad_output = get_packed_accessor<
629
+ const input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
630
+ auto grad_input = packed_accessor_or_dummy<
631
+ input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input");
632
+ auto weight = packed_accessor_or_dummy<
633
+ const stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight");
634
+ auto grad_weight = packed_accessor_or_dummy<
635
+ stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_, "grad_weight");
636
+ auto grad_bias = packed_accessor_or_dummy<
637
+ stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_, "grad_bias");
638
+ auto running_mean = packed_accessor_or_dummy<
639
+ const stat_scalar_t, 1, DefaultPtrTraits, index_t>(running_mean_, "running_mean");
640
+ auto running_var = packed_accessor_or_dummy<
641
+ const stat_scalar_t, 1, DefaultPtrTraits, index_t>(running_var_, "running_var");
642
+ auto save_mean = packed_accessor_or_dummy<
643
+ const accscalar_t, 1, DefaultPtrTraits, index_t>(save_mean_, "save_mean");
644
+ auto save_invstd = packed_accessor_or_dummy<
645
+ const accscalar_t, 1, DefaultPtrTraits, index_t>(save_invstd_, "save_invstd");
646
+
647
+ auto stream = at::cuda::getCurrentCUDAStream();
648
+ dim3 blocks(input.size(1));
649
+ int tf = getNumThreads(input.size(2));
650
+ dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
651
+
652
+ batch_norm_backward_kernel<input_scalar_t, stat_scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
653
+ (input, grad_output, grad_input, grad_weight, grad_bias, weight, running_mean, running_var,
654
+ save_mean, save_invstd, train, epsilon);
655
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
656
+
657
+ return std::make_tuple(grad_input_, grad_weight_, grad_bias_);
658
+ }
659
+
660
+ template<typename scalar_t, typename index_t, typename VarTransform>
661
+ void batch_norm_stats_cuda_template(
662
+ const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input_, double epsilon) {
663
+
664
+ using accscalar_t = at::acc_type<scalar_t, true>;
665
+ int64_t n_input = input_.size(1);
666
+ Tensor dummy_mean_;
667
+ Tensor dummy_var_;
668
+ auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
669
+
670
+ resize_output(out_mean, {n_input});
671
+ resize_output(out_invstd, {n_input});
672
+ auto input = get_packed_accessor<
673
+ const scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input");
674
+ TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() &&
675
+ out_invstd.sizes()[0]);
676
+ TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() &&
677
+ out_mean.sizes()[0]);
678
+
679
+ auto mean = packed_accessor_or_dummy<
680
+ accscalar_t, 1, RestrictPtrTraits, index_t>(out_mean, "out_mean");
681
+ auto invstd = packed_accessor_or_dummy<
682
+ accscalar_t, 1, RestrictPtrTraits, index_t>(out_invstd, "out_invstd");
683
+ auto stream = at::cuda::getCurrentCUDAStream();
684
+
685
+ dim3 blocks(input.size(1));
686
+ int tf = getNumThreads(input.size(2));
687
+ dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
688
+ batch_norm_collect_statistics_kernel<VarTransform, scalar_t, scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
689
+ (input, epsilon, 0.0, mean, invstd);
690
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
691
+ }
692
+
693
+ template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
694
+ void batch_norm_elemt_cuda_template(const Tensor& output_, const Tensor& input_, const Tensor& weight_,
695
+ const Tensor& bias_, const Tensor& mean_, const Tensor& invstd_) {
696
+
697
+ using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
698
+ int64_t n_input = input_.size(1);
699
+ auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
700
+ auto output_reshaped = output_.view({input_.size(0), input_.size(1), -1});
701
+
702
+ auto input = get_packed_accessor<
703
+ const input_scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input");
704
+ auto output = get_packed_accessor<
705
+ input_scalar_t, 3, RestrictPtrTraits, index_t>(output_reshaped, "output");
706
+ auto weight = packed_accessor_or_dummy<
707
+ const stat_scalar_t, 1, RestrictPtrTraits, index_t>(weight_, "weight");
708
+ auto bias = packed_accessor_or_dummy<
709
+ const stat_scalar_t, 1, RestrictPtrTraits, index_t>(bias_, "bias");
710
+ auto mean = packed_accessor_or_dummy<
711
+ stat_accscalar_t, 1, RestrictPtrTraits, index_t>(mean_, "mean");
712
+ auto invstd = packed_accessor_or_dummy<
713
+ stat_accscalar_t, 1, RestrictPtrTraits, index_t>(invstd_, "invstd");
714
+ auto stream = at::cuda::getCurrentCUDAStream();
715
+
716
+ // NOTE: We use transform_input_kernel in training mode, which ignores epsilon
717
+ const double dummy_epsilon = 1e-5;
718
+
719
+ // The input_transform kernel is pointwise, but we need to balance reading parameters (save_var/mean,
720
+ // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
721
+ // and good occupancy. Quiet likely, we could go with even more blocks than 1024.
722
+ // The various planes are independent, so we use blocks for them.
723
+ int tf = std::max<int>(getNumThreads(input.size(2)/4),
724
+ std::min<int>(getNumThreads(input.size(2)), 64));
725
+ int tb = std::max<int>(64/tf, 1);
726
+ dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
727
+ (input.size(0)+tb-1)/tb)));
728
+ blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE);
729
+ dim3 threads_trans(tf, tb);
730
+ batch_norm_transform_input_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, true, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
731
+ (input, output, mean, invstd, weight, bias, dummy_epsilon);
732
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
733
+ }
734
+
735
+ template<typename scalar_t, typename accscalar_t, typename index_t>
736
+ std::tuple<Tensor, Tensor> batch_norm_gather_stats_cuda_template(const Tensor& mean_, const Tensor& invstd_,
737
+ const Tensor& running_mean_, const Tensor& running_var_,
738
+ double momentum, double epsilon, const Tensor& counts_) {
739
+
740
+ Tensor save_mean_;
741
+ Tensor save_invstd_;
742
+
743
+ auto features = mean_.size(1);
744
+ auto input_options = mean_.options();
745
+ if (mean_.scalar_type() == at::ScalarType::Half || mean_.scalar_type() == at::ScalarType::BFloat16) {
746
+ input_options = input_options.dtype(ScalarType::Float);
747
+ }
748
+ save_mean_ = at::empty({features}, input_options);
749
+ save_invstd_ = at::empty({features}, input_options);
750
+
751
+ auto mean = packed_accessor_or_dummy<
752
+ accscalar_t, 2, RestrictPtrTraits, index_t>(mean_, "mean");
753
+ auto invstd = packed_accessor_or_dummy<
754
+ accscalar_t, 2, RestrictPtrTraits, index_t>(invstd_, "invstd");
755
+ auto running_mean = packed_accessor_or_dummy<
756
+ scalar_t, 1, RestrictPtrTraits, index_t>(running_mean_, "running_mean");
757
+ auto running_var = packed_accessor_or_dummy<
758
+ scalar_t, 1, RestrictPtrTraits, index_t>(running_var_, "running_mean");
759
+ auto counts = packed_accessor_or_dummy<
760
+ scalar_t, 1, RestrictPtrTraits, index_t>(counts_, "counts");
761
+
762
+ auto save_mean = get_packed_accessor<
763
+ accscalar_t, 1, RestrictPtrTraits, index_t>(save_mean_, "save_mean");
764
+ auto save_invstd = get_packed_accessor<
765
+ accscalar_t, 1, RestrictPtrTraits, index_t>(save_invstd_, "save_invstd");
766
+ auto stream = at::cuda::getCurrentCUDAStream();
767
+
768
+ int block = getNumThreads(features);
769
+ int grid = std::max<int>(1, features/block);
770
+ batch_norm_reduce_statistics_kernel<scalar_t, accscalar_t, index_t> <<<grid, block, 0, stream>>>
771
+ (mean, invstd, save_mean, save_invstd, running_mean, running_var, epsilon, momentum, counts);
772
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
773
+
774
+ return std::make_tuple(save_mean_, save_invstd_);
775
+ }
776
+
777
+ template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
778
+ std::tuple<Tensor, Tensor, Tensor, Tensor> batch_norm_backward_reduce_cuda_template(const Tensor& grad_out_, const Tensor& input_,
779
+ const Tensor& mean_, const Tensor& invstd_, const Tensor& weight_,
780
+ const bool input_g, const bool weight_g, const bool bias_g) {
781
+
782
+ using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
783
+ int64_t n_input = input_.size(1);
784
+ Tensor sum_dy_;
785
+ Tensor sum_dy_xmu_;
786
+ Tensor grad_weight_;
787
+ Tensor grad_bias_;
788
+ auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
789
+ auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
790
+
791
+ if (input_g) {
792
+ sum_dy_ = at::empty_like(mean_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
793
+ sum_dy_xmu_ = at::empty_like(mean_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
794
+ }
795
+ if (weight_g) {
796
+ grad_weight_ = at::empty({n_input}, weight_.options());
797
+ }
798
+ if (bias_g) {
799
+ grad_bias_ = at::empty({n_input}, weight_.options());
800
+ }
801
+
802
+ auto input = get_packed_accessor<
803
+ input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
804
+ auto grad_output = get_packed_accessor<
805
+ input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
806
+ auto grad_weight = packed_accessor_or_dummy<
807
+ stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_, "grad_weight");
808
+ auto grad_bias = packed_accessor_or_dummy<
809
+ stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_, "grad_bias");
810
+ auto mean = packed_accessor_or_dummy<
811
+ stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean");
812
+ auto invstd = packed_accessor_or_dummy<
813
+ stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd");
814
+ auto sum_dy = packed_accessor_or_dummy<
815
+ stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy");
816
+ auto sum_dy_xmu = packed_accessor_or_dummy<
817
+ stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu");
818
+
819
+ auto batch_size = input_reshaped.size(0);
820
+ auto feature_size = input_reshaped.size(2);
821
+ auto stream = at::cuda::getCurrentCUDAStream();
822
+
823
+ int warp_size = at::cuda::warp_size();
824
+ int block_y = std::min<int>(lastPow2(batch_size), MAX_BLOCK_SIZE/warp_size);
825
+ // We want block_x to be at least a warp width
826
+ int block_x = std::min<int>(std::max<int>(getNumThreads(feature_size), warp_size), MAX_BLOCK_SIZE/block_y);
827
+ const dim3 block(block_x, block_y);
828
+ const dim3 grid(n_input);
829
+
830
+ batch_norm_backward_reduce_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, index_t> <<<grid, block, 0, stream>>>
831
+ (input, grad_output, mean, invstd, sum_dy, sum_dy_xmu, grad_weight, grad_bias);
832
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
833
+
834
+ return std::make_tuple(sum_dy_, sum_dy_xmu_, grad_weight_, grad_bias_);
835
+ }
836
+
837
+ template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
838
+ Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Tensor& input_,
839
+ const Tensor& mean_, const Tensor& invstd_,
840
+ const Tensor& weight_, const Tensor& sum_dy_, const Tensor& sum_dy_xmu_) {
841
+
842
+ using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
843
+ int64_t n_input = input_.size(1);
844
+ auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
845
+ auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
846
+ auto grad_input_reshaped = at::empty_like(input_reshaped, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
847
+
848
+ auto input = get_packed_accessor<
849
+ input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
850
+ auto grad_input = get_packed_accessor<
851
+ input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input");
852
+ auto grad_output = get_packed_accessor<
853
+ input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
854
+ auto mean = packed_accessor_or_dummy<
855
+ stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean");
856
+ auto invstd = packed_accessor_or_dummy<
857
+ stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd");
858
+ auto weight = packed_accessor_or_dummy<
859
+ stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight");
860
+ auto sum_dy = packed_accessor_or_dummy<
861
+ stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy");
862
+ auto sum_dy_xmu = packed_accessor_or_dummy<
863
+ stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu");
864
+
865
+ auto stream = at::cuda::getCurrentCUDAStream();
866
+
867
+ // The kernel is pointwise, but we need to balance reading parameters (save_var/mean,
868
+ // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
869
+ // and good occupancy. Quiet likely, we could go with even more blocks than 1024.
870
+ // The various planes are independent, so we use blocks for them.
871
+ int tf = std::max<int>(getNumThreads(input.size(2)/4),
872
+ std::min<int>(getNumThreads(input.size(2)), 64));
873
+ int tb = std::max<int>(64/tf, 1);
874
+ dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
875
+ (input.size(0)+tb-1)/tb)));
876
+ blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE);
877
+ dim3 threads_trans(tf, tb);
878
+ auto reduction_size = input_.numel() / n_input;
879
+ auto norm_fct = static_cast<stat_accscalar_t>(1.0 / reduction_size);
880
+ batch_norm_backward_elemt_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, index_t>
881
+ <<<blocks_trans, threads_trans, 0, stream>>>
882
+ (input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct);
883
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
884
+
885
+ return grad_input_reshaped.view(input_.sizes());
886
+ }
887
+
888
+ template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
889
+ Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Tensor& input_,
890
+ const Tensor& mean_, const Tensor& invstd_,
891
+ const Tensor& weight_, const Tensor& sum_dy_, const Tensor& sum_dy_xmu_, const Tensor& count) {
892
+
893
+ using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
894
+ int64_t n_input = input_.size(1);
895
+ auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
896
+ auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
897
+ auto grad_input_reshaped = at::empty_like(input_reshaped, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
898
+
899
+ auto input = get_packed_accessor<
900
+ input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
901
+ auto grad_input = get_packed_accessor<
902
+ input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input");
903
+ auto grad_output = get_packed_accessor<
904
+ input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
905
+ auto mean = packed_accessor_or_dummy<
906
+ stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean");
907
+ auto invstd = packed_accessor_or_dummy<
908
+ stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd");
909
+ auto weight = packed_accessor_or_dummy<
910
+ stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight");
911
+ auto sum_dy = packed_accessor_or_dummy<
912
+ stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy");
913
+ auto sum_dy_xmu = packed_accessor_or_dummy<
914
+ stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu");
915
+
916
+ auto stream = at::cuda::getCurrentCUDAStream();
917
+
918
+ // The kernel is pointwise, but we need to balance reading parameters (save_var/mean,
919
+ // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
920
+ // and good occupancy. Quiet likely, we could go with even more blocks than 1024.
921
+ // The various planes are independent, so we use blocks for them.
922
+ int tf = std::max<int>(getNumThreads(input.size(2)/4),
923
+ std::min<int>(getNumThreads(input.size(2)), 64));
924
+ int tb = std::max<int>(64/tf, 1);
925
+ dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
926
+ (input.size(0)+tb-1)/tb)));
927
+ blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE);
928
+ dim3 threads_trans(tf, tb);
929
+ batch_norm_backward_elemt_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
930
+ (input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, count.const_data_ptr<int>(), count.numel());
931
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
932
+
933
+ return grad_input_reshaped.view(input_.sizes());
934
+ }
935
+
936
+ // welford kernel for c last tensor calculating mean/biased_variance/unbiased_variance
937
+ // original apex name: welford_kernel_c_last
938
+ template
939
+ <typename VarTransform,
940
+ typename scalar_t,
941
+ typename accscalar_t,
942
+ int PARALLEL_LOADS>
943
+ __global__ void
944
+ batch_norm_collect_statistics_channels_last_kernel(
945
+ const scalar_t* __restrict__ input,
946
+ accscalar_t* __restrict__ out_mean,
947
+ accscalar_t* __restrict__ out_invstd,
948
+ volatile accscalar_t* staging_data,
949
+ int* semaphores,
950
+ const int reduction_size,
951
+ const int stride,
952
+ accscalar_t epsilon) {
953
+ // hide latency with concurrency
954
+ accscalar_t x_mean[PARALLEL_LOADS];
955
+ accscalar_t m_2_n[PARALLEL_LOADS];
956
+ int count[PARALLEL_LOADS];
957
+
958
+ #pragma unroll
959
+ for (int i = 0; i < PARALLEL_LOADS; i++) {
960
+ x_mean[i] = accscalar_t(0);
961
+ m_2_n[i] = accscalar_t(0);
962
+ count[i] = accscalar_t(0);
963
+ }
964
+ // tensor dimension (m,c)
965
+
966
+ // loop along m dimension
967
+ int inner_loop_stride = blockDim.y * gridDim.y;
968
+
969
+ // offset along m dimension
970
+ int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
971
+ int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
972
+
973
+ int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
974
+ int address_base = m_offset * stride + c_offset;
975
+ int address_increment = inner_loop_stride * stride;
976
+
977
+ for (int i = 0; i < loop_count; i++) {
978
+ accscalar_t x_math[PARALLEL_LOADS];
979
+ accscalar_t x_count_inv[PARALLEL_LOADS];
980
+ accscalar_t is_valid[PARALLEL_LOADS];
981
+
982
+ // load multiple data in
983
+ #pragma unroll
984
+ for (int j = 0; j < PARALLEL_LOADS; j++) {
985
+ if (c_offset < stride && m_offset < reduction_size) {
986
+ x_math[j] = input[address_base];
987
+ count[j]++;
988
+ x_count_inv[j] = accscalar_t(1) / count[j];
989
+ is_valid[j] = accscalar_t(1);
990
+ } else {
991
+ x_math[j] = accscalar_t(0);
992
+ x_count_inv[j] = accscalar_t(0);
993
+ is_valid[j] = accscalar_t(0);
994
+ }
995
+ m_offset += inner_loop_stride;
996
+ address_base += address_increment;
997
+ }
998
+
999
+ // calculate mean/m2n with welford
1000
+ #pragma unroll
1001
+ for (int j = 0; j < PARALLEL_LOADS; j++) {
1002
+ accscalar_t delta0 = x_math[j] - x_mean[j];
1003
+ x_mean[j] += delta0 * x_count_inv[j];
1004
+ accscalar_t delta1 = x_math[j] - x_mean[j];
1005
+ m_2_n[j] += delta0 * delta1 * is_valid[j];
1006
+ }
1007
+ }
1008
+
1009
+ // thread reduction to accumulate mean/m_2_n/count between PARALLEL_LOADS
1010
+ #pragma unroll
1011
+ for (int j = 1; j < PARALLEL_LOADS; j++) {
1012
+ welford_merge_element(count[0], x_mean[0], m_2_n[0], count[j], x_mean[j], m_2_n[j]);
1013
+ }
1014
+
1015
+ // release x_mean / m_2_n
1016
+ auto mean_th = x_mean[0];
1017
+ auto m2_th = m_2_n[0];
1018
+ auto count_th = count[0];
1019
+
1020
+ // block-wise reduction with shared memory (since reduction cannot be done within a warp)
1021
+ static __shared__ accscalar_t shmem_mean[MAX_BLOCK_SIZE];
1022
+ static __shared__ accscalar_t shmem_m2n[MAX_BLOCK_SIZE];
1023
+ static __shared__ int shmem_count[MAX_BLOCK_SIZE];
1024
+
1025
+ welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);
1026
+
1027
+ if (gridDim.y > 1) {
1028
+ volatile accscalar_t* staging_mean = staging_data;
1029
+ volatile accscalar_t* staging_m2n = &staging_data[stride*gridDim.y];
1030
+ volatile int* staging_count = reinterpret_cast<volatile int*>(&staging_m2n[stride*gridDim.y]);
1031
+
1032
+ address_base = c_offset + blockIdx.y * stride;
1033
+ // write data to staging_data;
1034
+ if (threadIdx.y == 0 && c_offset < stride) {
1035
+ staging_mean[address_base] = mean_th;
1036
+ staging_m2n[address_base] = m2_th;
1037
+ staging_count[address_base] = count_th;
1038
+ }
1039
+
1040
+ __threadfence();
1041
+ __syncthreads(); // ensuring writes to staging_ is visible to all blocks
1042
+
1043
+ __shared__ bool is_last_block_done;
1044
+ // mark block done
1045
+ if (threadIdx.x == 0 && threadIdx.y == 0) {
1046
+ int old = atomicAdd(&semaphores[blockIdx.x], 1);
1047
+ is_last_block_done = (old == (gridDim.y-1));
1048
+ }
1049
+
1050
+ __syncthreads();
1051
+
1052
+ // check that all data is now available in global memory
1053
+ if (is_last_block_done) {
1054
+ count_th = 0;
1055
+ mean_th = accscalar_t(0.0);
1056
+ m2_th = accscalar_t(0.0);
1057
+
1058
+ for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
1059
+ address_base = c_offset + y * stride;
1060
+ int count_new = c_offset < stride ? staging_count[address_base] : 0;
1061
+ accscalar_t mean_new = c_offset < stride ? staging_mean[address_base] : accscalar_t(0.0);
1062
+ accscalar_t m2n_new = c_offset < stride ? staging_m2n[address_base] : accscalar_t(0.0);
1063
+
1064
+ welford_merge_element(count_th, mean_th, m2_th, count_new, mean_new, m2n_new);
1065
+ }
1066
+
1067
+ welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);
1068
+ if (threadIdx.y == 0 && c_offset < stride) {
1069
+ out_mean[c_offset] = static_cast<accscalar_t>(mean_th);
1070
+ out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon);
1071
+ }
1072
+ }
1073
+ } else {
1074
+ if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {
1075
+ out_mean[c_offset] = static_cast<accscalar_t>(mean_th);
1076
+ out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon);
1077
+ }
1078
+ }
1079
+ }
1080
+
1081
+ // elementwise BN kernel
1082
+ // original apex name: batchnorm_forward_c_last_kernel
1083
+ template <
1084
+ typename scalar_t,
1085
+ typename accscalar_t,
1086
+ typename layerscalar_t,
1087
+ int PARALLEL_LOADS>
1088
+ __global__ void batch_norm_transform_input_channels_last_kernel(
1089
+ const scalar_t* __restrict__ input,
1090
+ const scalar_t* __restrict__ z,
1091
+ const accscalar_t* __restrict__ mean,
1092
+ const accscalar_t* __restrict__ inv_std,
1093
+ const layerscalar_t* __restrict__ weight,
1094
+ const layerscalar_t* __restrict__ shift,
1095
+ scalar_t* __restrict__ out,
1096
+ const int reduction_size,
1097
+ const int stride,
1098
+ const bool fuse_relu) {
1099
+ // tensor dimension (m,c)
1100
+ // loop along m dimension
1101
+ int inner_loop_stride = blockDim.y * gridDim.y;
1102
+
1103
+ // offset along m dimension
1104
+ int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
1105
+ int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
1106
+
1107
+ if (c_offset >= stride || m_offset >= reduction_size) {
1108
+ return;
1109
+ }
1110
+
1111
+ auto m_c = mean[c_offset];
1112
+ auto inv_std_c = static_cast<accscalar_t>(inv_std[c_offset]);
1113
+ auto w_c = weight == nullptr ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset]);
1114
+ auto s_c = shift == nullptr ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[c_offset]);
1115
+
1116
+ int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
1117
+ int address_base = m_offset * stride + c_offset;
1118
+ int address_increment = inner_loop_stride * stride;
1119
+
1120
+ for (int i = 0; i < loop_count; i++) {
1121
+ #pragma unroll
1122
+ for (int j = 0; j < PARALLEL_LOADS; j++) {
1123
+ if (c_offset < stride && m_offset < reduction_size) {
1124
+ auto tmp = w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c;
1125
+ if (z != nullptr) {
1126
+ tmp += z[address_base];
1127
+ }
1128
+ out[address_base] = (fuse_relu && tmp <= accscalar_t(0.0) ? scalar_t(0.0) : static_cast<scalar_t>(tmp));
1129
+ }
1130
+ m_offset += inner_loop_stride;
1131
+ address_base += address_increment;
1132
+ }
1133
+ }
1134
+ }
1135
+
1136
+ template<typename T>
1137
+ __device__ __forceinline__ void merge_block_vertical_backward(T& sum_dy,
1138
+ T& sum_dy_xmu,
1139
+ T* shmem_sum_dy,
1140
+ T* shmem_sum_dy_xmu) {
1141
+ // write to shared memory
1142
+ auto address_base = threadIdx.x + threadIdx.y * blockDim.x;
1143
+
1144
+ #pragma unroll
1145
+ for (int offset = blockDim.y/2; offset > 0; offset >>= 1) {
1146
+ if (threadIdx.y < offset*2) {
1147
+ shmem_sum_dy[address_base] = sum_dy;
1148
+ shmem_sum_dy_xmu[address_base] = sum_dy_xmu;
1149
+ }
1150
+ __syncthreads();
1151
+ if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
1152
+ auto address = address_base + offset * blockDim.x;
1153
+
1154
+ sum_dy += shmem_sum_dy[address];
1155
+ sum_dy_xmu += shmem_sum_dy_xmu[address];
1156
+ }
1157
+ }
1158
+ }
1159
+
1160
+ // batchnorm backward kernel for c last tensor
1161
+ // original apex name: reduce_bn_c_last_kernel
1162
+ template <
1163
+ int PARALLEL_LOADS,
1164
+ typename scalar_t,
1165
+ typename accscalar_t,
1166
+ typename layerscalar_t>
1167
+ __global__ void batch_norm_backward_reduce_channels_last_kernel(
1168
+ const scalar_t* __restrict__ input,
1169
+ const scalar_t* __restrict__ grad_output,
1170
+ const accscalar_t* __restrict__ mean,
1171
+ const accscalar_t* __restrict__ inv_std,
1172
+ accscalar_t* __restrict__ sum_dy_o,
1173
+ accscalar_t* __restrict__ sum_dy_xmu_o,
1174
+ layerscalar_t* __restrict__ grad_weight,
1175
+ layerscalar_t* __restrict__ grad_bias,
1176
+ volatile accscalar_t* staging_data,
1177
+ int* semaphores,
1178
+ const int reduction_size,
1179
+ const int stride) {
1180
+
1181
+ // hide latency with concurrency
1182
+ accscalar_t sum_dy[PARALLEL_LOADS];
1183
+ accscalar_t sum_dy_xmu[PARALLEL_LOADS];
1184
+
1185
+ #pragma unroll
1186
+ for (int i = 0; i < PARALLEL_LOADS; i++) {
1187
+ sum_dy[i] = accscalar_t(0);
1188
+ sum_dy_xmu[i] = accscalar_t(0);
1189
+ }
1190
+ // tensor dimension (m,c)
1191
+
1192
+ // loop along m dimension
1193
+ int inner_loop_stride = blockDim.y * gridDim.y;
1194
+
1195
+ // offset along m dimension
1196
+ int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
1197
+ int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
1198
+
1199
+ if (c_offset >= stride || m_offset >= reduction_size) {
1200
+ return;
1201
+ }
1202
+
1203
+ int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
1204
+ int address_base = m_offset * stride + c_offset;
1205
+ int address_increment = inner_loop_stride * stride;
1206
+
1207
+ auto r_mean = mean[c_offset];
1208
+ auto factor = inv_std[c_offset];
1209
+
1210
+ for (int i = 0; i < loop_count; i++) {
1211
+ accscalar_t x_input[PARALLEL_LOADS];
1212
+ accscalar_t x_grad_output[PARALLEL_LOADS];
1213
+
1214
+ // load multiple data in
1215
+ #pragma unroll
1216
+ for (int j = 0; j < PARALLEL_LOADS; j++) {
1217
+ if (c_offset < stride && m_offset < reduction_size) {
1218
+ x_input[j] = input[address_base];
1219
+ x_grad_output[j] = grad_output[address_base];
1220
+ } else {
1221
+ x_input[j] = accscalar_t(0);
1222
+ x_grad_output[j] = accscalar_t(0);
1223
+ }
1224
+ m_offset += inner_loop_stride;
1225
+ address_base += address_increment;
1226
+ }
1227
+
1228
+ // calculate sum_dy / sum_dy_xmu
1229
+ #pragma unroll
1230
+ for (int j = 0; j < PARALLEL_LOADS; j++) {
1231
+ sum_dy[j] += x_grad_output[j];
1232
+ sum_dy_xmu[j] += x_grad_output[j] * (x_input[j] - r_mean);
1233
+ }
1234
+ }
1235
+
1236
+ // thread reduction to accumulate sum_dy / sum_dy_xmu between PARALLEL_LOADS
1237
+ #pragma unroll
1238
+ for (int j = 1; j < PARALLEL_LOADS; j++) {
1239
+ sum_dy[0] += sum_dy[j];
1240
+ sum_dy_xmu[0] += sum_dy_xmu[j];
1241
+ }
1242
+
1243
+ // release array of registers
1244
+ auto sum_dy_th = sum_dy[0];
1245
+ auto sum_dy_xmu_th = sum_dy_xmu[0];
1246
+
1247
+ // block-wise reduction with shared memory (since reduction cannot be done within a warp)
1248
+ static __shared__ accscalar_t shmem_sum_dy[MAX_BLOCK_SIZE];
1249
+ static __shared__ accscalar_t shmem_sum_dy_xmu[MAX_BLOCK_SIZE];
1250
+
1251
+ merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);
1252
+
1253
+ if (gridDim.y > 1) {
1254
+ volatile accscalar_t* staging_sum_dy = staging_data;
1255
+ volatile accscalar_t* staging_sum_dy_xmu = &staging_data[stride*gridDim.y];
1256
+
1257
+ address_base = c_offset + blockIdx.y * stride;
1258
+ // write data to staging_data;
1259
+ if (threadIdx.y == 0 && c_offset < stride) {
1260
+ staging_sum_dy[address_base] = sum_dy_th;
1261
+ staging_sum_dy_xmu[address_base] = sum_dy_xmu_th;
1262
+ }
1263
+
1264
+ __threadfence();
1265
+ __syncthreads(); // ensuring writes to staging_ is visible to all blocks
1266
+
1267
+ __shared__ bool is_last_block_done;
1268
+ // mark block done
1269
+ if (threadIdx.x == 0 && threadIdx.y == 0) {
1270
+ int old = atomicAdd(&semaphores[blockIdx.x], 1);
1271
+ is_last_block_done = (old == (gridDim.y-1));
1272
+ }
1273
+
1274
+ __syncthreads();
1275
+
1276
+ // check that all data is now available in global memory
1277
+ if (is_last_block_done) {
1278
+ sum_dy_th = accscalar_t(0.0);
1279
+ sum_dy_xmu_th = accscalar_t(0.0);
1280
+
1281
+ for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
1282
+ address_base = c_offset + y * stride;
1283
+ sum_dy_th += (c_offset < stride ? staging_sum_dy[address_base] : accscalar_t(0.0));
1284
+ sum_dy_xmu_th += (c_offset < stride ? staging_sum_dy_xmu[address_base] : accscalar_t(0.0));
1285
+ }
1286
+
1287
+ merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);
1288
+ if (threadIdx.y == 0 && c_offset < stride) {
1289
+ if (grad_bias != nullptr) {
1290
+ grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);
1291
+ }
1292
+ if (grad_weight != nullptr) {
1293
+ grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
1294
+ }
1295
+ //mean_dy[c_offset] = sum_dy_th / reduction_size;
1296
+ //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
1297
+ sum_dy_o[c_offset] = sum_dy_th;
1298
+ sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;
1299
+ }
1300
+ }
1301
+ } else {
1302
+ if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {
1303
+ if (grad_bias != nullptr) {
1304
+ grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);
1305
+ }
1306
+ if (grad_weight != nullptr) {
1307
+ grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
1308
+ }
1309
+ //mean_dy[c_offset] = sum_dy_th / reduction_size;
1310
+ //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
1311
+ sum_dy_o[c_offset] = sum_dy_th;
1312
+ sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;
1313
+ }
1314
+ }
1315
+ }
1316
+
1317
+ // elementwise BN kernel
1318
+ // original apex name: batchnorm_backward_c_last_kernel
1319
+ template <
1320
+ int PARALLEL_LOADS,
1321
+ typename scalar_t,
1322
+ typename accscalar_t,
1323
+ typename layerscalar_t>
1324
+ __device__ __forceinline__ void batch_norm_backward_elemt_channels_last_kernel_impl(
1325
+ const scalar_t* __restrict__ grad_output,
1326
+ const scalar_t* __restrict__ input,
1327
+ const accscalar_t* __restrict__ mean,
1328
+ const accscalar_t* __restrict__ inv_std,
1329
+ const layerscalar_t* __restrict__ weight,
1330
+ const accscalar_t* __restrict__ sum_dy,
1331
+ const accscalar_t* __restrict__ sum_dy_xmu,
1332
+ scalar_t* __restrict__ grad_input,
1333
+ const accscalar_t norm_fct,
1334
+ const int reduction_size,
1335
+ const int stride) {
1336
+ // tensor dimension (m,c)
1337
+ // loop along m dimension
1338
+ int inner_loop_stride = blockDim.y * gridDim.y;
1339
+
1340
+ // offset along m dimension
1341
+ int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
1342
+ int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
1343
+
1344
+ if (c_offset >= stride || m_offset >= reduction_size) {
1345
+ return;
1346
+ }
1347
+
1348
+ auto m_c = mean[c_offset];
1349
+ auto m_dy_c = sum_dy[c_offset] * norm_fct;
1350
+ auto factor_1_c = inv_std[c_offset];
1351
+ auto factor_2_c = (weight == nullptr? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset])) * factor_1_c;
1352
+ factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[c_offset] * norm_fct;
1353
+
1354
+ int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
1355
+ int address_base = m_offset * stride + c_offset;
1356
+ int address_increment = inner_loop_stride * stride;
1357
+
1358
+ for (int i = 0; i < loop_count; i++) {
1359
+ #pragma unroll
1360
+ for (int j = 0; j < PARALLEL_LOADS; j++) {
1361
+ if (c_offset < stride && m_offset < reduction_size) {
1362
+ grad_input[address_base] = static_cast<scalar_t>(
1363
+ (static_cast<accscalar_t>(grad_output[address_base]) - m_dy_c -
1364
+ (static_cast<accscalar_t>(input[address_base]) - m_c) * factor_1_c)
1365
+ * factor_2_c);
1366
+ }
1367
+ m_offset += inner_loop_stride;
1368
+ address_base += address_increment;
1369
+ }
1370
+ }
1371
+ }
1372
+
1373
+ template <
1374
+ int PARALLEL_LOADS,
1375
+ typename scalar_t,
1376
+ typename accscalar_t,
1377
+ typename layerscalar_t>
1378
+ __global__ void batch_norm_backward_elemt_channels_last_kernel(
1379
+ const scalar_t* __restrict__ grad_output,
1380
+ const scalar_t* __restrict__ input,
1381
+ const accscalar_t* __restrict__ mean,
1382
+ const accscalar_t* __restrict__ inv_std,
1383
+ const layerscalar_t* __restrict__ weight,
1384
+ const accscalar_t* __restrict__ sum_dy,
1385
+ const accscalar_t* __restrict__ sum_dy_xmu,
1386
+ const int* __restrict__ numel,
1387
+ scalar_t* __restrict__ grad_input,
1388
+ const int64_t world_size,
1389
+ const int reduction_size,
1390
+ const int stride) {
1391
+
1392
+ int64_t total_numel = 0;
1393
+ for (int i = 0; i < world_size; i++) {
1394
+ total_numel += numel[i];
1395
+ }
1396
+
1397
+ auto norm_fct = static_cast<accscalar_t>(1) / static_cast<accscalar_t>(total_numel);
1398
+ batch_norm_backward_elemt_channels_last_kernel_impl<PARALLEL_LOADS>(
1399
+ grad_output, input, mean, inv_std, weight, sum_dy, sum_dy_xmu,
1400
+ grad_input, norm_fct, reduction_size, stride);
1401
+ }
1402
+
1403
+ template <
1404
+ int PARALLEL_LOADS,
1405
+ typename scalar_t,
1406
+ typename accscalar_t,
1407
+ typename layerscalar_t>
1408
+ __global__ void batch_norm_backward_elemt_channels_last_kernel(
1409
+ const scalar_t* __restrict__ grad_output,
1410
+ const scalar_t* __restrict__ input,
1411
+ const accscalar_t* __restrict__ mean,
1412
+ const accscalar_t* __restrict__ inv_std,
1413
+ const layerscalar_t* __restrict__ weight,
1414
+ const accscalar_t* __restrict__ sum_dy,
1415
+ const accscalar_t* __restrict__ sum_dy_xmu,
1416
+ scalar_t* __restrict__ grad_input,
1417
+ const accscalar_t norm_fct,
1418
+ const int reduction_size,
1419
+ const int stride) {
1420
+ batch_norm_backward_elemt_channels_last_kernel_impl<PARALLEL_LOADS>(
1421
+ grad_output, input, mean, inv_std, weight, sum_dy, sum_dy_xmu,
1422
+ grad_input, norm_fct, reduction_size, stride);
1423
+ }
1424
+
1425
+ template<typename scalar_t, typename VarTransform>
1426
+ void batch_norm_stats_channels_last_cuda_template(
1427
+ const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input, double epsilon) {
1428
+ using accscalar_t = at::acc_type<scalar_t, true>;
1429
+
1430
+ const auto stride = input.sizes()[1];
1431
+ const auto reduction_size = input.numel() / stride;
1432
+
1433
+ resize_output(out_mean, {stride});
1434
+ resize_output(out_invstd, {stride});
1435
+ TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() &&
1436
+ out_invstd.sizes()[0]);
1437
+ TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() &&
1438
+ out_mean.sizes()[0]);
1439
+
1440
+ dim3 block;
1441
+ dim3 grid;
1442
+ flexible_launch_configs(reduction_size, stride, block, grid, true);
1443
+
1444
+ at::Tensor staging_data;
1445
+ at::Tensor semaphores;
1446
+ if (grid.y > 1) {
1447
+ staging_data = at::empty({4*stride*grid.y}, out_mean.options());
1448
+ semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
1449
+ }
1450
+
1451
+ accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.mutable_data_ptr<accscalar_t>() : nullptr;
1452
+ int* semaphores_ptr = grid.y > 1 ? semaphores.mutable_data_ptr<int>() : nullptr;
1453
+ batch_norm_collect_statistics_channels_last_kernel<VarTransform, scalar_t, accscalar_t, ELEMENTS_PER_ITER>
1454
+ <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
1455
+ input.const_data_ptr<scalar_t>(),
1456
+ out_mean.mutable_data_ptr<accscalar_t>(),
1457
+ out_invstd.mutable_data_ptr<accscalar_t>(),
1458
+ staging_data_ptr,
1459
+ semaphores_ptr,
1460
+ reduction_size,
1461
+ stride,
1462
+ epsilon);
1463
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
1464
+ }
1465
+
1466
+ void batch_norm_elemt_channels_last_cuda_template(
1467
+ const at::Tensor& output,
1468
+ const at::Tensor& input,
1469
+ const at::Tensor& weight,
1470
+ const at::Tensor& shift, // bias of BN
1471
+ const at::Tensor& mean,
1472
+ const at::Tensor& inv_std,
1473
+ const std::optional<at::Tensor>& z = std::nullopt, // bias after BN
1474
+ const bool fuse_relu = false) {
1475
+ const auto stride = input.sizes()[1];
1476
+ const auto reduction_size = input.numel() / stride;
1477
+
1478
+ dim3 block;
1479
+ dim3 grid;
1480
+ flexible_launch_configs(reduction_size, stride, block, grid);
1481
+
1482
+ auto stream = at::cuda::getCurrentCUDAStream();
1483
+ const auto second_dtype = weight.defined() ? weight.scalar_type() :
1484
+ (shift.defined() ? shift.scalar_type() : input.scalar_type());
1485
+
1486
+ if (input.scalar_type() != second_dtype) {
1487
+ AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward", [&] {
1488
+ using accscalar_t = at::acc_type<scalar_t, true>;
1489
+ batch_norm_transform_input_channels_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
1490
+ <<<grid, block, 0, stream>>>(
1491
+ input.const_data_ptr<scalar_t>(),
1492
+ z.has_value() ? z.value().const_data_ptr<scalar_t>() : nullptr,
1493
+ mean.const_data_ptr<accscalar_t>(),
1494
+ inv_std.const_data_ptr<accscalar_t>(),
1495
+ weight.defined() ? weight.const_data_ptr<accscalar_t>() : nullptr,
1496
+ shift.defined() ? shift.const_data_ptr<accscalar_t>() : nullptr,
1497
+ output.mutable_data_ptr<scalar_t>(),
1498
+ reduction_size,
1499
+ stride,
1500
+ fuse_relu);
1501
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
1502
+ });
1503
+ } else {
1504
+ if (weight.defined()){
1505
+ TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_forward: input.scalar_type() ", input.scalar_type(),
1506
+ " is not supported with weight.scalar_type() ", weight.scalar_type());
1507
+ }
1508
+ AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward", [&] {
1509
+ using accscalar_t = at::acc_type<scalar_t, true>;
1510
+ batch_norm_transform_input_channels_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER>
1511
+ <<<grid, block, 0, stream>>>(
1512
+ input.const_data_ptr<scalar_t>(),
1513
+ z.has_value() ? z.value().const_data_ptr<scalar_t>() : nullptr,
1514
+ mean.const_data_ptr<accscalar_t>(),
1515
+ inv_std.const_data_ptr<accscalar_t>(),
1516
+ weight.defined() ? weight.const_data_ptr<scalar_t>() : nullptr,
1517
+ shift.defined() ? shift.const_data_ptr<scalar_t>(): nullptr,
1518
+ output.mutable_data_ptr<scalar_t>(),
1519
+ reduction_size,
1520
+ stride,
1521
+ fuse_relu);
1522
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
1523
+ });
1524
+ }
1525
+ }
1526
+
1527
+ std::tuple<Tensor, Tensor, Tensor, Tensor>
1528
+ batch_norm_backward_reduce_cuda_channels_last_template(const at::Tensor& grad_output,
1529
+ const at::Tensor& input,
1530
+ const at::Tensor& mean,
1531
+ const at::Tensor& inv_std,
1532
+ const at::Tensor& weight,
1533
+ const bool input_g, const bool weight_g, const bool bias_g) {
1534
+ const auto stride = input.sizes()[1];
1535
+ const auto reduction_size = input.numel() / stride;
1536
+
1537
+ at::Tensor sumn_dy = at::empty({stride}, mean.options());
1538
+ at::Tensor sum_dy_xmu = at::empty({stride}, mean.options());
1539
+
1540
+ at::Tensor grad_weight;
1541
+ at::Tensor grad_bias;
1542
+ if (weight.defined()) {
1543
+ grad_weight = at::empty({stride}, weight.options());
1544
+ grad_bias = at::empty({stride}, weight.options());
1545
+ } else {
1546
+ // because I cannot return an uninitialized at::Tensor
1547
+ grad_weight = at::empty({0}, mean.options());
1548
+ grad_bias = at::empty({0}, mean.options());
1549
+ }
1550
+
1551
+ dim3 block;
1552
+ dim3 grid;
1553
+ flexible_launch_configs(reduction_size, stride, block, grid, true);
1554
+
1555
+ at::Tensor staging_data;
1556
+ at::Tensor semaphores;
1557
+ if (grid.y > 1) {
1558
+ staging_data = at::empty({2*stride*grid.y}, mean.options());
1559
+ semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
1560
+ }
1561
+ auto stream = at::cuda::getCurrentCUDAStream();
1562
+
1563
+ if (weight.defined() && input.scalar_type() != weight.scalar_type()) {
1564
+ AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_reduce", [&] {
1565
+ using accscalar_t = at::acc_type<scalar_t, true>;
1566
+ accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.mutable_data_ptr<accscalar_t>() : nullptr;
1567
+ int* semaphores_ptr = grid.y > 1 ? semaphores.mutable_data_ptr<int>() : nullptr;
1568
+ batch_norm_backward_reduce_channels_last_kernel<ELEMENTS_PER_ITER>
1569
+ <<<grid, block, 0, stream>>>(
1570
+ input.const_data_ptr<scalar_t>(),
1571
+ grad_output.const_data_ptr<scalar_t>(),
1572
+ mean.const_data_ptr<accscalar_t>(),
1573
+ inv_std.const_data_ptr<accscalar_t>(),
1574
+ sumn_dy.mutable_data_ptr<accscalar_t>(),
1575
+ sum_dy_xmu.mutable_data_ptr<accscalar_t>(),
1576
+ grad_weight.mutable_data_ptr<accscalar_t>(),
1577
+ grad_bias.mutable_data_ptr<accscalar_t>(),
1578
+ staging_data_ptr,
1579
+ semaphores_ptr,
1580
+ reduction_size,
1581
+ stride);
1582
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
1583
+ });
1584
+ } else {
1585
+ if (weight.defined()) {
1586
+ TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_backward_reduce: input.scalar_type() ", input.scalar_type(),
1587
+ " is not supported with weight.scalar_type() ", weight.scalar_type());
1588
+ }
1589
+ AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_reduce", [&] {
1590
+ using accscalar_t = at::acc_type<scalar_t, true>;
1591
+ accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.mutable_data_ptr<accscalar_t>() : nullptr;
1592
+ int* semaphores_ptr = grid.y > 1 ? semaphores.mutable_data_ptr<int>() : nullptr;
1593
+ batch_norm_backward_reduce_channels_last_kernel<ELEMENTS_PER_ITER>
1594
+ <<<grid, block, 0, stream>>>(
1595
+ input.const_data_ptr<scalar_t>(),
1596
+ grad_output.const_data_ptr<scalar_t>(),
1597
+ mean.const_data_ptr<accscalar_t>(),
1598
+ inv_std.const_data_ptr<accscalar_t>(),
1599
+ sumn_dy.mutable_data_ptr<accscalar_t>(),
1600
+ sum_dy_xmu.mutable_data_ptr<accscalar_t>(),
1601
+ weight.defined() ? grad_weight.mutable_data_ptr<scalar_t>() : nullptr,
1602
+ weight.defined() ? grad_bias.mutable_data_ptr<scalar_t>() : nullptr,
1603
+ staging_data_ptr,
1604
+ semaphores_ptr,
1605
+ reduction_size,
1606
+ stride);
1607
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
1608
+ });
1609
+ }
1610
+
1611
+ return std::make_tuple(sumn_dy, sum_dy_xmu, grad_weight, grad_bias);
1612
+ }
1613
+
1614
+ at::Tensor batch_norm_backward_elemt_channels_last_cuda_template(
1615
+ const at::Tensor& grad_output,
1616
+ const at::Tensor& input,
1617
+ const at::Tensor& mean,
1618
+ const at::Tensor& inv_std,
1619
+ const at::Tensor& weight,
1620
+ const at::Tensor& sum_dy,
1621
+ const at::Tensor& sum_dy_xmu,
1622
+ const at::Tensor& count) {
1623
+ const auto stride = input.sizes()[1];
1624
+ const auto reduction_size = input.numel() / stride;
1625
+
1626
+ // Input is guarunteed to be channels-last compatible
1627
+ at::Tensor grad_input = at::empty_like(input);
1628
+
1629
+ dim3 block;
1630
+ dim3 grid;
1631
+ flexible_launch_configs(reduction_size, stride, block, grid);
1632
+
1633
+ auto stream = at::cuda::getCurrentCUDAStream();
1634
+
1635
+ if (weight.defined() && weight.scalar_type() != input.scalar_type()) {
1636
+ AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_element", [&] {
1637
+ using accscalar_t = at::acc_type<scalar_t, true>;
1638
+ batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER>
1639
+ <<<grid, block, 0, stream>>>(
1640
+ grad_output.const_data_ptr<scalar_t>(),
1641
+ input.const_data_ptr<scalar_t>(),
1642
+ mean.const_data_ptr<accscalar_t>(),
1643
+ inv_std.const_data_ptr<accscalar_t>(),
1644
+ weight.const_data_ptr<accscalar_t>(),
1645
+ sum_dy.const_data_ptr<accscalar_t>(),
1646
+ sum_dy_xmu.const_data_ptr<accscalar_t>(),
1647
+ count.const_data_ptr<int>(),
1648
+ grad_input.mutable_data_ptr<scalar_t>(),
1649
+ count.numel(),
1650
+ reduction_size,
1651
+ stride);
1652
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
1653
+ });
1654
+ } else {
1655
+ if (weight.defined()) {
1656
+ TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_backward_element: input.scalar_type() ", input.scalar_type(),
1657
+ " is not supported with weight.scalar_type() ", weight.scalar_type());
1658
+ }
1659
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "batchnorm_backward_element", [&] {
1660
+ using accscalar_t = at::acc_type<scalar_t, true>;
1661
+ batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER>
1662
+ <<<grid, block, 0, stream>>>(
1663
+ grad_output.const_data_ptr<scalar_t>(),
1664
+ input.const_data_ptr<scalar_t>(),
1665
+ mean.const_data_ptr<accscalar_t>(),
1666
+ inv_std.const_data_ptr<accscalar_t>(),
1667
+ weight.defined() ? weight.const_data_ptr<scalar_t>() : nullptr,
1668
+ sum_dy.const_data_ptr<accscalar_t>(),
1669
+ sum_dy_xmu.const_data_ptr<accscalar_t>(),
1670
+ count.const_data_ptr<int>(),
1671
+ grad_input.mutable_data_ptr<scalar_t>(),
1672
+ count.numel(),
1673
+ reduction_size,
1674
+ stride);
1675
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
1676
+ });
1677
+ }
1678
+
1679
+ return grad_input;
1680
+ }
1681
+
1682
+ at::Tensor batch_norm_backward_elemt_channels_last_cuda_template(
1683
+ const at::Tensor& grad_output,
1684
+ const at::Tensor& input,
1685
+ const at::Tensor& mean,
1686
+ const at::Tensor& inv_std,
1687
+ const at::Tensor& weight,
1688
+ const at::Tensor& sum_dy,
1689
+ const at::Tensor& sum_dy_xmu) {
1690
+ const auto stride = input.sizes()[1];
1691
+ const auto reduction_size = input.numel() / stride;
1692
+ auto norm_fct = 1.0 / reduction_size;
1693
+
1694
+ // Input is guarunteed to be channels-last compatible
1695
+ at::Tensor grad_input = at::empty_like(input);
1696
+
1697
+ dim3 block;
1698
+ dim3 grid;
1699
+ flexible_launch_configs(reduction_size, stride, block, grid);
1700
+
1701
+ auto stream = at::cuda::getCurrentCUDAStream();
1702
+
1703
+ AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_element", [&] {
1704
+ using accscalar_t = at::acc_type<scalar_t, true>;
1705
+
1706
+ if (weight.defined() && weight.scalar_type() != input.scalar_type()) {
1707
+ batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER>
1708
+ <<<grid, block, 0, stream>>>(
1709
+ grad_output.const_data_ptr<scalar_t>(),
1710
+ input.const_data_ptr<scalar_t>(),
1711
+ mean.const_data_ptr<accscalar_t>(),
1712
+ inv_std.const_data_ptr<accscalar_t>(),
1713
+ weight.const_data_ptr<accscalar_t>(),
1714
+ sum_dy.const_data_ptr<accscalar_t>(),
1715
+ sum_dy_xmu.const_data_ptr<accscalar_t>(),
1716
+ grad_input.mutable_data_ptr<scalar_t>(),
1717
+ static_cast<accscalar_t>(norm_fct),
1718
+ reduction_size,
1719
+ stride);
1720
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
1721
+ } else {
1722
+ batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER>
1723
+ <<<grid, block, 0, stream>>>(
1724
+ grad_output.const_data_ptr<scalar_t>(),
1725
+ input.const_data_ptr<scalar_t>(),
1726
+ mean.const_data_ptr<accscalar_t>(),
1727
+ inv_std.const_data_ptr<accscalar_t>(),
1728
+ weight.defined() ? weight.const_data_ptr<scalar_t>() : nullptr,
1729
+ sum_dy.const_data_ptr<accscalar_t>(),
1730
+ sum_dy_xmu.const_data_ptr<accscalar_t>(),
1731
+ grad_input.mutable_data_ptr<scalar_t>(),
1732
+ static_cast<accscalar_t>(norm_fct),
1733
+ reduction_size,
1734
+ stride);
1735
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
1736
+ }
1737
+ });
1738
+
1739
+ return grad_input;
1740
+ }
1741
+
1742
+ } } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Pow.cuh ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/native/Pow.h>
3
+ #include <c10/core/Scalar.h>
4
+
5
+ namespace at { namespace native {
6
+
7
+ namespace {
8
+
9
+
10
+ // SFINAE doesn't work well with NVCC under Windows for math functions like pow and sqrt.
11
+ // So we need to define the functions with the explicit function signatures.
12
+ // As for pow, the following signatures are defined as the device function:
13
+ // pow(float, int)
14
+ // pow(double, int)
15
+ // pow(float, float)
16
+ // pow(double, double)
17
+ #ifdef _MSC_VER
18
+ // Functions for pow
19
+ // pow for at::Half
20
+ static inline __host__ __device__ at::Half pow_(at::Half base, at::Half exp) {
21
+ return static_cast<at::Half>(std::pow(static_cast<float>(base), static_cast<float>(exp)));
22
+ }
23
+ // pow for at::BFloat16
24
+ static inline __host__ __device__ at::BFloat16 pow_(at::BFloat16 base, at::BFloat16 exp) {
25
+ return static_cast<at::BFloat16>(std::pow(static_cast<float>(base), static_cast<float>(exp)));
26
+ }
27
+ // pow (floating, floating/int)
28
+ template <typename Base_type, typename Exp_type>
29
+ static inline __host__ __device__ typename std::enable_if<std::is_floating_point<Base_type>::value && (std::is_same<Base_type, Exp_type>::value || std::is_same<Exp_type, int>::value), Base_type>::type
30
+ pow_(Base_type base, Exp_type exp) {
31
+ return std::pow(base, exp);
32
+ }
33
+ // pow (Otherwise)
34
+ template <typename Base_type, typename Exp_type>
35
+ static inline __host__ __device__ typename std::enable_if<!std::is_same<Base_type, Exp_type>::value && !std::is_same<Exp_type, int>::value, Base_type>::type
36
+ pow_(Base_type base, Exp_type exp) {
37
+ return static_cast<Base_type>(std::pow(static_cast<double>(base), static_cast<double>(exp)));
38
+ }
39
+ #else
40
+ template <typename Base_type, typename Exp_type>
41
+ static inline __host__ __device__ Base_type pow_(Base_type base, Exp_type exp) {
42
+ return ::pow(base, exp);
43
+ }
44
+ #endif
45
+
46
+ template <typename T>
47
+ static inline __host__ __device__ std::enable_if_t<std::is_integral<T>::value, T> pow_(
48
+ T base, T exp) {
49
+ return at::native::powi(base, exp);
50
+ }
51
+
52
+ template <typename T>
53
+ static inline __host__ __device__ c10::complex<T> pow_(c10::complex<T> base, c10::complex<T> exp) {
54
+ return c10_complex_math::pow(base, exp);
55
+ }
56
+
57
+ } // namespace
58
+ }} // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Randperm.cuh ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/cuda/CUDAGeneratorImpl.h>
2
+ #include <ATen/cuda/CUDAGraphsUtils.cuh>
3
+ #include <ATen/Utils.h>
4
+
5
+ #include <curand.h>
6
+ #include <curand_kernel.h>
7
+ #include <curand_philox4x32_x.h>
8
+
9
+ namespace {
10
+
11
+ // See note [Algorithm of randperm]
12
+ template<typename T, typename scalar_t>
13
+ __global__ void randperm_handle_duplicate_keys_kernel(T *keys, scalar_t *data, T mask, int n, at::PhiloxCudaState philox_args) {
14
+ int tid = threadIdx.x + blockDim.x * blockIdx.x;
15
+
16
+ // find the beginning of islands
17
+ if (tid >= n - 1) return; // out of range
18
+ if ((keys[tid] & mask) != (keys[tid + 1] & mask)) return; // not in an island
19
+ if (tid != 0 && (keys[tid] & mask) == (keys[tid - 1] & mask)) return; // not the beginning of an island
20
+
21
+ // find the size of islands
22
+ int island_size = 0;
23
+ do { island_size++; }
24
+ while ((tid + island_size < n) && (keys[tid + island_size] & mask) == (keys[tid] & mask));
25
+
26
+ // do random permutation inside each island.
27
+ data += tid;
28
+ auto seeds = at::cuda::philox::unpack(philox_args);
29
+ curandStatePhilox4_32_10_t state;
30
+ curand_init(std::get<0>(seeds), tid, std::get<1>(seeds), &state);
31
+ for (int i = island_size - 1; i > 0; i--) {
32
+ unsigned int r = curand(&state) % (i + 1);
33
+ if (i != r) {
34
+ scalar_t tmp = data[i];
35
+ data[i] = data[r];
36
+ data[r] = tmp;
37
+ }
38
+ }
39
+ }
40
+
41
+ // See note [Algorithm of randperm]
42
+ template<typename T, typename scalar_t>
43
+ void randperm_handle_duplicate_keys(T *keys, scalar_t *data, int bits, int64_t n, std::optional<at::Generator> &gen_) {
44
+ auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator());
45
+ int64_t counter_offset = n;
46
+ at::PhiloxCudaState rng_engine_inputs;
47
+ {
48
+ // See Note [Acquire lock when using random generators]
49
+ std::lock_guard<std::mutex> lock(gen->mutex_);
50
+ rng_engine_inputs = gen->philox_cuda_state(counter_offset);
51
+ }
52
+ T mask = static_cast<T>((1UL << bits) - 1);
53
+ randperm_handle_duplicate_keys_kernel<<<(n + 511) / 512, 512, 0, at::cuda::getCurrentCUDAStream()>>>(
54
+ keys, data, mask, n, rng_engine_inputs);
55
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
56
+ }
57
+
58
+ }
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Resize.h ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/EmptyTensor.h>
4
+ #include <ATen/native/ResizeCommon.h>
5
+
6
+ #include <c10/cuda/CUDAGuard.h>
7
+
8
+ namespace at { namespace native {
9
+
10
+ TORCH_CUDA_CPP_API void resize_bytes_cuda(StorageImpl* storage, size_t size_bytes);
11
+
12
+ static inline void maybe_resize_storage_cuda(TensorImpl* self, size_t new_size_bytes) {
13
+ // It does not make sense to try to resize a storage
14
+ // to hold 0 elements, and this can break
15
+ // if storage_offset is positive but
16
+ // new_size is 0, so just bail in that case
17
+ // (same comment is in Resize.h)
18
+ if (self->numel() == 0) {
19
+ return;
20
+ }
21
+
22
+ const Storage &storage = self->unsafe_storage();
23
+ TORCH_CHECK(storage, "Tensor: invalid null storage");
24
+ if (new_size_bytes > storage.nbytes()) {
25
+ resize_bytes_cuda(storage.unsafeGetStorageImpl(), new_size_bytes);
26
+ }
27
+ }
28
+
29
+ inline TensorImpl* resize_impl_cuda_(
30
+ TensorImpl* self,
31
+ IntArrayRef size,
32
+ at::OptionalIntArrayRef stride) {
33
+ if (self->sizes() == size && (!stride || self->strides() == stride)) {
34
+ return self;
35
+ }
36
+ const auto itemsize = self->dtype().itemsize();
37
+ const auto storage_offset = self->storage_offset();
38
+ size_t storage_size = 1;
39
+ if (stride) {
40
+ self->set_sizes_and_strides(size, *stride);
41
+ storage_size = at::detail::computeStorageNbytes(
42
+ size, *stride, itemsize, storage_offset);
43
+ } else {
44
+ self->set_sizes_contiguous(size);
45
+ storage_size = at::detail::computeStorageNbytesContiguous(
46
+ size, itemsize, storage_offset);
47
+ }
48
+ maybe_resize_storage_cuda(self, storage_size);
49
+
50
+ return self;
51
+ }
52
+
53
+ }}
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/RowwiseScaledMM.h ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/TensorBase.h>
3
+ #include <optional>
4
+
5
+
6
+ namespace at::cuda::detail {
7
+ TORCH_API void f8f8bf16_rowwise(
8
+ at::Tensor XQ, // FP8
9
+ at::Tensor WQ, // FP8
10
+ at::Tensor x_scale, // FP32
11
+ at::Tensor w_scale, // FP32
12
+ std::optional<at::Tensor> bias, // BF16
13
+ bool use_fast_accum,
14
+ at::Tensor& out);
15
+ } // at::cuda::detail
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/ScanUtils.cuh ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/NumericUtils.h>
3
+ #include <ATen/core/TensorBase.h>
4
+ #include <ATen/cuda/cub.cuh>
5
+ #include <ATen/cuda/CUDAContext.h>
6
+
7
+ #include <c10/util/Load.h>
8
+ #include <limits>
9
+ #include <cmath>
10
+
11
+ namespace at {
12
+ namespace native {
13
+
14
+ template <typename integer>
15
+ constexpr inline integer ceil_div(integer n, integer m) {
16
+ return (n + m - 1) / m;
17
+ }
18
+
19
+ template <typename integer>
20
+ constexpr inline integer get_log_num_threads_x_inner_scan(integer num_rows, integer row_size) {
21
+ integer log_num_threads_x = 0;
22
+ integer log_num_threads_y = 0;
23
+ while (((integer)1 << log_num_threads_x) < row_size) {
24
+ ++log_num_threads_x;
25
+ }
26
+ while (((integer)1 << log_num_threads_y) < num_rows) {
27
+ ++log_num_threads_y;
28
+ }
29
+ // we want to keep the ratio between the x-threads and y-threads about the same as
30
+ // the ratio between the row_size and num_rows, but the total number of threads in
31
+ // a block should be about 512
32
+ integer diff = log_num_threads_x - log_num_threads_y;
33
+ // 9 is from log2(512)
34
+ log_num_threads_x = ((integer)9 + diff) / (integer)2;
35
+ // I found that in having larger log_num_threads_x can give significant speed up in some cases,
36
+ // but detrimental in another case, so just keep the lower bound to be log2(16) == 4 to make it
37
+ // similar to the previous implementation
38
+ // Keeping the upper bound to be log2(512) == 9 as the maximum number of threads in a block.
39
+ log_num_threads_x = std::min(std::max((integer)4, log_num_threads_x), (integer)9);
40
+ return log_num_threads_x;
41
+ }
42
+
43
+ template<typename scalar_t, typename idx_t, typename BinaryOperation>
44
+ __device__ void binary_op_update(const scalar_t lhs, scalar_t& rhs, const idx_t lhs_idx, idx_t& rhs_idx, BinaryOperation binary_op) {
45
+ if(!at::_isnan(rhs) && (at::_isnan(lhs) || !binary_op(rhs, lhs))) {
46
+ rhs = lhs;
47
+ rhs_idx = lhs_idx;
48
+ }
49
+ }
50
+ /* Perform an inclusive scan along the innermost dimension of a tensor.
51
+ *
52
+ * - num_rows is the size of the flattened outer dimensions;
53
+ * - row_size is the size of the innermost dimension;
54
+ *
55
+ * The outer dimensions of the tensor are considered as a single dimension, i.e. the tensor is
56
+ * considered as having 'num_rows' rows of size 'row_size'.
57
+ * Each thread block processes one or more sets of contiguous rows (processing multiple rows
58
+ * per thread block is quicker than processing a single row, especially for short rows).
59
+ */
60
+ template<typename scalar_t, class BinaryFunction>
61
+ __global__ void tensor_kernel_scan_innermost_dim_with_indices(const scalar_t *self_, scalar_t *values_, int64_t *indices_,
62
+ int num_rows, int row_size,
63
+ const uint32_t num_threads, const uint32_t log_num_threads_x,
64
+ scalar_t init, BinaryFunction binary_op) {
65
+ // dynamic memory allocation for vbuf and ibuf
66
+ alignas(sizeof(double)) extern __shared__ char buf[];
67
+ scalar_t* vbuf = reinterpret_cast<scalar_t*>(buf); // the size is num_threads * 2
68
+ int64_t* ibuf = reinterpret_cast<int64_t*>(vbuf + num_threads * 2);
69
+ const uint32_t num_threads_x = 1 << log_num_threads_x;
70
+ scalar_t* row_buf = vbuf + 2 * num_threads_x * threadIdx.y;
71
+ int64_t* row_idx_buf = ibuf + 2 * num_threads_x * threadIdx.y;
72
+
73
+ for (int block_row = blockIdx.x * blockDim.y;
74
+ block_row < num_rows;
75
+ block_row += blockDim.y * gridDim.x) {
76
+ int row = block_row + threadIdx.y;
77
+ const scalar_t *row_self = self_ + row * row_size;
78
+ scalar_t *row_values = values_ + row * row_size;
79
+ int64_t *row_indices = indices_ + row * row_size;
80
+ scalar_t block_total = init;
81
+ int64_t block_idx_final = 0;
82
+ const bool row_exists = row < num_rows;
83
+ // Perform scan on one block at a time, keeping track of the total value of
84
+ // all blocks processed so far.
85
+ for (int block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) {
86
+ // Load data into shared memory (two values per thread).
87
+ int col1 = block_col + threadIdx.x;
88
+ int col2 = block_col + num_threads_x + threadIdx.x;
89
+ if (row_exists) {
90
+ if (col1 < row_size) {
91
+ row_buf[threadIdx.x] = c10::load(&row_self[col1]);
92
+ row_idx_buf[threadIdx.x] = col1;
93
+ } else {
94
+ row_buf[threadIdx.x] = init;
95
+ // No need to set the index here as the value in init will never be selected
96
+ }
97
+
98
+ if (col2 < row_size) {
99
+ row_buf[num_threads_x + threadIdx.x] = c10::load(&row_self[col2]);
100
+ row_idx_buf[num_threads_x + threadIdx.x] = col2;
101
+ } else {
102
+ row_buf[num_threads_x + threadIdx.x] = init;
103
+ // No need to set the index here as the value in init will never be selected
104
+ }
105
+
106
+ // Add the total value of all previous blocks to the first value of this block.
107
+ if (threadIdx.x == 0) {
108
+ binary_op_update(block_total, row_buf[0], block_idx_final, row_idx_buf[0], binary_op);
109
+ }
110
+ }
111
+ __syncthreads();
112
+
113
+ // Parallel reduction with Sklansky method. The diagram can be seen on this paper:
114
+ // https://research.nvidia.com/publication/single-pass-parallel-prefix-scan-decoupled-look-back
115
+ for (uint32_t s = 1; s <= num_threads_x; s <<= 1) {
116
+ if (row_exists) {
117
+ uint32_t a = (threadIdx.x / s) * (2 * s) + s;
118
+ uint32_t ti = a + (threadIdx.x % s);
119
+ uint32_t si = a - 1;
120
+ binary_op_update(row_buf[si], row_buf[ti], row_idx_buf[si], row_idx_buf[ti], binary_op);
121
+ }
122
+ __syncthreads();
123
+ }
124
+
125
+ // Write back to output.
126
+ if (row_exists) {
127
+ if (col1 < row_size){
128
+ row_values[col1] = row_buf[threadIdx.x];
129
+ row_indices[col1] = row_idx_buf[threadIdx.x];
130
+ }
131
+ if (col2 < row_size) {
132
+ row_values[col2] = row_buf[num_threads_x + threadIdx.x];
133
+ row_indices[col2] = row_idx_buf[num_threads_x + threadIdx.x];
134
+ }
135
+ }
136
+ block_total = row_buf[2 * num_threads_x - 1];
137
+ block_idx_final = row_idx_buf[2 * num_threads_x - 1];
138
+ __syncthreads();
139
+ }
140
+ }
141
+ }
142
+
143
+ /* Perform an inclusive scan along an outer dimension of a tensor.
144
+ *
145
+ * - num_orows is the size of the flattened outer dimensions;
146
+ * - num_irows is the size of the flattened inner dimensions;
147
+ * - row_size is the size of the dimension along which to compute the variance;
148
+ *
149
+ * The dimensions to the outside and inside of the specified dimension are considered as flattened.
150
+ * Thread blocks with the same blockIdx.y process an "outer row" (i.e. an element of the flattened
151
+ * outer dimensions, which contains several "inner rows").
152
+ * Each thread processes a single inner row at a time.
153
+ */
154
+ template<typename scalar_t, class BinaryFunction>
155
+ __global__ void tensor_kernel_scan_outer_dim_with_indices(const scalar_t *self_, scalar_t *values_, int64_t *indices_,
156
+ const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size, scalar_t init, BinaryFunction binary_op) {
157
+ for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
158
+ for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
159
+ const scalar_t *self = self_ + orow * row_size * num_irows + irow;
160
+ scalar_t *values = values_ + orow * row_size * num_irows + irow;
161
+ int64_t *indices = indices_ + orow * row_size * num_irows + irow;
162
+ scalar_t out = init;
163
+ int64_t out_idx = 0;
164
+
165
+ for (auto col = decltype(row_size){0}; col < row_size; ++col) {
166
+ const auto val = c10::load(self);
167
+ if(at::_isnan(val) || (!at::_isnan(out) && binary_op(val, out))) {
168
+ out = val;
169
+ out_idx = col;
170
+ }
171
+ *values = out;
172
+ *indices = out_idx;
173
+ self += num_irows;
174
+ values += num_irows;
175
+ indices += num_irows;
176
+ }
177
+ }
178
+ }
179
+ }
180
+
181
+ inline void check_fits_in_unsigned(int64_t val, const char* name) {
182
+ constexpr auto umax = std::numeric_limits<uint32_t>::max();
183
+ TORCH_CHECK(
184
+ val >= 0 && val <= umax, name, " must fit in a 32-bit uint32_t value");
185
+ }
186
+
187
+
188
+ template<typename scalar_t, class BinaryFunction>
189
+ __host__ void scan_outer_dim_with_indices(
190
+ const TensorBase& self, const TensorBase& values, const TensorBase& indices,
191
+ int dim, scalar_t init, BinaryFunction binary_op) {
192
+ int64_t row_size = self.size(dim);
193
+ auto sizes = self.sizes();
194
+
195
+ // Treat all outer dimensions (i.e. dim_ < dim) as one.
196
+ const int64_t num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + dim);
197
+
198
+ // Treat all inner dimensions (i.e. dim > dimension) as one.
199
+ const int64_t num_irows = c10::multiply_integers(sizes.begin() + dim + 1, sizes.end());
200
+ //for performance reasons, cuda kernels use uint32_t for loops over irows, orows and row,
201
+ //make sure that input is not bigger than supported by uint32_t
202
+ check_fits_in_unsigned(num_irows, "num_irows");
203
+ check_fits_in_unsigned(num_orows, "num_orows");
204
+ check_fits_in_unsigned(row_size, "row_size");
205
+
206
+
207
+ dim3 threads(std::min(512, int(num_irows)));
208
+ int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
209
+ dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x})));
210
+ tensor_kernel_scan_outer_dim_with_indices<scalar_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
211
+ self.const_data_ptr<scalar_t>(), values.mutable_data_ptr<scalar_t>(), indices.mutable_data_ptr<int64_t>(),
212
+ num_orows, num_irows, row_size, init, binary_op);
213
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
214
+ }
215
+
216
+ template <typename scalar_t, class BinaryFunction>
217
+ __host__ void scan_innermost_dim_with_indices(
218
+ const TensorBase& self, const TensorBase& values, const TensorBase& indices,
219
+ scalar_t init, BinaryFunction binary_op) {
220
+ int ndim = self.dim();
221
+ // Treat all outer dimensions as a single dimension.
222
+ int row_size = self.size(ndim - 1);
223
+ int num_rows = self.numel() / row_size;
224
+
225
+ // assuming max_num_threads per block is 512
226
+ const uint32_t num_threads = 512;
227
+ const uint32_t log_num_threads_x = get_log_num_threads_x_inner_scan<uint32_t>(num_rows, row_size);
228
+ const uint32_t num_threads_x = (1 << log_num_threads_x);
229
+ const uint32_t num_threads_y = num_threads / num_threads_x;
230
+ dim3 threads(num_threads_x, num_threads_y);
231
+ dim3 grid(std::min(at::cuda::getCurrentDeviceProperties()->maxGridSize[0], ceil_div(num_rows, int(threads.y))));
232
+
233
+ const uint32_t mem_size = 2 * num_threads * (sizeof(scalar_t) + sizeof(int64_t));
234
+ tensor_kernel_scan_innermost_dim_with_indices<scalar_t><<<grid, threads, mem_size,
235
+ at::cuda::getCurrentCUDAStream()>>>(
236
+ self.const_data_ptr<scalar_t>(), values.mutable_data_ptr<scalar_t>(), indices.mutable_data_ptr<int64_t>(),
237
+ num_rows, row_size, num_threads, log_num_threads_x, init, binary_op);
238
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
239
+ }
240
+
241
+ template<typename scalar_t, typename BinaryFunction>
242
+ void scan_dim_with_indices(const TensorBase& self, const TensorBase& values, const TensorBase& indices, //int64_t dim) {
243
+ int64_t dim, scalar_t init, BinaryFunction binary_op) {
244
+ int ndim = self.dim();
245
+ auto self_ = self.expect_contiguous();
246
+ TORCH_INTERNAL_ASSERT(values.is_contiguous() && indices.is_contiguous());
247
+ if (dim == ndim - 1) {
248
+ scan_innermost_dim_with_indices<scalar_t>(*self_, values, indices, init, binary_op);
249
+ } else {
250
+ scan_outer_dim_with_indices<scalar_t>(*self_, values, indices, dim, init, binary_op);
251
+ }
252
+ }
253
+
254
+ // TODO: The implementation of `tensor_kernel_scan_outer_dim` and
255
+ // `tensor_kernel_scan_innermost_dim` is similar to
256
+ // `tensor_kernel_scan_outer_dim_with_indices`
257
+ // `tensor_kernel_scan_outer_dim_with_indices` and should be refactored to
258
+ // remove the duplication.
259
+
260
+ /* Perform an inclusive scan along an outer dimension of a tensor.
261
+ *
262
+ * - num_orows is the size of the flattened outer dimensions;
263
+ * - num_irows is the size of the flattened inner dimensions;
264
+ * - row_size is the size of the dimension along which to scan;
265
+ *
266
+ * The dimensions to the outside and inside of the specified dimension are considered as flattened.
267
+ * Thread blocks with the same blockIdx.y process an "outer row" (i.e. an element of the flattened
268
+ * outer dimensions, which contains several "inner rows").
269
+ * Each thread processes a single inner row at a time.
270
+ */
271
+ template<typename scalar_t, class BinaryOp>
272
+ __global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, const scalar_t *src_,
273
+ const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size,
274
+ const scalar_t init, BinaryOp binary_op)
275
+ {
276
+ for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
277
+ for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
278
+ const scalar_t *src = src_ + orow * row_size * num_irows + irow;
279
+ scalar_t *tgt = tgt_ + orow * row_size * num_irows + irow;
280
+ scalar_t acc = init;
281
+
282
+ for (uint32_t col = 0; col < row_size; ++col) {
283
+ acc = binary_op(acc, c10::load(src));
284
+ *tgt = acc;
285
+
286
+ src += num_irows;
287
+ tgt += num_irows;
288
+ }
289
+ }
290
+ }
291
+ }
292
+
293
+ /* Perform an inclusive scan along the innermost dimension of a tensor.
294
+ *
295
+ * - num_rows is the size of the flattened outer dimensions;
296
+ * - row_size is the size of the innermost dimension;
297
+ *
298
+ * The outer dimensions of the tensor are considered as a single dimension, i.e. the tensor is
299
+ * considered as having 'num_rows' rows of size 'row_size'.
300
+ * Each thread block processes one or more sets of contiguous rows (processing multiple rows
301
+ * per thread block is quicker than processing a single row, especially for short rows).
302
+ */
303
+ template<typename T, class BinaryFunction>
304
+ __device__ void tensor_kernel_scan_innermost_dim_impl(T* row_buf, T *tgt_, const T *src_,
305
+ const uint32_t num_rows, const uint32_t row_size,
306
+ const uint32_t log_num_threads_x,
307
+ T init, BinaryFunction binary_op){
308
+ const uint32_t num_threads_x = 1 << log_num_threads_x;
309
+ for (uint32_t block_row = blockIdx.x * blockDim.y;
310
+ block_row < num_rows;
311
+ block_row += blockDim.y * gridDim.x) {
312
+ uint32_t row = block_row + threadIdx.y;
313
+ T block_total = init;
314
+
315
+ const T *row_src = src_ + row * row_size;
316
+ T *row_tgt = tgt_ + row * row_size;
317
+ const bool row_exists = row < num_rows;
318
+
319
+ // Perform scan on one block at a time, keeping track of the total value of
320
+ // all blocks processed so far.
321
+ for (uint32_t block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) {
322
+ // Load data into shared memory (two values per thread).
323
+ uint32_t col1 = block_col + threadIdx.x;
324
+ uint32_t col2 = block_col + num_threads_x + threadIdx.x;
325
+ if (row_exists) {
326
+ if (col1 < row_size) {
327
+ row_buf[threadIdx.x] = row_src[col1];
328
+ } else {
329
+ row_buf[threadIdx.x] = init;
330
+ }
331
+
332
+ if (col2 < row_size) {
333
+ row_buf[num_threads_x + threadIdx.x] = row_src[col2];
334
+ } else {
335
+ row_buf[num_threads_x + threadIdx.x] = init;
336
+ }
337
+
338
+ // Add the total value of all previous blocks to the first value of this block.
339
+ if (threadIdx.x == 0) {
340
+ row_buf[0] = binary_op(row_buf[0], block_total);
341
+ }
342
+ }
343
+ __syncthreads();
344
+
345
+ // Parallel reduction with Sklansky method. The diagram can be seen on this paper:
346
+ // https://research.nvidia.com/publication/single-pass-parallel-prefix-scan-decoupled-look-back
347
+ for (uint32_t m = 0; m <= log_num_threads_x; ++m) {
348
+ if (row_exists) {
349
+ uint32_t s = 1 << m; // s = 2 ^ m
350
+ uint32_t a = ((threadIdx.x >> m) << (m + 1)) | s; // a = (threadIdx.x / s) * (2 * s) + s
351
+ uint32_t ti = a + (threadIdx.x % s);
352
+ uint32_t si = a - 1;
353
+ row_buf[ti] = binary_op(row_buf[ti], row_buf[si]);
354
+ }
355
+ __syncthreads();
356
+ }
357
+
358
+ // Write back to output.
359
+ if (row_exists) {
360
+ if (col1 < row_size) row_tgt[col1] = row_buf[threadIdx.x];
361
+ if (col2 < row_size) row_tgt[col2] = row_buf[num_threads_x + threadIdx.x];
362
+ }
363
+ block_total = row_buf[2 * num_threads_x - 1];
364
+ __syncthreads();
365
+ }
366
+ }
367
+ }
368
+
369
+ template <
370
+ typename T,
371
+ class BinaryFunction>
372
+ __global__ void tensor_kernel_scan_innermost_dim(
373
+ T* tgt_,
374
+ const T* src_,
375
+ const uint32_t num_rows,
376
+ const uint32_t row_size,
377
+ const uint32_t log_num_threads_x,
378
+ T init,
379
+ BinaryFunction binary_op) {
380
+ alignas(sizeof(double)) extern __shared__ char sbuf[];
381
+ T* sbuf2 = reinterpret_cast<T*>(sbuf);
382
+ const uint32_t num_threads_x = 1 << log_num_threads_x;
383
+ T* row_buf = reinterpret_cast<T*>(sbuf2 + num_threads_x * 2 * threadIdx.y);
384
+
385
+ tensor_kernel_scan_innermost_dim_impl<T>(
386
+ row_buf, tgt_, src_, num_rows, row_size, log_num_threads_x, init, binary_op);
387
+ }
388
+
389
+
390
+ template<typename scalar_t, class BinaryFunction>
391
+ __host__ void scan_outer_dim(const TensorBase& self, const TensorBase& result,
392
+ int dim, scalar_t init, BinaryFunction binary_op) {
393
+ const int64_t row_size = self.size(dim);
394
+ auto sizes = self.sizes();
395
+
396
+ // Treat all outer dimensions (i.e. dim_ < dim) as one.
397
+ const int64_t num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + dim);
398
+
399
+ // Treat all inner dimensions (i.e. dim > dimension) as one.
400
+ const int64_t num_irows = c10::multiply_integers(sizes.begin() + dim + 1, sizes.end());
401
+
402
+ dim3 threads(std::min(512, int(num_irows)));
403
+ int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
404
+ dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x})));
405
+
406
+ check_fits_in_unsigned(num_irows, "num_irows");
407
+ check_fits_in_unsigned(num_orows, "num_orows");
408
+ check_fits_in_unsigned(row_size, "row_size");
409
+
410
+ tensor_kernel_scan_outer_dim<scalar_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
411
+ result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),
412
+ num_orows, num_irows, row_size, init, binary_op);
413
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
414
+ }
415
+
416
+ template <typename scalar_t, class BinaryFunction>
417
+ void scan_innermost_dim(const TensorBase& self, const TensorBase& result,
418
+ scalar_t init, BinaryFunction binary_op) {
419
+ int64_t ndim = self.dim();
420
+ // Treat all outer dimensions as a single dimension.
421
+ int64_t row_size = self.size(ndim - 1);
422
+ int64_t num_rows = self.numel() / row_size;
423
+
424
+ // assuming max_num_threads per block is 512
425
+ const uint32_t num_threads = 512;
426
+ const uint32_t log_num_threads_x = get_log_num_threads_x_inner_scan<uint32_t>(num_rows, row_size);
427
+ const uint32_t num_threads_x = (1 << log_num_threads_x);
428
+ const uint32_t num_threads_y = num_threads / num_threads_x;
429
+ dim3 threads(num_threads_x, num_threads_y);
430
+ int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
431
+ dim3 grid(std::min(maxGridDim, ceil_div(num_rows, int64_t{threads.y})));
432
+
433
+ check_fits_in_unsigned(num_rows, "Number of rows (self.numel()/self.size(self.dim()-1))");
434
+ check_fits_in_unsigned(row_size, "row_size");
435
+
436
+ tensor_kernel_scan_innermost_dim<scalar_t><<<grid, threads, num_threads * 2 * sizeof(scalar_t),
437
+ at::cuda::getCurrentCUDAStream()>>>(
438
+ result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),
439
+ num_rows, row_size, log_num_threads_x, init, binary_op);
440
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
441
+ }
442
+
443
+ template<typename scalar_t, typename BinaryFunction>
444
+ void scan_dim(const TensorBase& self, const TensorBase& result,
445
+ int64_t dim, scalar_t init, BinaryFunction binary_op) {
446
+ int ndim = self.dim();
447
+ auto self_ = self.expect_contiguous();
448
+ TORCH_INTERNAL_ASSERT(result.is_contiguous());
449
+
450
+ if (self.numel() == self.size(dim)) {
451
+ cuda::cub::inclusive_scan(self_->const_data_ptr<scalar_t>(), result.mutable_data_ptr<scalar_t>(), binary_op, self.numel());
452
+ } else if (dim == ndim - 1) {
453
+ scan_innermost_dim<scalar_t>(*self_, result, init, binary_op);
454
+ } else {
455
+ scan_outer_dim<scalar_t>(*self_, result, dim, init, binary_op);
456
+ }
457
+ }
458
+
459
+ }} // namespace at::native