koichi12 commited on
Commit
9cc5909
·
verified ·
1 Parent(s): edce735

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/torch/include/ATen/mps/EmptyTensor.h +29 -0
  2. .venv/lib/python3.11/site-packages/torch/include/ATen/mps/IndexKernels.h +535 -0
  3. .venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSAllocator.h +403 -0
  4. .venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSAllocatorInterface.h +64 -0
  5. .venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSDevice.h +84 -0
  6. .venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSEvent.h +100 -0
  7. .venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSGeneratorImpl.h +52 -0
  8. .venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSGuardImpl.h +179 -0
  9. .venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSHooks.h +60 -0
  10. .venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSProfiler.h +402 -0
  11. .venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSStream.h +133 -0
  12. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/CatKernel.h +12 -0
  13. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Loops.h +394 -0
  14. .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/ReduceUtils.h +238 -0
  15. .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/AffineQuantizer.h +130 -0
  16. .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/AffineQuantizerBase.h +47 -0
  17. .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/ConvUtils.h +62 -0
  18. .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/Copy.h +10 -0
  19. .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/FakeQuantAffine.h +67 -0
  20. .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/IndexKernel.h +14 -0
  21. .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/PackedParams.h +147 -0
  22. .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/BinaryOps.h +8 -0
  23. .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/EmbeddingPackedParams.h +29 -0
  24. .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/OnednnUtils.h +457 -0
  25. .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/QnnpackUtils.h +527 -0
  26. .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/QuantUtils.h +239 -0
  27. .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/QuantizedOps.h +258 -0
  28. .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/RuyUtils.h +21 -0
  29. .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/XnnpackUtils.h +335 -0
  30. .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/conv_serialization.h +414 -0
  31. .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/fbgemm_utils.h +413 -0
  32. .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/init_qnnpack.h +13 -0
  33. .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/qembeddingbag.h +34 -0
  34. .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/qembeddingbag_prepack.h +13 -0
  35. .venv/lib/python3.11/site-packages/torch/include/ATen/native/transformers/attention.h +72 -0
  36. .venv/lib/python3.11/site-packages/torch/include/ATen/native/transformers/sdp_utils_cpp.h +566 -0
  37. .venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/Factory.h +24 -0
  38. .venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/ParamUtils.h +42 -0
  39. .venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/ParamsHash.h +104 -0
  40. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_empty_per_channel_affine_quantized.h +113 -0
  41. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_addcmul_native.h +35 -0
  42. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_expm1_ops.h +50 -0
  43. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_sparse_mask_projection_ops.h +39 -0
  44. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_sparse_sum_backward_ops.h +39 -0
  45. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_test_autograd_multiple_dispatch_view_copy_ops.h +39 -0
  46. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_unsafe_masked_index_put_accumulate_ops.h +28 -0
  47. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/alias_ops.h +28 -0
  48. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/arcsinh_compositeimplicitautograd_dispatch.h +26 -0
  49. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/block_diag_native.h +22 -0
  50. .venv/lib/python3.11/site-packages/torch/include/ATen/ops/cat_compositeexplicitautogradnonfunctional_dispatch.h +23 -0
.venv/lib/python3.11/site-packages/torch/include/ATen/mps/EmptyTensor.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2022 Apple Inc.
2
+
3
+ #pragma once
4
+ #include <ATen/core/TensorBase.h>
5
+
6
+ namespace at::detail {
7
+
8
+ C10_EXPORT TensorBase empty_mps(
9
+ IntArrayRef size,
10
+ std::optional<ScalarType> dtype_opt,
11
+ std::optional<Layout> layout_opt,
12
+ std::optional<Device> device_opt,
13
+ std::optional<bool> pin_memory_opt,
14
+ std::optional<c10::MemoryFormat> memory_format_opt);
15
+ C10_EXPORT TensorBase empty_mps(
16
+ IntArrayRef size, const TensorOptions &options);
17
+
18
+ C10_EXPORT TensorBase empty_strided_mps(
19
+ IntArrayRef size,
20
+ IntArrayRef stride,
21
+ ScalarType dtype,
22
+ std::optional<Device> device_opt);
23
+
24
+ C10_EXPORT TensorBase empty_strided_mps(
25
+ IntArrayRef size,
26
+ IntArrayRef stride,
27
+ const TensorOptions &options);
28
+
29
+ } // namespace at::detail
.venv/lib/python3.11/site-packages/torch/include/ATen/mps/IndexKernels.h ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ namespace at::mps {
4
+
5
+ static const char * indexing_metal_shaders = R"INDEX_METAL(
6
+ #include <metal_stdlib>
7
+ #include <metal_atomic>
8
+
9
+ using namespace metal;
10
+
11
+ struct IndexAB {
12
+ constant int64_t* indexArray;
13
+ };
14
+
15
+ template<typename T, typename OffsetsT>
16
+ kernel void index_select(
17
+ constant IndexAB * indexAB [[buffer(0)]],
18
+ constant void * indexSizes [[buffer(1)]],
19
+ constant void * indexStrides [[buffer(2)]],
20
+ constant OffsetsT * offsets [[buffer(3)]],
21
+ constant void * inputData [[buffer(4)]],
22
+ device void * outputData [[buffer(5)]],
23
+ constant uint32_t & num_indices [[buffer(6)]],
24
+ uint thread_index [[thread_position_in_grid]]) {
25
+ constant int64_t * index_sizes = (constant int64_t *)indexSizes;
26
+ constant int64_t * index_strides = (constant int64_t *)indexStrides;
27
+ int64_t offset = 0;
28
+ for (uint32_t i = 0; i < num_indices; i++) {
29
+ constant int64_t* indexArray = indexAB[i].indexArray;
30
+ int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
31
+ if (index < 0) {
32
+ index += index_sizes[i];
33
+ }
34
+ offset += index * index_strides[i];
35
+ }
36
+ device T * out = (device T*)((device char*)outputData + offsets[thread_index].x);
37
+ constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y + offset);
38
+ *out = *in;
39
+ }
40
+
41
+ template<typename T, typename OffsetsT>
42
+ void index_put_impl(
43
+ constant IndexAB * indexAB,
44
+ constant int64_t * index_sizes,
45
+ constant int64_t * index_strides,
46
+ constant OffsetsT * offsets,
47
+ constant void * inputData,
48
+ device void * outputData,
49
+ constant uint32_t & num_indices,
50
+ uint thread_index) {
51
+ int64_t offset = 0;
52
+ for (uint32_t i = 0; i < num_indices; i++) {
53
+ constant int64_t* indexArray = indexAB[i].indexArray;
54
+ int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
55
+
56
+ if (index < 0) {
57
+ index += index_sizes[i];
58
+ }
59
+ offset += index * index_strides[i];
60
+ }
61
+ device T * out = (device T*)((device char*)outputData + offsets[thread_index].x + offset);
62
+ constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y);
63
+ *out = *in;
64
+ }
65
+
66
+ template<typename T, typename OffsetsT>
67
+ kernel void index_put_serial(
68
+ constant IndexAB * indexAB [[buffer(0)]],
69
+ constant void * indexSizes [[buffer(1)]],
70
+ constant void * indexStrides [[buffer(2)]],
71
+ constant OffsetsT * offsets [[buffer(3)]],
72
+ constant void * inputData [[buffer(4)]],
73
+ device void * outputData [[buffer(5)]],
74
+ constant uint32_t & num_indices [[buffer(6)]],
75
+ constant uint * numIters [[buffer(7)]],
76
+ uint thread_index [[thread_position_in_grid]]) {
77
+
78
+ constant int64_t * index_sizes = (constant int64_t *)indexSizes;
79
+ constant int64_t * index_strides = (constant int64_t *)indexStrides;
80
+
81
+ for (uint iter_i = 0; iter_i < *numIters; iter_i++) {
82
+ index_put_impl<T>(indexAB, index_sizes, index_strides, offsets, inputData, outputData, num_indices, iter_i);
83
+ }
84
+ }
85
+
86
+ template<typename T, typename OffsetsT>
87
+ kernel void index_put(
88
+ constant IndexAB * indexAB [[buffer(0)]],
89
+ constant void * indexSizes [[buffer(1)]],
90
+ constant void * indexStrides [[buffer(2)]],
91
+ constant OffsetsT * offsets [[buffer(3)]],
92
+ constant void * inputData [[buffer(4)]],
93
+ device void * outputData [[buffer(5)]],
94
+ constant uint32_t & num_indices [[buffer(6)]],
95
+ uint thread_index [[thread_position_in_grid]]) {
96
+
97
+ constant int64_t * index_sizes = (constant int64_t *)indexSizes;
98
+ constant int64_t * index_strides = (constant int64_t *)indexStrides;
99
+ index_put_impl<T>(indexAB, index_sizes, index_strides, offsets, inputData, outputData, num_indices, thread_index);
100
+ }
101
+
102
+ #define REGISTER_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
103
+ template \
104
+ [[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
105
+ kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
106
+ constant IndexAB * indexAB [[buffer(0)]], \
107
+ constant void * indexSizes [[buffer(1)]], \
108
+ constant void * indexStrides [[buffer(2)]], \
109
+ constant IDX_DTYPE * offsets [[buffer(3)]], \
110
+ constant void * inputData [[buffer(4)]], \
111
+ device void * outputData [[buffer(5)]], \
112
+ constant uint32_t & num_indices [[buffer(6)]], \
113
+ uint thread_index [[thread_position_in_grid]]);
114
+
115
+ #define REGISTER_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \
116
+ REGISTER_INDEX_OP(8bit, idx32, char, INDEX_OP_TYPE, uint3); \
117
+ REGISTER_INDEX_OP(8bit, idx64, char, INDEX_OP_TYPE, ulong3); \
118
+ REGISTER_INDEX_OP(16bit, idx32, short, INDEX_OP_TYPE, uint3); \
119
+ REGISTER_INDEX_OP(16bit, idx64, short, INDEX_OP_TYPE, ulong3); \
120
+ REGISTER_INDEX_OP(32bit, idx32, int, INDEX_OP_TYPE, uint3); \
121
+ REGISTER_INDEX_OP(32bit, idx64, int, INDEX_OP_TYPE, ulong3); \
122
+ REGISTER_INDEX_OP(64bit, idx32, long, INDEX_OP_TYPE, uint3); \
123
+ REGISTER_INDEX_OP(64bit, idx64, long, INDEX_OP_TYPE, ulong3);
124
+
125
+ REGISTER_INDEX_OP_ALL_DTYPES(select);
126
+ REGISTER_INDEX_OP_ALL_DTYPES(put);
127
+
128
+ #define REGISTER_SINGLE_THREADED_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
129
+ template \
130
+ [[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
131
+ kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
132
+ constant IndexAB * indexAB [[buffer(0)]], \
133
+ constant void * indexSizes [[buffer(1)]], \
134
+ constant void * indexStrides [[buffer(2)]], \
135
+ constant IDX_DTYPE * offsets [[buffer(3)]], \
136
+ constant void * inputData [[buffer(4)]], \
137
+ device void * outputData [[buffer(5)]], \
138
+ constant uint32_t & num_indices [[buffer(6)]], \
139
+ constant uint * numIters [[buffer(7)]], \
140
+ uint thread_index [[thread_position_in_grid]]);
141
+
142
+ #define REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \
143
+ REGISTER_SINGLE_THREADED_INDEX_OP(8bit, idx32, char, INDEX_OP_TYPE, uint3); \
144
+ REGISTER_SINGLE_THREADED_INDEX_OP(8bit, idx64, char, INDEX_OP_TYPE, ulong3); \
145
+ REGISTER_SINGLE_THREADED_INDEX_OP(16bit, idx32, short, INDEX_OP_TYPE, uint3); \
146
+ REGISTER_SINGLE_THREADED_INDEX_OP(16bit, idx64, short, INDEX_OP_TYPE, ulong3); \
147
+ REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx32, int, INDEX_OP_TYPE, uint3); \
148
+ REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx64, int, INDEX_OP_TYPE, ulong3); \
149
+ REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx32, long, INDEX_OP_TYPE, uint3); \
150
+ REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx64, long, INDEX_OP_TYPE, ulong3);
151
+
152
+ REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(put_serial);
153
+
154
+ template<typename StridesT, typename DataT>
155
+ kernel void kernel_index_offsets(constant StridesT * strides [[buffer(0)]],
156
+ device DataT * data_offsets [[buffer(1)]],
157
+ constant uint * iter_shape [[buffer(2)]],
158
+ constant uint & num_dimensions [[buffer(3)]],
159
+ uint thread_index [[thread_position_in_grid]]) {
160
+ data_offsets[thread_index] = 0;
161
+ uint32_t idx = thread_index;
162
+ for (uint32_t dim = 0; dim < num_dimensions; dim++) {
163
+ uint32_t remainder = idx % iter_shape[dim];
164
+ idx /= iter_shape[dim];
165
+
166
+ data_offsets[thread_index] += remainder * DataT(strides[dim]);
167
+ }
168
+ }
169
+
170
+ template
171
+ [[host_name("kernel_index_offsets_32")]]
172
+ kernel void kernel_index_offsets<packed_uint3, uint3>(
173
+ constant packed_uint3 * strides [[buffer(0)]],
174
+ device uint3 * data_offsets [[buffer(1)]],
175
+ constant uint * iter_shape [[buffer(2)]],
176
+ constant uint & num_dimensions [[buffer(3)]],
177
+ uint thread_index [[thread_position_in_grid]]);
178
+
179
+ template
180
+ [[host_name("kernel_index_offsets_64")]]
181
+ kernel void kernel_index_offsets<packed_uint3, ulong3>(
182
+ constant packed_uint3 * strides [[buffer(0)]],
183
+ device ulong3 * data_offsets [[buffer(1)]],
184
+ constant uint * iter_shape [[buffer(2)]],
185
+ constant uint & num_dimensions [[buffer(3)]],
186
+ uint thread_index [[thread_position_in_grid]]);
187
+
188
+ template<typename T, typename E, typename OffsetsT>
189
+ kernel void index_put_accumulate_native_dtypes(
190
+ constant IndexAB * indexAB [[buffer(0)]],
191
+ constant void * indexSizes [[buffer(1)]],
192
+ constant void * indexStrides [[buffer(2)]],
193
+ constant OffsetsT * offsets [[buffer(3)]],
194
+ constant void * inputData [[buffer(4)]],
195
+ device void * outputData [[buffer(5)]],
196
+ constant uint32_t & num_indices [[buffer(6)]],
197
+ uint thread_index [[thread_position_in_grid]]) {
198
+ constant int64_t * index_sizes = (constant int64_t *)indexSizes;
199
+ constant int64_t * index_strides = (constant int64_t *)indexStrides;
200
+ int64_t offset = 0;
201
+ for (uint32_t i = 0; i < num_indices; i++) {
202
+ constant int64_t* indexArray = indexAB[i].indexArray;
203
+ int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
204
+ if (index < 0) {
205
+ index += index_sizes[i];
206
+ }
207
+ offset += index * index_strides[i];
208
+ }
209
+ device T * out = (device T*)((device char*)outputData + offsets[thread_index].x + offset);
210
+ constant E * in = (constant E*)((constant char*)inputData + offsets[thread_index].y);
211
+ atomic_fetch_add_explicit(out, *in, memory_order_relaxed);
212
+ }
213
+
214
+ template<typename T>
215
+ __attribute__((__always_inline__)) void atomic_fetch_add_relaxed(device void * addr, T value) {
216
+ device atomic_uint* uintAddr = (device atomic_uint*)addr;
217
+ uint expected = atomic_load_explicit(uintAddr, memory_order_relaxed);
218
+ T updated = as_type<T>(expected) + value;
219
+ while (!atomic_compare_exchange_weak_explicit(uintAddr, &expected, as_type<uint>(updated), memory_order_relaxed, memory_order_relaxed)) {
220
+ updated = as_type<T>(expected) + value;
221
+ }
222
+ }
223
+
224
+ template<typename T, typename OffsetsT>
225
+ kernel void atomic_index_put_accumulate(
226
+ constant IndexAB * indexAB [[buffer(0)]],
227
+ constant void * indexSizes [[buffer(1)]],
228
+ constant void * indexStrides [[buffer(2)]],
229
+ constant OffsetsT * offsets [[buffer(3)]],
230
+ constant void * inputData [[buffer(4)]],
231
+ device void * outputData [[buffer(5)]],
232
+ constant uint32_t & num_indices [[buffer(6)]],
233
+ uint thread_index [[thread_position_in_grid]]) {
234
+ constant int64_t * index_sizes = (constant int64_t *)indexSizes;
235
+ constant int64_t * index_strides = (constant int64_t *)indexStrides;
236
+ int64_t offset = 0;
237
+ for (uint32_t i = 0; i < num_indices; i++) {
238
+ constant int64_t* indexArray = indexAB[i].indexArray;
239
+ int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
240
+ if (index < 0) {
241
+ index += index_sizes[i];
242
+ }
243
+ offset += index * index_strides[i];
244
+ }
245
+ device void * out = (device void*)((device char*)outputData + offsets[thread_index].x + offset);
246
+ constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y);
247
+ atomic_fetch_add_relaxed<T>(out, *in);
248
+ }
249
+
250
+ template
251
+ [[host_name("index_put_accumulate_32bit_float_idx32")]]
252
+ kernel void atomic_index_put_accumulate<float, uint3>(
253
+ constant IndexAB * indexAB [[buffer(0)]],
254
+ constant void * indexSizes [[buffer(1)]],
255
+ constant void * indexStrides [[buffer(2)]],
256
+ constant uint3 * offsets [[buffer(3)]],
257
+ constant void * inputData [[buffer(4)]],
258
+ device void * outputData [[buffer(5)]],
259
+ constant uint32_t & num_indices [[buffer(6)]],
260
+ uint thread_index [[thread_position_in_grid]]);
261
+
262
+ template
263
+ [[host_name("index_put_accumulate_32bit_float_idx64")]]
264
+ kernel void atomic_index_put_accumulate<float, ulong3>(
265
+ constant IndexAB * indexAB [[buffer(0)]],
266
+ constant void * indexSizes [[buffer(1)]],
267
+ constant void * indexStrides [[buffer(2)]],
268
+ constant ulong3 * offsets [[buffer(3)]],
269
+ constant void * inputData [[buffer(4)]],
270
+ device void * outputData [[buffer(5)]],
271
+ constant uint32_t & num_indices [[buffer(6)]],
272
+ uint thread_index [[thread_position_in_grid]]);
273
+
274
+ template
275
+ [[host_name("index_put_accumulate_32bit_int_idx32")]]
276
+ kernel void index_put_accumulate_native_dtypes<atomic_int, int, uint3>(
277
+ constant IndexAB * indexAB [[buffer(0)]],
278
+ constant void * indexSizes [[buffer(1)]],
279
+ constant void * indexStrides [[buffer(2)]],
280
+ constant uint3 * offsets [[buffer(3)]],
281
+ constant void * inputData [[buffer(4)]],
282
+ device void * outputData [[buffer(5)]],
283
+ constant uint32_t & num_indices [[buffer(6)]],
284
+ uint thread_index [[thread_position_in_grid]]);
285
+
286
+ template
287
+ [[host_name("index_put_accumulate_32bit_int_idx64")]]
288
+ kernel void index_put_accumulate_native_dtypes<atomic_int, int, ulong3>(
289
+ constant IndexAB * indexAB [[buffer(0)]],
290
+ constant void * indexSizes [[buffer(1)]],
291
+ constant void * indexStrides [[buffer(2)]],
292
+ constant ulong3 * offsets [[buffer(3)]],
293
+ constant void * inputData [[buffer(4)]],
294
+ device void * outputData [[buffer(5)]],
295
+ constant uint32_t & num_indices [[buffer(6)]],
296
+ uint thread_index [[thread_position_in_grid]]);
297
+ )INDEX_METAL";
298
+
299
+ static const char *SCATTER_OPS_TEMPLATE = R"METAL_SCATTER(
300
+ struct __attribute__ ((packed)) packed_uint5{{
301
+ uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u;
302
+ }};
303
+
304
+ template<typename Y, typename X>
305
+ Y cast(const X x);
306
+
307
+ template<>
308
+ {1} cast<{1}, {0}>(const {0} x) {{
309
+ return {2};
310
+ }}
311
+
312
+ kernel void scatter_kernel_5(uint linear_index [[thread_position_in_grid]],
313
+ constant void * src_ [[buffer(0)]],
314
+ device void * dst_ [[buffer(1)]],
315
+ constant packed_uint5 & size [[buffer(2)]],
316
+ constant packed_uint5 & stride [[buffer(3)]],
317
+ constant uint32_t & numel [[buffer(4)]]) {{
318
+ if (linear_index >= numel) return;
319
+
320
+ constant {0} * src = (constant {0} *)src_;
321
+ device {1} * dst = (device {1} *)dst_;
322
+
323
+ packed_uint5 local_index;
324
+ local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x;
325
+ local_index.y = linear_index / (size.u * size.w * size.z) % size.y;
326
+ local_index.z = linear_index / (size.u * size.w) % size.z;
327
+ local_index.w = linear_index / size.u % size.w;
328
+ local_index.u = linear_index % size.u;
329
+
330
+ packed_uint5 strided_index;
331
+ strided_index.x = local_index.x * stride.x;
332
+ strided_index.y = local_index.y * stride.y;
333
+ strided_index.z = local_index.z * stride.z;
334
+ strided_index.w = local_index.w * stride.w;
335
+ strided_index.u = local_index.u * stride.u;
336
+
337
+ dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u] = cast<{1}>(src[linear_index]);
338
+ }}
339
+
340
+ kernel void scatter_kernel_4(uint linear_index [[thread_position_in_grid]],
341
+ constant void * src_ [[buffer(0)]],
342
+ device void * dst_ [[buffer(1)]],
343
+ constant packed_uint4 & size [[buffer(2)]],
344
+ constant packed_uint4 & stride [[buffer(3)]],
345
+ constant uint32_t & numel [[buffer(4)]]) {{
346
+ if (linear_index >= numel) return;
347
+
348
+ constant {0} * src = (constant {0} *)src_;
349
+ device {1} * dst = (device {1} *)dst_;
350
+
351
+ packed_uint4 local_index;
352
+ local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
353
+ local_index.y = linear_index / (size[3] * size[2]) % size[1];
354
+ local_index.z = linear_index / size[3] % size[2];
355
+ local_index.w = linear_index % size[3];
356
+
357
+ const packed_uint4 strided_index = local_index * stride;
358
+ dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w] = cast<{1}>(src[linear_index]);
359
+ }}
360
+
361
+ kernel void scatter_kernel_3(uint linear_index [[thread_position_in_grid]],
362
+ constant void * src_ [[buffer(0)]],
363
+ device void * dst_ [[buffer(1)]],
364
+ constant packed_uint3 & size [[buffer(2)]],
365
+ constant packed_uint3 & stride [[buffer(3)]],
366
+ constant uint32_t & numel [[buffer(4)]]) {{
367
+ if (linear_index >= numel) return;
368
+
369
+ constant {0} * src = (constant {0} *)src_;
370
+ device {1} * dst = (device {1} *)dst_;
371
+
372
+ packed_uint3 local_index;
373
+ local_index.x = linear_index / (size[2] * size[1]) % size[0];
374
+ local_index.y = linear_index / size[2] % size[1];
375
+ local_index.z = linear_index % size[2];
376
+
377
+ const packed_uint3 strided_index = local_index * stride;
378
+ dst[strided_index.x + strided_index.y + strided_index.z] = cast<{1}>(src[linear_index]);
379
+ }}
380
+
381
+ kernel void scatter_kernel_2(uint linear_index [[thread_position_in_grid]],
382
+ constant void * src_ [[buffer(0)]],
383
+ device void * dst_ [[buffer(1)]],
384
+ constant packed_uint2 & size [[buffer(2)]],
385
+ constant packed_uint2 & stride [[buffer(3)]],
386
+ constant uint32_t & numel [[buffer(4)]]) {{
387
+ if (linear_index >= numel) return;
388
+
389
+ constant {0} * src = (constant {0} *)src_;
390
+ device {1} * dst = (device {1} *)dst_;
391
+
392
+ packed_uint2 local_index;
393
+ local_index.x = linear_index / size[1] % size[0];
394
+ local_index.y = linear_index % size[1];
395
+
396
+ const packed_uint2 strided_index = local_index * stride;
397
+ dst[strided_index.x + strided_index.y] = cast<{1}>(src[linear_index]);
398
+ }}
399
+
400
+ kernel void scatter_kernel_1(uint linear_index [[thread_position_in_grid]],
401
+ constant void * src_ [[buffer(0)]],
402
+ device void * dst_ [[buffer(1)]],
403
+ constant int & size [[buffer(2)]],
404
+ constant int & stride [[buffer(3)]],
405
+ constant uint32_t & numel [[buffer(4)]]) {{
406
+ if (linear_index >= numel) return;
407
+
408
+ constant {0} * src = (constant {0} *)src_;
409
+ device {1} * dst = (device {1} *)dst_;
410
+
411
+ const int local_index = linear_index % size;
412
+ const int strided_index = local_index * stride;
413
+ dst[strided_index] = cast<{1}>(src[linear_index]);
414
+ }}
415
+ )METAL_SCATTER";
416
+
417
+ static const char *GATHER_OPS_TEMPLATE = R"METAL_GATHER(
418
+ struct __attribute__ ((packed)) packed_uint5{{
419
+ uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u;
420
+ }};
421
+
422
+ template<typename Y, typename X>
423
+ Y cast(const X x);
424
+
425
+ template<>
426
+ {1} cast<{1}, {0}>(const {0} x) {{
427
+ return {2};
428
+ }}
429
+
430
+ kernel void gather_kernel_5(uint linear_index [[thread_position_in_grid]],
431
+ constant void * src_ [[buffer(0)]],
432
+ device void * dst_ [[buffer(1)]],
433
+ constant packed_uint5 & size [[buffer(2)]],
434
+ constant packed_uint5 & stride [[buffer(3)]],
435
+ constant uint32_t & numel [[buffer(4)]]) {{
436
+ if (linear_index >= numel) return;
437
+
438
+ constant {0} * src = (constant {0} *)src_;
439
+ device {1} * dst = (device {1} *)dst_;
440
+
441
+
442
+ packed_uint5 local_index;
443
+ local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x;
444
+ local_index.y = linear_index / (size.u * size.w * size.z) % size.y;
445
+ local_index.z = linear_index / (size.u * size.w) % size.z;
446
+ local_index.w = linear_index / size.u % size.w;
447
+ local_index.u = linear_index % size.u;
448
+
449
+ packed_uint5 strided_index;
450
+ strided_index.x = local_index.x * stride.x;
451
+ strided_index.y = local_index.y * stride.y;
452
+ strided_index.z = local_index.z * stride.z;
453
+ strided_index.w = local_index.w * stride.w;
454
+ strided_index.u = local_index.u * stride.u;
455
+
456
+ dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u]);
457
+ }}
458
+
459
+ kernel void gather_kernel_4(uint linear_index [[thread_position_in_grid]],
460
+ constant void * src_ [[buffer(0)]],
461
+ device void * dst_ [[buffer(1)]],
462
+ constant packed_uint4 & size [[buffer(2)]],
463
+ constant packed_uint4 & stride [[buffer(3)]],
464
+ constant uint32_t & numel [[buffer(4)]]) {{
465
+ if (linear_index >= numel) return;
466
+
467
+ constant {0} * src = (constant {0} *)src_;
468
+ device {1} * dst = (device {1} *)dst_;
469
+
470
+ packed_uint4 local_index;
471
+ local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
472
+ local_index.y = linear_index / (size[3] * size[2]) % size[1];
473
+ local_index.z = linear_index / size[3] % size[2];
474
+ local_index.w = linear_index % size[3];
475
+
476
+ const packed_uint4 strided_index = local_index * stride;
477
+ dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w]);
478
+ }}
479
+
480
+ kernel void gather_kernel_3(uint linear_index [[thread_position_in_grid]],
481
+ constant void * src_ [[buffer(0)]],
482
+ device void * dst_ [[buffer(1)]],
483
+ constant packed_uint3 & size [[buffer(2)]],
484
+ constant packed_uint3 & stride [[buffer(3)]],
485
+ constant uint32_t & numel [[buffer(4)]]) {{
486
+ if (linear_index >= numel) return;
487
+
488
+ constant {0} * src = (constant {0} *)src_;
489
+ device {1} * dst = (device {1} *)dst_;
490
+
491
+ packed_uint3 local_index;
492
+ local_index.x = linear_index / (size[2] * size[1]) % size[0];
493
+ local_index.y = linear_index / size[2] % size[1];
494
+ local_index.z = linear_index % size[2];
495
+
496
+ const packed_uint3 strided_index = local_index * stride;
497
+ dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z]);
498
+ }}
499
+
500
+ kernel void gather_kernel_2(uint linear_index [[thread_position_in_grid]],
501
+ constant void * src_ [[buffer(0)]],
502
+ device void * dst_ [[buffer(1)]],
503
+ constant packed_uint2 & size [[buffer(2)]],
504
+ constant packed_uint2 & stride [[buffer(3)]],
505
+ constant uint32_t & numel [[buffer(4)]]) {{
506
+ if (linear_index >= numel) return;
507
+
508
+ constant {0} * src = (constant {0} *)src_;
509
+ device {1} * dst = (device {1} *)dst_;
510
+
511
+ packed_uint2 local_index;
512
+ local_index.x = linear_index / size[1] % size[0];
513
+ local_index.y = linear_index % size[1];
514
+
515
+ const packed_uint2 strided_index = local_index * stride;
516
+ dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y]);
517
+ }}
518
+
519
+ kernel void gather_kernel_1(uint linear_index [[thread_position_in_grid]],
520
+ constant void * src_ [[buffer(0)]],
521
+ device void * dst_ [[buffer(1)]],
522
+ constant int & size [[buffer(2)]],
523
+ constant int & stride [[buffer(3)]],
524
+ constant uint32_t & numel [[buffer(4)]]) {{
525
+ if (linear_index >= numel) return;
526
+
527
+ constant {0} * src = (constant {0} *)src_;
528
+ device {1} * dst = (device {1} *)dst_;
529
+
530
+ const int local_index = linear_index % size;
531
+ const int strided_index = local_index * stride;
532
+ dst[linear_index] = cast<{1}>(src[strided_index]);
533
+ }}
534
+ )METAL_GATHER";
535
+ } // namespace at::mps
.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSAllocator.h ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2022 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <ATen/mps/MPSAllocatorInterface.h>
6
+ #include <ATen/mps/MPSEvent.h>
7
+ #include <ATen/mps/MPSStream.h>
8
+
9
+ #include <cstdio>
10
+ #include <mutex>
11
+ #include <set>
12
+ #include <unordered_set>
13
+ #include <mach/vm_page_size.h>
14
+ #include <c10/util/flat_hash_map.h>
15
+
16
+ // this implementation is based on CUDACachingAllocator.
17
+ // It utilizes Metal Heaps to improve the performance with buffer allocation.
18
+ // Do not include this header. Use MPSAllocatorInterface.h instead.
19
+ // TODO: Unify the logic with CUDACachingAllocator and remove redundant code.
20
+ namespace at::mps::HeapAllocator {
21
+
22
+ static const size_t kMaxSmallAlloc = MB(1); // largest "small" allocation is 1 MiB
23
+ static const size_t kMinLargeAlloc = MB(10); // allocations between 1 and 10 MiB may use kLargeHeap
24
+ static const size_t kRoundLarge = MB(2); // round up large allocations to 2 MiB
25
+ static const size_t kSmallHeap = MB(8); // "small" allocations are packed in 8 MiB heaps
26
+ static const size_t kLargeHeap = MB(32); // "large" allocations may be packed in 32 MiB heaps
27
+ static const size_t kXLargeHeapD = MB(128); // "extra large" allocations on Discrete devices may be packed in 128 MiB heaps
28
+ static const size_t kXLargeHeapU = MB(1024); // "extra large" allocations on Unified devices may be packed in 1 GiB heaps
29
+ static const size_t kMaxScalarAlloc = (sizeof(int64_t)); // largest "scalar" allocation
30
+
31
+ // buffer pools could be customized with a combination of usage flags
32
+ enum UsageFlags : uint32_t {
33
+ PRIVATE = 0,
34
+ SMALL = (1 << 0), // small heaps have sizes of kSmallHeap, and large ones kLargeHeap
35
+ SHARED = (1 << 1), // shared pools allocated on devices with unified memory; otherwise, private between host/device
36
+ MANAGED = (1 << 2), // managed storage mode
37
+ HAZARD = (1 << 3), // enables Automatic Hazard Tracking for the resources allocated on the pool
38
+ SCALAR = (1 << 4), // used to import CPU scalar values to GPU and use them in MPS Stream
39
+ };
40
+ // debug verbosity flags
41
+ enum DebugVerbosity : uint32_t {
42
+ SILENT = 0,
43
+ PROFILING = (1 << 0), // print generic profiling data for total system memory usage
44
+ ALLOCATIONS = (1 << 1), // print buffer allocations
45
+ RECYCLES = (1 << 2), // print buffer recycling
46
+ RELEASES = (1 << 3), // print buffer releases
47
+ LARGE_ONLY = (1 << 4), // only log large buffer pool transactions
48
+ };
49
+
50
+ struct HeapBlock;
51
+
52
+ struct BufferBlock {
53
+ id<MTLBuffer> buffer;
54
+ void* cpu_ptr = nullptr; // stores the pointer to CPU mapping of a Shared MTLBuffer
55
+ size_t size; // size after alignment
56
+ size_t requested_size; // requested size (before alignment)
57
+ // buffer shape is used for retrieving base of views in cached graphs
58
+ std::vector<int64_t> shape;
59
+ bool in_use = false;
60
+ HeapBlock* heap;
61
+ id_t buf_id;
62
+ // counter to candidate least recently used buffers for garbage collection
63
+ uint32_t gc_count = 0;
64
+ uint32_t use_count = 0;
65
+ // counter to assign unique ids to buffer blocks
66
+ static uint64_t buffer_counter;
67
+ // Metal events used to sync GPU/CPU operations on the shared-storage buffers
68
+ MPSEventPtr event;
69
+
70
+ BufferBlock(size_t Size, size_t RequestedSize = 0, const id<MTLBuffer> Buffer = nullptr,
71
+ HeapBlock* Heap = nullptr) :
72
+ buffer(Buffer), size(Size), requested_size(RequestedSize),
73
+ heap(Heap), buf_id(Buffer ? ++buffer_counter : 0) { }
74
+
75
+ static bool Comparator(const BufferBlock* a, const BufferBlock* b) {
76
+ return (a->size != b->size) ? a->size < b->size : (uintptr_t)a->buffer < (uintptr_t)b->buffer;
77
+ }
78
+ static size_t alignUp(size_t Size, size_t Alignment) {
79
+ assert(((Alignment - 1) & Alignment) == 0);
80
+ return ((Size + Alignment - 1) & ~(Alignment - 1));
81
+ }
82
+ uint32_t retainCount() const { return [buffer retainCount]; }
83
+ };
84
+ typedef bool (*BufferComparison)(const BufferBlock*, const BufferBlock*);
85
+
86
+ struct BufferPool;
87
+ struct AllocParams {
88
+ AllocParams(size_t Alloc_Size, size_t Requested_Size, BufferPool* Pool) :
89
+ search_key(Alloc_Size), pool(Pool), requested_size(Requested_Size) { }
90
+ size_t size() const { return search_key.size; }
91
+
92
+ BufferBlock search_key;
93
+ BufferPool* pool;
94
+ BufferBlock* buffer_block = nullptr;
95
+ size_t requested_size;
96
+ // true if we exceed the low watermark limit. In this case
97
+ // we apply strategies to relieve the pressure before allocation.
98
+ bool has_memory_pressure = false;
99
+ // true if we're allocating on a unified memory device
100
+ bool has_unified_memory = true;
101
+ };
102
+
103
+ struct HeapBlock {
104
+ id<MTLHeap> heap;
105
+ struct { size_t total, available; } size;
106
+ BufferPool* pool;
107
+ unsigned int n_buffers = 0;
108
+ id_t heap_id;
109
+ // indicates if we split this heap to sub-allocate 'several' buffers (otherwise single buffer)
110
+ bool is_split;
111
+ // counter to assign unique ids to heap blocks
112
+ static uint64_t heap_counter;
113
+
114
+ HeapBlock(size_t Size, const id<MTLHeap> Heap = nullptr, BufferPool *Pool = nullptr) :
115
+ heap(Heap), size({.total = Size, .available = Size}), pool(Pool),
116
+ heap_id(Heap ? ++heap_counter : 0), is_split(true) { }
117
+
118
+ static MTLResourceOptions getOptions(uint32_t usage) {
119
+ // TODO: check the caching performance of write-combined mode
120
+ MTLResourceOptions options = MTLResourceCPUCacheModeDefaultCache;
121
+
122
+ if (usage & UsageFlags::MANAGED)
123
+ options |= MTLResourceStorageModeManaged;
124
+ else if (usage & UsageFlags::SHARED)
125
+ options |= MTLResourceStorageModeShared;
126
+ else
127
+ options |= MTLResourceStorageModePrivate;
128
+
129
+ options |= (usage & UsageFlags::HAZARD) ? MTLResourceHazardTrackingModeTracked : MTLResourceHazardTrackingModeUntracked;
130
+
131
+ return options;
132
+ }
133
+
134
+ static HeapBlock* createHeapBlock(AllocParams& params, id<MTLDevice> device, uint32_t usage) {
135
+ HeapBlock *heapBlock = nullptr;
136
+ bool is_split = true;
137
+ const size_t size = params.size();
138
+ MTLHeapDescriptor *d = [MTLHeapDescriptor new];
139
+ if (d) {
140
+ const size_t kXLargeHeap = params.has_unified_memory ? kXLargeHeapU : kXLargeHeapD;
141
+ if (size <= kMaxSmallAlloc) {
142
+ d.size = kSmallHeap;
143
+ } else if (size < kMinLargeAlloc) {
144
+ d.size = kLargeHeap;
145
+ } else if (size < kXLargeHeap / 2 && !params.has_memory_pressure) {
146
+ d.size = kXLargeHeap;
147
+ } else {
148
+ d.size = kRoundLarge * ((size + kRoundLarge - 1) / kRoundLarge);
149
+ is_split = false;
150
+ }
151
+ d.storageMode = (usage & UsageFlags::SHARED) ? MTLStorageModeShared : MTLStorageModePrivate;
152
+ d.cpuCacheMode = MTLCPUCacheModeDefaultCache;
153
+ // this automatically handles Metal buffer access synchronizations at the
154
+ // cost of slightly lower performance.
155
+ d.hazardTrackingMode = (usage & UsageFlags::HAZARD) ? MTLHazardTrackingModeTracked : MTLHazardTrackingModeUntracked;
156
+ d.resourceOptions = getOptions(usage);
157
+ d.type = MTLHeapTypeAutomatic;
158
+ id<MTLHeap> heap = [device newHeapWithDescriptor: d];
159
+ if (heap) {
160
+ [heap setPurgeableState:MTLPurgeableStateNonVolatile];
161
+ const size_t heap_size = heapAvailableSize(heap);
162
+ heapBlock = new HeapBlock(heap_size, heap, params.pool);
163
+ if (heapBlock) {
164
+ heapBlock->is_split = is_split;
165
+ }
166
+ }
167
+ [d release];
168
+ }
169
+ return heapBlock;
170
+ }
171
+ static bool Comparator(const HeapBlock* a, const HeapBlock* b) {
172
+ return (a->size.available != b->size.available) ? a->size.available < b->size.available :
173
+ (uintptr_t)a->heap < (uintptr_t)b->heap;
174
+ }
175
+ static NSUInteger heapAvailableSize(id<MTLHeap> heap, size_t Alignment = vm_page_size) {
176
+ return [heap maxAvailableSizeWithAlignment:Alignment];
177
+ }
178
+ NSUInteger Size() {
179
+ return [heap size];
180
+ }
181
+ id<MTLBuffer> newMTLBuffer(size_t length, uint32_t usage) {
182
+ id<MTLBuffer> buf = [heap newBufferWithLength:length options:getOptions(usage)];
183
+ if (buf) {
184
+ updateAvailableSize();
185
+ n_buffers++;
186
+ }
187
+ return buf;
188
+ }
189
+ // returns the retainCount before releasing the buffer
190
+ uint32_t releaseMTLBuffer(id<MTLBuffer>& buffer) {
191
+ const uint32_t retainCount = [buffer retainCount];
192
+ [buffer release];
193
+ buffer = nil;
194
+ updateAvailableSize();
195
+ n_buffers--;
196
+ return retainCount;
197
+ }
198
+ // returns the retainCount before releasing the heap
199
+ uint32_t releaseMTLHeap() {
200
+ const uint32_t retainCount = [heap retainCount];
201
+ TORCH_INTERNAL_ASSERT(!n_buffers); // assert if heap isn't empty
202
+ [heap setPurgeableState:MTLPurgeableStateEmpty];
203
+ [heap release];
204
+ heap = nil;
205
+ size.available = 0;
206
+ return retainCount;
207
+ }
208
+ uint32_t retainCount() const { return [heap retainCount]; }
209
+ void updateAvailableSize() { size.available = heapAvailableSize(heap); }
210
+ };
211
+ typedef bool (*HeapComparison)(const HeapBlock*, const HeapBlock*);
212
+
213
+ struct BufferPool {
214
+ enum class Kind {
215
+ PRIVATE_SMALL,
216
+ PRIVATE_LARGE,
217
+ SHARED_SMALL,
218
+ SHARED_LARGE,
219
+ SCALAR,
220
+ };
221
+
222
+ BufferPool(const id<MTLDevice> Device, uint32_t Usage) :
223
+ device(Device), usage(Usage),
224
+ heaps(HeapBlock::Comparator), available_buffers(BufferBlock::Comparator) { }
225
+
226
+ const id<MTLDevice> device;
227
+ // usage flags to customize the pool for various purposes (see UsageFlags enum)
228
+ const uint32_t usage;
229
+ // total number of buffers in the pool
230
+ uint32_t n_buffers = 0;
231
+ // total allocations size on this pool
232
+ size_t allocated_size = 0;
233
+ // total memory available in the pool
234
+ size_t available_size = 0;
235
+ // list of heaps ordered by their "available" (not total) memory size
236
+ std::set<HeapBlock*, HeapComparison> heaps;
237
+ // list of only "available" buffers in the pool (i.e., buffers not in-use)
238
+ std::set<BufferBlock*, BufferComparison> available_buffers;
239
+ // list of buffers that are in a state of "limbo" where they've already been freed
240
+ // from PyTorch-side, but were not returned to pool due to still being
241
+ // in-use by command buffers with retainCount > 1. In this state, the buffer is
242
+ // neither ready to be recycled, nor could be returned to pool as available.
243
+ // These buffers will be returned to pool once the command buffer's
244
+ // completionHandler callbacks are called.
245
+ std::unordered_set<BufferBlock*> buffers_pending_free;
246
+ // list of heaps pending size update
247
+ std::unordered_set<HeapBlock*> heaps_pending_update;
248
+ };
249
+
250
+ class MPSHeapAllocatorImpl {
251
+ public:
252
+ explicit MPSHeapAllocatorImpl() :
253
+ m_device(at::mps::MPSDevice::getInstance()->device()),
254
+ m_max_buffer_size([m_device maxBufferLength]),
255
+ m_stream(getDefaultMPSStream()),
256
+ m_event_pool(getMPSEventPool()) {
257
+ init_allocator();
258
+ }
259
+ ~MPSHeapAllocatorImpl() {
260
+ emptyCache();
261
+ }
262
+ // interface exposed to at::Allocator
263
+ id<MTLBuffer> malloc(size_t size, uint32_t usage);
264
+ // frees a buffer and returns it into buffer pool
265
+ void free(void* ptr);
266
+ // releases all the cached buffers and their associated heaps
267
+ void emptyCache();
268
+ // free inactive buffers that are pending to be freed
269
+ void freeInactiveBuffers();
270
+ // returns true if buffer was allocated from the shared pool
271
+ bool isSharedBuffer(const void* ptr);
272
+ // get the requested unaligned size of an MTLBuffer
273
+ ssize_t getUnalignedBufferSize(const void* ptr);
274
+ // set the shape of a base tensor from a view tensor
275
+ void setBufferShape(const void* ptr, const IntArrayRef& shape);
276
+ // retrieve the shape of a base tensor from a view tensor
277
+ IntArrayRef getBufferShape(const void* ptr);
278
+ // get the unique ID of the buffer
279
+ id_t getBufferId(const void* ptr);
280
+ // allocate a buffer from a specialized pool to import CPU scalars into GPU
281
+ id<MTLBuffer> allocScalarBufferWithValue(void* value, size_t size);
282
+ // returns a CPU-mapping of the input buffer and its retainCount,
283
+ // if only it has Shared storage-mode and allocated on MPSAllocator
284
+ std::pair<const void*, uint32_t> getSharedBufferPtr(const void* buffer);
285
+ // records events for a list of MTLBuffers (list is used to lock the mutex once)
286
+ // returns true if records any event (given if passed buffers exist and are shared-storage)
287
+ bool recordEvents(c10::ArrayRef<const void*> buffers);
288
+ // waits for the event to signal the completion of GPU execution
289
+ // on the passed shared buffers (list is used to lock the mutex once)
290
+ // returns true if actually waited on any event
291
+ bool waitForEvents(c10::ArrayRef<const void*> buffers);
292
+ // this indicates how far (in Megabytes) the current total allocations are from the
293
+ // low watermark limit which is used to detect if we're under memory pressure
294
+ // This returns zero if we've reached the low watermark limit
295
+ ssize_t getLowWatermarkValue();
296
+ // (see m_low_watermark_ratio for description)
297
+ void setLowWatermarkRatio(double ratio);
298
+ // (see m_high_watermark_ratio for description)
299
+ void setHighWatermarkRatio(double ratio);
300
+ // (see m_low_watermark_limit for description)
301
+ size_t getLowWatermarkLimit() const { return m_low_watermark_limit; }
302
+ // (see m_max_total_allowed_size for description)
303
+ size_t getHighWatermarkLimit() const { return m_max_total_allowed_size; }
304
+ // (see m_total_allocated_memory for description)
305
+ size_t getTotalAllocatedMemory() const { return m_total_allocated_memory; }
306
+ // (see m_current_allocated_memory for description)
307
+ size_t getCurrentAllocatedMemory() const { return m_current_allocated_memory; }
308
+ // total GPU memory allocated in the process by Metal driver; including
309
+ // implicit allocations from MPS/MPSGraph frameworks and MPSHeapAllocatorImpl.
310
+ size_t getDriverAllocatedMemory() const { return current_allocated_size(); }
311
+ // recommended Max memory for Metal
312
+ size_t getRecommendedMaxMemory() const { return max_device_size(); }
313
+ // (see enum DebugVerbosity for description)
314
+ uint32_t getDebugVerbosity() const { return m_debug_verbosity; }
315
+ // returns the device that we allocate from
316
+ inline id<MTLDevice> Device() const { return m_device; }
317
+
318
+ // TODO: make a common function to do size unit conversions in PyTorch.
319
+ inline std::string format_size(uint64_t size) const;
320
+
321
+ private:
322
+ // (see m_high_watermark_ratio for description)
323
+ constexpr static double default_high_watermark_ratio = 1.7;
324
+ // we set the allowed upper bound to twice the size of recommendedMaxWorkingSetSize.
325
+ constexpr static double default_high_watermark_upper_bound = 2.0;
326
+ // (see m_low_watermark_ratio for description)
327
+ // on unified memory, we could allocate beyond the recommendedMaxWorkingSetSize
328
+ constexpr static double default_low_watermark_ratio_unified = 1.4;
329
+ constexpr static double default_low_watermark_ratio_discrete = 1.0;
330
+
331
+ const id<MTLDevice> m_device;
332
+ std::recursive_mutex m_mutex;
333
+ // allocated buffers by device pointer
334
+ ska::flat_hash_map<const void*, BufferBlock*> m_allocated_buffers;
335
+ // using a container for pools to simplify iterating them
336
+ ska::flat_hash_map<BufferPool::Kind, std::unique_ptr<BufferPool>> m_pools;
337
+ // total memory allocated by HeapAllocator (including blocks in pools)
338
+ size_t m_total_allocated_memory = 0;
339
+ // currently active memory allocations in use (i.e., blocks not in pools)
340
+ size_t m_current_allocated_memory = 0;
341
+ // max buffer size allowed by Metal
342
+ size_t m_max_buffer_size = 0;
343
+ // maximum total size allowed to be allocated
344
+ size_t m_max_total_allowed_size = 0;
345
+ // high watermark ratio is a hard limit for the total allowed allocations
346
+ // 0. : disables high watermark limit (may cause system failure if system-wide OOM occurs)
347
+ // 1. : recommended maximum allocation size (i.e., device.recommendedMaxWorkingSetSize)
348
+ // >1.: allows limits beyond the device.recommendedMaxWorkingSetSize
349
+ // e.g., value 0.95 means we allocate up to 95% of recommended maximum
350
+ // allocation size; beyond that, the allocations would fail with OOM error.
351
+ double m_high_watermark_ratio;
352
+ // low watermark ratio is a soft limit to attempt limiting memory allocations up to the lower watermark
353
+ // level by garbage collection or committing command buffers more frequently (a.k.a, adaptive commit).
354
+ // Value between 0 to m_high_watermark_ratio (setting 0.0 disables adaptive commit and garbage collection)
355
+ // e.g., value 0.9 means we 'attempt' to limit allocations up to 90% of recommended maximum
356
+ // allocation size.
357
+ double m_low_watermark_ratio;
358
+ // low watermark size limit (in Bytes) at the time we initialize the allocator
359
+ size_t m_low_watermark_limit;
360
+ // use "PYTORCH_DEBUG_MPS_ALLOCATOR" env-var to set debug verbosity
361
+ uint32_t m_debug_verbosity;
362
+ // default MPS stream
363
+ MPSStream* m_stream;
364
+ // we hold a reference to MPSEventPool so it could get destroyed after MPSAllocator
365
+ std::shared_ptr<MPSEventPool> m_event_pool;
366
+
367
+ void init_allocator();
368
+ void init_buffer_pools();
369
+ HeapBlock* get_free_heap(AllocParams& params);
370
+ bool get_free_buffer(AllocParams& params);
371
+ BufferBlock* get_allocated_buffer_block(const void* ptr);
372
+ BufferBlock* alloc_buffer_block(size_t size, uint32_t usage);
373
+ bool alloc_buffer(AllocParams& params);
374
+ void free_buffer(BufferBlock* buffer_block);
375
+ // returns true if the container heap is also released
376
+ bool release_buffer(BufferBlock* buffer_block, bool remove_empty_heap = true);
377
+ void release_buffers(BufferPool& pool);
378
+ bool release_available_cached_buffers(AllocParams& params);
379
+ bool release_cached_buffers();
380
+ // free unused cached blocks to reclaim GPU memory if memory pressure is high
381
+ void garbage_collect_cached_buffers(AllocParams& params);
382
+ // returns the suitable buffer pool type for the usage or
383
+ // requested/allocated sizes
384
+ BufferPool& get_pool(size_t requested_size, size_t aligned_size, uint32_t usage);
385
+ // returns the aligned allocation size that is optimized
386
+ // for the buffers to get reused frequently
387
+ size_t get_allocation_size(size_t size, uint32_t usage) const;
388
+ // maximum size of device memory available for allocation in current process
389
+ // Note: the recommendedMaxWorkingSetSize is typically 75% of the total system memory.
390
+ size_t max_device_size() const { return [m_device recommendedMaxWorkingSetSize]; }
391
+ // there are implicit allocations from MPS backend, so we need to query the 'device' for
392
+ // total allocated size instead of manually tracking in MPSAllocator
393
+ size_t current_allocated_size() const { return [m_device currentAllocatedSize]; }
394
+
395
+ bool trigger_memory_callbacks(BufferBlock* buffer_block, IMpsAllocatorCallback::EventType event) const {
396
+ for (const auto& name : MPSAllocatorCallbacksRegistry()->Keys()) {
397
+ MPSAllocatorCallbacksRegistry()->Create(name)->executeMPSAllocatorCallback(buffer_block ? buffer_block->buffer : nullptr, event);
398
+ }
399
+ return true;
400
+ }
401
+ };
402
+
403
+ } // namespace at::mps::HeapAllocator
.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSAllocatorInterface.h ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <c10/core/Allocator.h>
6
+ #include <c10/util/Registry.h>
7
+ #include <ATen/core/ATen_fwd.h>
8
+
9
+ #define MB(x) (x * 1048576UL)
10
+
11
+ namespace at::mps {
12
+
13
+ // this is a public interface to access MPSAllocator.
14
+ // Do not declare methods that would depend on MPS or Metal frameworks.
15
+ class IMPSAllocator : public c10::Allocator {
16
+ public:
17
+ // see the comments in MPSAllocator.h for the description of these methods.
18
+ virtual void emptyCache() const = 0;
19
+ virtual void freeInactiveBuffers() const = 0;
20
+ virtual ssize_t getUnalignedBufferSize(const void* ptr) const = 0;
21
+ virtual IntArrayRef getBufferShape(const void* ptr) const = 0;
22
+ virtual id_t getBufferId(const void* ptr) const = 0;
23
+ virtual void setBufferShape(const void* ptr, const IntArrayRef& shape) const = 0;
24
+ virtual bool isSharedBuffer(const void* ptr) const = 0;
25
+ virtual bool isSharedStorageSupported() const = 0;
26
+ virtual c10::DataPtr allocScalarBufferWithValue(void* value, size_t size) const = 0;
27
+ virtual std::string formatSize(size_t size) const = 0;
28
+ virtual void setLowWatermarkRatio(double ratio) const = 0;
29
+ virtual void setHighWatermarkRatio(double ratio) const = 0;
30
+ virtual ssize_t getLowWatermarkValue() const = 0;
31
+ virtual size_t getLowWatermarkLimit() const = 0;
32
+ virtual size_t getHighWatermarkLimit() const = 0;
33
+ virtual size_t getTotalAllocatedMemory() const = 0;
34
+ virtual size_t getCurrentAllocatedMemory() const = 0;
35
+ virtual size_t getDriverAllocatedMemory() const = 0;
36
+ virtual size_t getRecommendedMaxMemory() const = 0;
37
+ virtual std::pair<const void*, uint32_t> getSharedBufferPtr(const void* ptr) const = 0;
38
+ virtual bool recordEvents(c10::ArrayRef<const void*> buffers) const = 0;
39
+ virtual bool waitForEvents(c10::ArrayRef<const void*> buffers) const = 0;
40
+ };
41
+
42
+ class IMpsAllocatorCallback {
43
+ public:
44
+ enum class EventType {
45
+ ALLOCATED, // buffer got allocated to be used immediately
46
+ RECYCLED, // buffer pulled from free list to be reused
47
+ FREED, // buffer put to free list for future recycling
48
+ RELEASED, // buffer memory released
49
+ ALLOCATION_FAILED // buffer allocation failed
50
+ };
51
+ virtual ~IMpsAllocatorCallback() = default;
52
+ virtual void executeMPSAllocatorCallback(void* ptr, EventType event) = 0;
53
+ };
54
+
55
+ // MPS allocator will execute every registered callback when a block of memory is freed.
56
+ C10_DECLARE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback);
57
+ #define REGISTER_MPS_ALLOCATOR_CALLBACK(name, ...) \
58
+ C10_REGISTER_CLASS(MPSAllocatorCallbacksRegistry, name, __VA_ARGS__);
59
+
60
+ IMPSAllocator* getIMPSAllocator(bool sharedAllocator = false);
61
+
62
+ bool isMPSPinnedPtr(const void* data);
63
+
64
+ } // namespace at::mps
.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSDevice.h ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2022 Apple Inc.
2
+
3
+ #pragma once
4
+ #include <c10/core/Allocator.h>
5
+ #include <c10/macros/Macros.h>
6
+ #include <c10/util/Exception.h>
7
+
8
+
9
+ #ifdef __OBJC__
10
+ #include <Foundation/Foundation.h>
11
+ #include <Metal/Metal.h>
12
+ #include <MetalPerformanceShaders/MetalPerformanceShaders.h>
13
+ typedef id<MTLDevice> MTLDevice_t;
14
+ typedef id<MTLLibrary> MTLLibrary_t;
15
+ typedef id<MTLComputePipelineState> MTLComputePipelineState_t;
16
+ typedef id<MTLLibrary> MTLLibrary_t;
17
+ #else
18
+ typedef void* MTLDevice;
19
+ typedef void* MTLDevice_t;
20
+ typedef void* MTLLibrary_t;
21
+ typedef void* MTLComputePipelineState_t;
22
+ typedef void* MTLLibrary_t;
23
+ #endif
24
+
25
+ namespace at::mps {
26
+
27
+ // Helper enum to check if a MPSGraph op is supported in a given macOS version
28
+ enum class MacOSVersion : uint32_t {
29
+ MACOS_VER_13_1_PLUS = 0,
30
+ MACOS_VER_13_2_PLUS,
31
+ MACOS_VER_13_3_PLUS,
32
+ MACOS_VER_14_0_PLUS,
33
+ MACOS_VER_14_4_PLUS,
34
+ MACOS_VER_15_0_PLUS,
35
+ };
36
+
37
+ //-----------------------------------------------------------------
38
+ // MPSDevice
39
+ //
40
+ // MPSDevice is a singleton class that returns the default device
41
+ //-----------------------------------------------------------------
42
+
43
+ class TORCH_API MPSDevice {
44
+ public:
45
+ /**
46
+ * MPSDevice should not be cloneable.
47
+ */
48
+ MPSDevice(MPSDevice& other) = delete;
49
+ /**
50
+ * MPSDevice should not be assignable.
51
+ */
52
+ void operator=(const MPSDevice&) = delete;
53
+ /**
54
+ * Gets single instance of the Device.
55
+ */
56
+ static MPSDevice* getInstance();
57
+ /**
58
+ * Returns the single device.
59
+ */
60
+ MTLDevice_t device() {
61
+ return _mtl_device;
62
+ }
63
+ /**
64
+ * Returns whether running on Ventura or newer
65
+ */
66
+ bool isMacOS13Plus(MacOSVersion version) const;
67
+
68
+ MTLComputePipelineState_t metalIndexingPSO(const std::string &kernel);
69
+ MTLLibrary_t getMetalIndexingLibrary();
70
+
71
+ ~MPSDevice();
72
+
73
+ private:
74
+ static MPSDevice* _device;
75
+ MTLDevice_t _mtl_device;
76
+ MTLLibrary_t _mtl_indexing_library;
77
+ MPSDevice();
78
+ };
79
+
80
+ TORCH_API bool is_available();
81
+ TORCH_API bool is_macos_13_or_newer(MacOSVersion version);
82
+ TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);
83
+
84
+ } // namespace at::mps
.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSEvent.h ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <ATen/mps/MPSStream.h>
6
+ #include <ctime>
7
+ #include <stack>
8
+
9
+ namespace at::mps {
10
+
11
+ // NOTE: don't create instances of this class directly.
12
+ // Use MPSEventPool to acquire instances of MPSEvent.
13
+ class MPSEvent {
14
+ public:
15
+ explicit MPSEvent(id_t ID, MPSStream* stream, bool enable_timing);
16
+ ~MPSEvent();
17
+
18
+ // records an event on the stream
19
+ void record(bool needsLock, bool syncEvent = false);
20
+ // makes all future work submitted to the stream wait for this event.
21
+ bool wait(bool needsLock, bool syncEvent = false);
22
+ // schedules a notifyListener callback for the event.
23
+ bool notify(bool needsLock, MTLSharedEventNotificationBlock block);
24
+ // checks if events are already signaled.
25
+ bool query() const;
26
+ // blocks the CPU thread until all the GPU work that were scheduled
27
+ // prior to recording this event are completed.
28
+ bool synchronize();
29
+ // resets this event with new parameters in case it gets reused from the event pool
30
+ void reset(MPSStream* stream, bool enable_timing);
31
+ // returns the unique ID of the event instance
32
+ id_t getID() const { return m_id; }
33
+ // returns the completion timestamp of the event
34
+ uint64_t getCompletionTime() const { return m_completion_time; }
35
+ // if already recorded, waits for cpu_sync_cv to be signaled
36
+ void waitForCpuSync();
37
+
38
+ private:
39
+ id_t m_id;
40
+ // enables measuring the completion time of the notifyListener of this event
41
+ bool m_enable_timing;
42
+ uint64_t m_signalCounter = 0;
43
+ MPSStream* m_stream = nullptr;
44
+ MTLSharedEvent_t m_event = nullptr;
45
+ MTLSharedEventListener* m_listener = nullptr;
46
+ // used to sync the events created on this Stream with CPU
47
+ std::mutex m_cpu_sync_mutex{};
48
+ std::condition_variable m_cpu_sync_cv{};
49
+ // CondVar predicate to sync the events created on this Stream with CPU
50
+ bool m_cpu_sync_completed = false;
51
+ // used to compute elapsed time
52
+ uint64_t m_completion_time = 0;
53
+
54
+ void recordLocked(bool syncEvent);
55
+ bool waitLocked(bool syncEvent);
56
+ bool notifyLocked(MTLSharedEventNotificationBlock block);
57
+ void notifyCpuSync();
58
+ static uint64_t getTime() {
59
+ return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW);
60
+ }
61
+ };
62
+
63
+ typedef std::unique_ptr<MPSEvent, std::function<void(MPSEvent*)>> MPSEventPtr;
64
+
65
+ class MPSEventPool {
66
+ public:
67
+ explicit MPSEventPool(MPSStream* default_stream);
68
+ ~MPSEventPool();
69
+
70
+ MPSEventPtr acquireEvent(bool enable_timing, MPSStream* stream);
71
+ void emptyCache();
72
+
73
+ // these are mainly used for MPSHooks and torch.mps.Event() bindings
74
+ id_t acquireEvent(bool enable_timing);
75
+ void releaseEvent(id_t event_id);
76
+ void recordEvent(id_t event_id, bool syncEvent);
77
+ void waitForEvent(id_t event_id, bool syncEvent);
78
+ void synchronizeEvent(id_t event_id);
79
+ bool queryEvent(id_t event_id);
80
+ // returns elapsed time between two recorded events in milliseconds
81
+ double elapsedTime(id_t start_event_id, id_t end_event_id);
82
+
83
+ private:
84
+ MPSStream* m_default_stream = nullptr;
85
+ std::recursive_mutex m_mutex;
86
+ std::stack<std::unique_ptr<MPSEvent>> m_pool{};
87
+ // dictionary to associate event IDs with event objects
88
+ // used to retain in-use events out of the pool
89
+ // for torch.mps.Event() bindings.
90
+ std::unordered_map<id_t, MPSEventPtr> m_in_use_events{};
91
+ uint64_t m_event_counter = 0;
92
+ std::function<void(MPSEvent*)> m_default_deleter;
93
+
94
+ MPSEvent* getInUseEvent(id_t event_id, bool locked = true);
95
+ };
96
+
97
+ // shared_ptr is used to get MPSEventPool destroyed after dependent instances
98
+ std::shared_ptr<MPSEventPool> getMPSEventPool();
99
+
100
+ } // namespace at::mps
.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSGeneratorImpl.h ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2022 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <ATen/core/Generator.h>
6
+ #include <ATen/core/PhiloxRNGEngine.h>
7
+ #include <c10/core/GeneratorImpl.h>
8
+ #include <optional>
9
+
10
+ namespace at {
11
+ namespace mps::detail {
12
+
13
+ constexpr uint32_t PHILOX_STATE_N = 7;
14
+ struct rng_data_pod {
15
+ std::array<uint32_t, PHILOX_STATE_N> state{1};
16
+ uint64_t seed = default_rng_seed_val;
17
+ };
18
+
19
+ TORCH_API const Generator& getDefaultMPSGenerator();
20
+ TORCH_API Generator createMPSGenerator(uint64_t seed_val = default_rng_seed_val);
21
+
22
+ } // namespace mps::detail
23
+
24
+ struct TORCH_API MPSGeneratorImpl : public c10::GeneratorImpl {
25
+ // Constructors
26
+ MPSGeneratorImpl(uint64_t seed_in = default_rng_seed_val);
27
+ ~MPSGeneratorImpl() override = default;
28
+
29
+ // MPSGeneratorImpl methods
30
+ std::shared_ptr<MPSGeneratorImpl> clone() const;
31
+ void set_current_seed(uint64_t seed) override;
32
+ void set_offset(uint64_t offset) override;
33
+ uint64_t get_offset() const override;
34
+ uint64_t current_seed() const override;
35
+ uint64_t seed() override;
36
+ void set_state(const c10::TensorImpl& new_state) override;
37
+ c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
38
+ void update_philox_counters();
39
+
40
+ void set_engine(at::Philox4_32 engine) { engine_ = engine; };
41
+ at::Philox4_32 engine() { return engine_; };
42
+ uint32_t* state_data() { return data_.state.data(); }
43
+ static DeviceType device_type() { return DeviceType::MPS; };
44
+
45
+ private:
46
+ mps::detail::rng_data_pod data_;
47
+ at::Philox4_32 engine_;
48
+
49
+ MPSGeneratorImpl* clone_impl() const override;
50
+ };
51
+
52
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSGuardImpl.h ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2022 Apple Inc.
2
+
3
+ #pragma once
4
+ #include <c10/core/impl/DeviceGuardImplInterface.h>
5
+ #include <c10/macros/Macros.h>
6
+ #include <c10/util/Exception.h>
7
+ #include <ATen/Context.h>
8
+ #include <ATen/mps/MPSStream.h>
9
+ #include <ATen/mps/MPSEvent.h>
10
+
11
+ #ifdef __OBJC__
12
+ #include <Foundation/Foundation.h>
13
+ #include <Metal/Metal.h>
14
+ #include <MetalPerformanceShaders/MetalPerformanceShaders.h>
15
+ #endif
16
+
17
+ #include <ATen/Tensor.h>
18
+ #include <c10/core/MemoryFormat.h>
19
+ #include <c10/core/Storage.h>
20
+ #include <c10/core/TensorImpl.h>
21
+ #include <sys/_types/_size_t.h>
22
+ #include <memory>
23
+ #include <c10/core/UndefinedTensorImpl.h>
24
+ #include <c10/util/intrusive_ptr.h>
25
+
26
+
27
+ namespace at::mps {
28
+
29
+ typedef MPSEvent* mpsEvent_t;
30
+
31
+ // TODO: Move the MPSGuardImpl to inherit from NoOpDeviceGuardImpl
32
+ // https://github.com/pytorch/pytorch/issues/77170
33
+ struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface {
34
+ static constexpr c10::DeviceType static_type = c10::DeviceType::MPS;
35
+
36
+ // constructor
37
+ MPSGuardImpl() {}
38
+ explicit MPSGuardImpl(c10::DeviceType t) {
39
+ TORCH_INTERNAL_ASSERT(t == c10::DeviceType::MPS);
40
+ }
41
+
42
+ // returns the type
43
+ c10::DeviceType type() const override {
44
+ return c10::DeviceType::MPS;
45
+ }
46
+
47
+ Device exchangeDevice(Device d) const override {
48
+ return Device(c10::DeviceType::MPS, 0);
49
+ }
50
+
51
+ Device getDevice() const override {
52
+ return Device(c10::DeviceType::MPS, 0);
53
+ }
54
+
55
+ std::optional<Device> uncheckedGetDevice() const noexcept {
56
+ return Device(c10::DeviceType::MPS, 0);
57
+ }
58
+
59
+ void setDevice(Device d) const override {
60
+ TORCH_INTERNAL_ASSERT(d.is_mps());
61
+ }
62
+
63
+ void uncheckedSetDevice(Device d) const noexcept override {
64
+ // TODO: Currently setting only device 0
65
+ }
66
+
67
+ Stream getStream(Device d) const noexcept override {
68
+ return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
69
+ }
70
+
71
+ Stream getNewStream(Device, int priority = 0) const override {
72
+ (void)priority;
73
+ return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
74
+ }
75
+
76
+ Stream getDefaultStream(Device d) const override {
77
+ return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
78
+ }
79
+
80
+ // NB: These do NOT set the current device
81
+ Stream exchangeStream(Stream s) const noexcept override {
82
+ return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
83
+ }
84
+ DeviceIndex deviceCount() const noexcept override {
85
+ if (at::hasMPS()) {
86
+ //TODO: extend it for multi-device case
87
+ return 1;
88
+ } else {
89
+ return 0;
90
+ }
91
+ }
92
+
93
+ // Event-related functions
94
+ void createEvent(
95
+ mpsEvent_t* event,
96
+ const EventFlag flag) const;
97
+
98
+ void destroyEvent(
99
+ void* event,
100
+ const DeviceIndex device_index) const noexcept override;
101
+
102
+ void record(
103
+ void** event,
104
+ const Stream& stream,
105
+ const DeviceIndex device_index,
106
+ const EventFlag flag) const override;
107
+
108
+ void block(
109
+ void* event,
110
+ const Stream& stream) const override;
111
+
112
+ bool queryEvent(void* event) const override;
113
+
114
+ };
115
+
116
+ /// A variant of OptionalDeviceGuard that is specialized for MPS.
117
+ struct OptionalMPSGuard {
118
+ explicit OptionalMPSGuard() : guard_() {}
119
+
120
+ explicit OptionalMPSGuard(std::optional<Device> device_opt)
121
+ : guard_(device_opt) {}
122
+
123
+ /// Set the current MPS device to the passed device index, if it is not
124
+ /// nullopt
125
+ explicit OptionalMPSGuard(std::optional<DeviceIndex> device_index_opt)
126
+ : guard_(device_index_opt) {}
127
+
128
+ // Copy is not allowed
129
+ OptionalMPSGuard(const OptionalMPSGuard&) = delete;
130
+ OptionalMPSGuard& operator=(const OptionalMPSGuard&) = delete;
131
+ OptionalMPSGuard(OptionalMPSGuard&& other) = delete;
132
+ OptionalMPSGuard& operator=(OptionalMPSGuard&& other) = delete;
133
+
134
+ /// Sets the MPS device to the given device, initializing the guard if it
135
+ /// is not already initialized. Errors if the given device is not a MPS
136
+ /// device.
137
+ void set_device(Device device) {
138
+ guard_.set_device(device);
139
+ }
140
+
141
+ /// Sets the MPS device to the given device, initializing the guard if it is
142
+ /// not already initialized. Errors if the given device is not a MPS device.
143
+ void reset_device(Device device) {
144
+ guard_.reset_device(device);
145
+ }
146
+
147
+ /// Sets the MPS device to the given device index, initializing the guard if
148
+ /// it is not already initialized.
149
+ void set_index(DeviceIndex device_index) {
150
+ guard_.set_index(device_index);
151
+ }
152
+
153
+ /// Returns the device that was set immediately prior to initialization of the
154
+ /// guard, or nullopt if the guard is uninitialized.
155
+ std::optional<Device> original_device() const {
156
+ return guard_.original_device();
157
+ }
158
+
159
+ /// Returns the most recent device that was set using this device guard,
160
+ /// either from construction, or via set_device, if the guard is initialized,
161
+ /// or nullopt if the guard is uninitialized.
162
+ std::optional<Device> current_device() const {
163
+ return guard_.current_device();
164
+ }
165
+
166
+ /// Restore the original MPS device, resetting this guard to uninitialized
167
+ /// state.
168
+ void reset() {
169
+ guard_.reset();
170
+ }
171
+
172
+ private:
173
+ c10::impl::InlineOptionalDeviceGuard<MPSGuardImpl> guard_;
174
+ };
175
+
176
+
177
+ C10_REGISTER_GUARD_IMPL(MPS, MPSGuardImpl);
178
+
179
+ } // namespace at::mps
.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSHooks.h ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2022 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <ATen/detail/MPSHooksInterface.h>
6
+ #include <ATen/Generator.h>
7
+ #include <ATen/mps/MPSEvent.h>
8
+ #include <optional>
9
+
10
+ namespace at::mps {
11
+
12
+ // The real implementation of MPSHooksInterface
13
+ struct MPSHooks : public at::MPSHooksInterface {
14
+ MPSHooks(at::MPSHooksArgs) {}
15
+ void initMPS() const override;
16
+
17
+ // MPSDevice interface
18
+ bool hasMPS() const override;
19
+ bool isOnMacOSorNewer(unsigned major, unsigned minor) const override;
20
+
21
+ // MPSGeneratorImpl interface
22
+ const Generator& getDefaultMPSGenerator() const override;
23
+
24
+ // MPSStream interface
25
+ void deviceSynchronize() const override;
26
+ void commitStream() const override;
27
+ void* getCommandBuffer() const override;
28
+ void* getDispatchQueue() const override;
29
+
30
+ // MPSAllocator interface
31
+ Allocator* getMPSDeviceAllocator() const override;
32
+ void emptyCache() const override;
33
+ size_t getCurrentAllocatedMemory() const override;
34
+ size_t getDriverAllocatedMemory() const override;
35
+ size_t getRecommendedMaxMemory() const override;
36
+ void setMemoryFraction(double ratio) const override;
37
+ bool isPinnedPtr(const void* data) const override;
38
+ Allocator* getPinnedMemoryAllocator() const override;
39
+
40
+ // MPSProfiler interface
41
+ void profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const override;
42
+ void profilerStopTrace() const override;
43
+
44
+ // MPSEvent interface
45
+ uint32_t acquireEvent(bool enable_timing) const override;
46
+ void releaseEvent(uint32_t event_id) const override;
47
+ void recordEvent(uint32_t event_id) const override;
48
+ void waitForEvent(uint32_t event_id) const override;
49
+ void synchronizeEvent(uint32_t event_id) const override;
50
+ bool queryEvent(uint32_t event_id) const override;
51
+ double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const override;
52
+
53
+ // Compatibility with Accelerator API
54
+ bool hasPrimaryContext(DeviceIndex device_index) const override {
55
+ // When MPS is available, it is always in use for the one device.
56
+ return true;
57
+ }
58
+ };
59
+
60
+ } // namespace at::mps
.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSProfiler.h ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2022 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <ATen/Tensor.h>
6
+ #include <ATen/mps/MPSStream.h>
7
+ #include <ATen/mps/MPSAllocatorInterface.h>
8
+
9
+ #include <os/signpost.h>
10
+ #include <os/log.h>
11
+
12
+ #include <atomic>
13
+ #include <ctime>
14
+ #include <sstream>
15
+ #include <string>
16
+ #include <unordered_map>
17
+ #include <utility>
18
+
19
+ namespace at::mps {
20
+
21
+ namespace Profiler {
22
+
23
+ struct BaseInfo {
24
+ // profiling info types
25
+ enum class Type {
26
+ GRAPH,
27
+ KERNEL,
28
+ COPY,
29
+ CPU_FALLBACK,
30
+ };
31
+
32
+ BaseInfo(Type infoType, uint64_t Id, const uintptr_t Handle) :
33
+ type(infoType), profileId(Id), handle(Handle) { }
34
+ virtual ~BaseInfo() = default;
35
+
36
+ // type of profiling info
37
+ Type type;
38
+ // unique profile ID for execution instances of operations or copies
39
+ uint64_t profileId;
40
+ // ID generated by os_signpost
41
+ // since it's possible to use event and interval-based signposts at the
42
+ // same time, we need separate IDs for each.
43
+ os_signpost_id_t eventSignpostId = 0, intervalSignpostId = 0;
44
+ // accumulated GPU time in ms (obtained from CompletionHandler's "GPUEndTime - GPUStartTime")
45
+ std::atomic<double> totalGpuTime{0.0};
46
+ // accumulated Scheduling time in ms (obtained from CompletionHandler's "KernelEndTime - KernelStartTime")
47
+ std::atomic<double> totalSchedulingTime{0.0};
48
+ // indicates if the operation or copy execution has completed
49
+ std::atomic_bool completed{false};
50
+ // handle used to identify the profile info's instance (usually the pointer)
51
+ const uintptr_t handle;
52
+
53
+ virtual const std::string toString(double gpuTime = 0, double schedulingTime = 0) const;
54
+ // builds a string for a tensor (format: Device:ScalarType[tensor.sizes()])
55
+ static std::string buildTensorString(const Tensor& tensor, bool includeBufferId = false) {
56
+ if (tensor.defined()) {
57
+ std::stringstream tensorStr;
58
+ auto deviceType = tensor.device().type();
59
+ tensorStr << c10::DeviceTypeName(deviceType);
60
+ // see comments for INCLUDE_BUFFER_ID
61
+ if (includeBufferId && deviceType == at::kMPS) {
62
+ id<MTLBuffer> buffer = __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
63
+ tensorStr << "(buf#" << (getIMPSAllocator()->getBufferId(buffer))
64
+ << ":" << buffer.retainCount << ")";
65
+ }
66
+ tensorStr << ":"
67
+ << tensor.scalar_type() << tensor.sizes();
68
+ return tensorStr.str();
69
+ } else {
70
+ return "undefined";
71
+ }
72
+ }
73
+ static uint64_t getTime() {
74
+ return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW);
75
+ }
76
+ };
77
+
78
+ struct OperationInfo : BaseInfo {
79
+ OperationInfo(const void* Handle, bool IsGraph, uint64_t Id, const std::string& StrKey) :
80
+ BaseInfo(IsGraph ? Type::GRAPH : Type::KERNEL, Id, uintptr_t(Handle)), strKey(StrKey) { }
81
+
82
+ uint64_t runCount = 0;
83
+ std::string strKey;
84
+
85
+ const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
86
+
87
+ // builds a string for a kernel
88
+ static std::string buildKernelString(const std::string& kernelName,
89
+ const TensorList& tensors,
90
+ bool includeBufferId = false) {
91
+ std::stringstream kernelStr;
92
+ kernelStr << kernelName;
93
+ for (const Tensor& tensor: tensors) {
94
+ kernelStr << ":" << BaseInfo::buildTensorString(tensor, includeBufferId);
95
+ }
96
+ return kernelStr.str();
97
+ }
98
+ };
99
+
100
+ struct CpuFbInfo : BaseInfo {
101
+ CpuFbInfo(uint64_t Id, const std::string& OpName) :
102
+ BaseInfo(Type::CPU_FALLBACK, Id, 0), opName(OpName) { }
103
+
104
+ uint64_t runCount = 0;
105
+ // the current and total overhead of copies in bytes required to convert the Op's
106
+ // input tensors from MPS to CPU and then output from CPU back to MPS
107
+ size_t currentCopyOverhead = 0;
108
+ size_t totalCopyOverhead = 0;
109
+ std::string opName;
110
+ std::string strKey;
111
+ uint64_t startTime = 0;
112
+
113
+ const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
114
+
115
+ void updateCopyOverhead(const TensorList& tensors) {
116
+ currentCopyOverhead = 0;
117
+ for (const Tensor& tensor: tensors) {
118
+ if (tensor.defined()) {
119
+ currentCopyOverhead += tensor.nbytes();
120
+ }
121
+ }
122
+ totalCopyOverhead += currentCopyOverhead;
123
+ }
124
+ };
125
+
126
+ struct CopyInfo : BaseInfo {
127
+ enum class Kind {
128
+ MPS_TO_MPS,
129
+ MPS_TO_CPU,
130
+ CPU_TO_MPS,
131
+ };
132
+
133
+ CopyInfo(const void* Handle, size_t Length, uint64_t Id, bool IsNonBlocking, bool UsesBlitter) :
134
+ BaseInfo(Type::COPY, Id, uintptr_t(Handle)), kind(Kind::MPS_TO_MPS),
135
+ length(Length), isNonBlocking(IsNonBlocking), usesBlitter(UsesBlitter) { }
136
+
137
+ Kind kind;
138
+ size_t length;
139
+ bool isNonBlocking;
140
+ bool usesBlitter;
141
+ std::string srcStrKey;
142
+ std::string dstStrKey;
143
+ // for copies that don't use blitters, we measure CPU time
144
+ uint64_t startTime = 0;
145
+
146
+ const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
147
+
148
+ static std::string buildTensorString(const void* buffer, const OptionalTensorRef tensor, bool includeBufferId = false);
149
+
150
+ static bool isStorageOnMPS(const void* buffer, const OptionalTensorRef tensor) {
151
+ if (tensor.has_value()) {
152
+ return tensor->device().type() == at::kMPS;
153
+ }
154
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(buffer);
155
+ // getUnalignedBufferSize() returns -1 if input buffer is not on MPS device
156
+ return getIMPSAllocator()->getUnalignedBufferSize(buffer) >= 0;
157
+ }
158
+
159
+ static Kind getCopyKind(const void* srcBuffer, const void* dstBuffer,
160
+ const OptionalTensorRef srcTensor, const OptionalTensorRef dstTensor) {
161
+ const bool isSrcOnMPS = isStorageOnMPS(srcBuffer, srcTensor);
162
+ const bool isDstOnMPS = isStorageOnMPS(dstBuffer, dstTensor);
163
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isSrcOnMPS || isDstOnMPS);
164
+ if (isSrcOnMPS && !isDstOnMPS) {
165
+ return Kind::MPS_TO_CPU;
166
+ } else if (!isSrcOnMPS && isDstOnMPS) {
167
+ return Kind::CPU_TO_MPS;
168
+ }
169
+ return Kind::MPS_TO_MPS;
170
+ }
171
+ };
172
+
173
+ struct CopyStat : CopyInfo {
174
+ explicit CopyStat(std::string CopyKindStr) :
175
+ CopyInfo(nullptr, 0, 0, false, false), kindStr(std::move(CopyKindStr)) {}
176
+ // total number of copies
177
+ size_t totalCount = 0;
178
+ // number of Scalar copies (i.e., less than sizeof(int64))
179
+ size_t scalarsCount = 0;
180
+ // number of blocking copies (i.e., require syncing to GPU)
181
+ size_t blockingCount = 0;
182
+ // number of copies that used memcpy(), instead of Metal Blit Encoder
183
+ size_t memcpyCount = 0;
184
+ // accumulated GPU time in ms for the scalar copies
185
+ std::atomic<double> scalarsGpuTime{0.0};
186
+ // copy kind in string type
187
+ std::string kindStr;
188
+ };
189
+
190
+ class MPSProfiler {
191
+ public:
192
+ // lower 16 bits used for profiler options
193
+ enum ProfileOptions : uint32_t {
194
+ OPTIONS_NONE = 0,
195
+ // ALL_* means, all signpost types (RUN_OPERATION|BLIT_COPY|CPU_FALLBACK, etc.)
196
+ // (used for convenience to not compute bit flags by OR-ing manually)
197
+ // trace all signpost types using events
198
+ ALL_SIGNPOST_EVENTS = (1 << 0),
199
+ // trace all signpost types using intervals
200
+ ALL_SIGNPOST_INTERVALS = (1 << 1),
201
+ // always wait for command buffer to finish executing after each commit
202
+ WAIT_UNTIL_COMPLETED = (1 << 2),
203
+ // for interval-based signposts, include the scheduling portion of
204
+ // Graph/Kernel/Copy executions as well.
205
+ // if flag is disable, only "GPU run time" is included in interval,
206
+ // and not schedule time.
207
+ INCLUDE_SCHEDULE_INTERVAL = (1 << 3),
208
+
209
+ // use these if you need to trace signposts types individually (rarely required)
210
+ // trace signpost using intervals
211
+ USE_INTERVALS = (1 << 4),
212
+ // trace signpost by emitting events
213
+ USE_EVENTS = (1 << 5),
214
+ // used for sanity check (Change this when new option added)
215
+ OPTIONS_COUNT = (USE_EVENTS << 1) - 1,
216
+ };
217
+
218
+ // when adding new types, #define the type string in MPSProfiler.mm as well.
219
+ // upper 16 bits used for event types
220
+ enum SignpostTypes : uint32_t {
221
+ SIGNPOST_NONE = 0,
222
+ // trace signposts for PyTorch operation executions
223
+ RUN_OPERATION = (1 << 16),
224
+ // trace signposts for blitter copies
225
+ BLIT_COPY = (1 << 17),
226
+ // trace signposts for ops that fall back on CPU
227
+ CPU_FALLBACK = (1 << 18),
228
+ // used for sanity check (Change this when new type added)
229
+ SIGNPOST_COUNT = (CPU_FALLBACK << 1) - 1,
230
+ };
231
+
232
+ enum LogOptions : uint32_t {
233
+ LOG_NONE = 0,
234
+
235
+ // Info logging options during execution
236
+ // -------------------------------------
237
+ // prints operation info (id/key/run_count) during execution
238
+ OPERATION_INFO = (1 << 0),
239
+ // prints copy info (src/dst tensors/buffers, size, etc.) during execution
240
+ COPY_INFO = (1 << 1),
241
+ // prints CPU Fallback info (id/runCount/opName/copyOverhead) during execution
242
+ CPU_FALLBACK_INFO = (1 << 2),
243
+
244
+ // Profiling Statistics logging options when process terminates
245
+ // ------------------------------------------------------------
246
+ // prints all stats (OPERATION_STATS, COPY_STATS, CPU_FALLBACK_STATS) before process terminates
247
+ // this is convenient to not combine following stats bit flags manually
248
+ ALL_STATS = (1 << 3),
249
+ // prints operation stats (GPU times, run count, etc.) before process terminates
250
+ OPERATION_STATS = (1 << 4),
251
+ // prints copies stats (GPU times, copy kinds, sizes, etc.) before process terminates
252
+ COPY_STATS = (1 << 5),
253
+ // prints CPU Fallback stats (CPU times, run times, size of MPS<->CPU copies
254
+ // for tensors, etc.) before process terminates
255
+ CPU_FALLBACK_STATS = (1 << 6),
256
+
257
+ // Metadata format options when logging the info
258
+ // ---------------------------------------------
259
+ // if enabled, includes GPU run time in metadata (i.e., GPUEndTime-GPUStartTime
260
+ // from Metal Command Buffers) (e.g., [GPU=0.324 ms])
261
+ INCLUDE_GPU_TIME = (1 << 7),
262
+ // if enabled, includes GPU scheduling time in metadata separately
263
+ // (i.e., KernelEndTime-KernelStartTime from Metal Command Buffers)
264
+ // e.g., [GPU=0.324 ms, KRNL=0.036 ms]
265
+ INCLUDE_KERNEL_TIME = (1 << 8),
266
+ // if enabled, includes the unique buffer ID in metadata for the storage
267
+ // of a tensor that was allocated on MPSAllocator. This is useful (along with
268
+ // the EV "PYTORCH_DEBUG_MPS_ALLOCATOR") to identify buffers that are involved
269
+ // with various operations.
270
+ INCLUDE_BUFFER_ID = (1 << 9),
271
+
272
+ // used for sanity check (Change this when new option added)
273
+ LOG_COUNT = (INCLUDE_BUFFER_ID << 1) - 1,
274
+ };
275
+
276
+ explicit MPSProfiler();
277
+ ~MPSProfiler();
278
+
279
+ // the handle is either "MPSGraph*" or "id<MTLComputePipelineState>" for Metal Kernels
280
+ // the beginProfile*() functions return a profileId which is unique per graph/kernel/copy
281
+ uint64_t beginProfileKernel(const void* handle, const std::string& strKey, bool isGraph);
282
+ uint64_t beginProfileKernel(const void* handle, const std::string& kernelName, const TensorList& tensors);
283
+ uint64_t beginProfileCopy(const void* srcBuffer, const void* dstBuffer,
284
+ const OptionalTensorRef srcTensor,
285
+ const OptionalTensorRef dstTensor,
286
+ size_t length, bool isNonBlocking, bool usesBlitter = true);
287
+ uint64_t beginProfileCPUFallback(const std::string& opName, const TensorList& tensors);
288
+ void beginProfileGPUInterval(const void* handle);
289
+
290
+ void endProfileCopy(uint64_t profileId, SyncType syncType);
291
+ void endProfileKernel(const void* handle, SyncType syncType = SyncType::NONE);
292
+ void endProfileCPUFallback(const std::string& opName);
293
+
294
+ // these are used to hook into Python bindings for torch.mps.profiler module.
295
+ // this enables generating OS Signpost traces from MPSProfiler on-demand
296
+ // during runtime (instead of environment variables).
297
+ // The "mode" could be either "interval", "event", or both "interval,event"
298
+ // for interval-based and/or event-based signpost tracing.
299
+ void StartTrace(const std::string& mode, bool waitUntilCompleted);
300
+ void StopTrace();
301
+
302
+ // Abstractions for GPU trace capturing
303
+ bool isCaptureEnabled() const;
304
+ bool isCapturing() const;
305
+ void startCapture(const std::string& name, MPSStream* stream = nullptr);
306
+ void stopCapture(MPSStream* stream = nullptr);
307
+
308
+ // convenience functions to indicate whether signpost tracing or
309
+ // logging are enabled for the SignpostTypes
310
+ bool isOperationProfilingEnabled() const {
311
+ return (m_signpost_types & SignpostTypes::RUN_OPERATION) ||
312
+ (m_log_options & (LogOptions::OPERATION_INFO | LogOptions::OPERATION_STATS));
313
+ }
314
+ bool isCopyProfilingEnabled() const {
315
+ return (m_signpost_types & SignpostTypes::BLIT_COPY) ||
316
+ (m_log_options & (LogOptions::COPY_INFO | LogOptions::COPY_STATS));
317
+ }
318
+ bool isCPUFallbackProfilingEnabled() const {
319
+ return (m_signpost_types & SignpostTypes::CPU_FALLBACK) ||
320
+ (m_log_options & (LogOptions::CPU_FALLBACK_INFO | LogOptions::CPU_FALLBACK_STATS));
321
+ }
322
+ bool isSignpostTracingEnabled() const {
323
+ return (m_signpost_types != SignpostTypes::SIGNPOST_NONE);
324
+ }
325
+
326
+ private:
327
+ // indicates what type of signpost types are enabled and traced by MPS profiler.
328
+ uint32_t m_signpost_types = 0;
329
+ uint32_t m_profile_options = 0;
330
+ uint32_t m_log_options = 0;
331
+ uint64_t m_kernel_counter = 0;
332
+ uint64_t m_graph_counter = 0;
333
+ uint64_t m_cpu_fb_counter = 0;
334
+ uint64_t m_copy_counter = 0;
335
+ // technically, it's possible to trace both events and intervals at the same time
336
+ // so we use separate os_log categories for them
337
+ os_log_t m_os_log_events;
338
+ os_log_t m_os_log_intervals;
339
+ // stats logging could run either from destructor or signal handler
340
+ // so this is used to check if logging has already started.
341
+ std::atomic_bool hasLoggedStats{false};
342
+ // indicates there are pending completionHandler callbacks that haven't been called yet.
343
+ std::atomic_bool hasPendingCompletionHandlers{false};
344
+ // used to capture sigint signal to log profiling stats
345
+ static struct sigaction currentSigint, previousSigint;
346
+
347
+ // We use the following lists for two reasons:
348
+ // 1- for interval-based signposts the "begin" point won't be in same function
349
+ // as the "end" point where we need to be able to retrieve signpost's info
350
+ // 2- if Operations info need to be logged when process ends using LogOptions::OPERATION_INFO.
351
+
352
+ // the pointer key for this map is either "MPSGraph*" or "id<MTLComputePipelineState>" for Metal Kernels
353
+ // this list is retained and could be logged along with aggregate profiling numbers when the process ends.
354
+ std::unordered_map<uintptr_t, std::unique_ptr<OperationInfo>> m_op_info_list{};
355
+ // the string key for this map is the op name that we fall back to execute on CPU
356
+ // this list is retained and could be logged along with aggregate profiling numbers when the process ends.
357
+ std::unordered_map<std::string, std::unique_ptr<CpuFbInfo>> m_cpu_fb_info_list{};
358
+ // this list contains the info for copies, and its key is the unique profileId
359
+ // which is generated from m_copy_counter
360
+ // The copyInfo list is not retained.
361
+ std::unordered_map<uint64_t, std::unique_ptr<CopyInfo>> m_copy_info_list{};
362
+ // a short list that contains copy stats
363
+ std::unordered_map<CopyInfo::Kind, std::unique_ptr<CopyStat>> m_copy_stat_list{};
364
+
365
+ mutable MTLCaptureManager *captureManager = nil;
366
+ unsigned captureCount = 0;
367
+
368
+ void initialize();
369
+ void beginProfileExecution(BaseInfo& info, bool cpuExecution = false);
370
+ void endProfileExecution(BaseInfo& info, os_signpost_id_t event_signpost_id,
371
+ os_signpost_id_t interval_signpost_id,
372
+ double gpuTime, double schedulingTime);
373
+ void addProfilerScheduledHandler(BaseInfo& info);
374
+ void addProfilerCompletedHandler(BaseInfo& info, SyncType syncType);
375
+ void emitSignpostEvent(SignpostTypes signpost_type, os_signpost_id_t signpost_id,
376
+ const std::string& msg) const;
377
+ void beginSignpostInterval(SignpostTypes signpost_type, os_signpost_id_t signpost_id,
378
+ const std::string& msg) const;
379
+ void endSignpostInterval(SignpostTypes signpost_type, os_signpost_id_t signpost_id) const;
380
+
381
+ void updateCopyStats(const CopyInfo& copyInfo, double gpuTime, double schedulingTime);
382
+ // returns true if logging the profiling info "during the execution" is enabled
383
+ bool isProfileInfoLoggingEnabled(BaseInfo::Type infoType, bool isExecutionEnded);
384
+ // logs all the profiling stats that are enabled
385
+ void logProfilingStats();
386
+ // logs kernel profiling stats when the process ends.
387
+ void logOperationsProfilingStats(std::FILE* f) const;
388
+ // logs CPU Fallback profiling stats when the process ends.
389
+ void logCPUFallbackProfilingStats(std::FILE* f) const;
390
+ // logs copy profiling stats when the process ends.
391
+ void logCopyProfilingStats(std::FILE* f) const;
392
+
393
+ os_signpost_id_t generateSignpostId(os_signpost_type_t signpostType, const void* ptr = nullptr);
394
+ static SignpostTypes getSignpostType(BaseInfo::Type infoType);
395
+ static void handleIntSignal(int signal);
396
+ };
397
+
398
+ } // namespace Profiler
399
+
400
+ Profiler::MPSProfiler& getMPSProfiler();
401
+
402
+ } // namespace at::mps
.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSStream.h ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2022 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <cstdint>
6
+ #include <utility>
7
+
8
+ #include <c10/core/DeviceGuard.h>
9
+ #include <c10/util/Exception.h>
10
+ #include <c10/core/Stream.h>
11
+ #include <ATen/mps/MPSDevice.h>
12
+
13
+ #ifdef __OBJC__
14
+ #include <Foundation/Foundation.h>
15
+ #include <Metal/Metal.h>
16
+ #include <MetalPerformanceShaders/MetalPerformanceShaders.h>
17
+ #include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
18
+ typedef id<MTLCommandQueue> MTLCommandQueue_t;
19
+ typedef id<MTLCommandBuffer> MTLCommandBuffer_t;
20
+ typedef id<MTLComputeCommandEncoder> MTLComputeCommandEncoder_t;
21
+ typedef id<MTLSharedEvent> MTLSharedEvent_t;
22
+ typedef id<MTLDevice> MTLDevice_t;
23
+ #else
24
+ typedef void* MTLCommandQueue_t;
25
+ typedef void* MTLCommandQueue;
26
+ typedef void* MTLCommandBuffer_t;
27
+ typedef void* MTLCommandBuffer;
28
+ typedef void* MTLComputeCommandEncoder_t;
29
+ typedef void* MTLSharedEvent_t;
30
+ typedef void* dispatch_queue_t;
31
+ typedef void* MTLDevice_t;
32
+ #define nil NULL;
33
+ #endif
34
+
35
+
36
+ namespace at::mps {
37
+
38
+ //-----------------------------------------------------------------
39
+ // MPSStream
40
+ //-----------------------------------------------------------------
41
+
42
+ enum class SyncType {
43
+ NONE, // no commit to command buffer
44
+ COMMIT, // commit and flush the command buffer
45
+ COMMIT_AND_WAIT, // flush and wait for command buffer execution to finish
46
+ COMMIT_AND_CONTINUE,// commit and continue with a new underlying command buffer
47
+ COMMIT_ADAPTIVE, // commit adaptively based on available memory
48
+ };
49
+
50
+ class TORCH_API MPSStream
51
+ {
52
+ public:
53
+ enum Unchecked { UNCHECKED };
54
+
55
+ /// Construct a MPSStream from a Stream. This construction is checked,
56
+ /// and will raise an error if the Stream is not, in fact, a MPS stream.
57
+ explicit MPSStream(Stream stream);
58
+
59
+ ~MPSStream();
60
+ MTLCommandQueue_t commandQueue() const { return _commandQueue; };
61
+ dispatch_queue_t queue() const { return _serialQueue; }
62
+
63
+ MPSCommandBuffer* commandBuffer();
64
+ MTLComputeCommandEncoder_t commandEncoder();
65
+ void endKernelCoalescing();
66
+ void synchronize(SyncType syncType);
67
+ void fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t offset, SyncType syncType = SyncType::NONE);
68
+ void copy(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
69
+ size_t length, size_t srcOffset, size_t dstOffset,
70
+ uint64_t profileId, SyncType syncType = SyncType::NONE);
71
+ void copy_and_sync(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
72
+ size_t length, size_t srcOffset, size_t dstOffset,
73
+ bool non_blocking, uint64_t profileId);
74
+ void executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results, SyncType syncType = SyncType::NONE);
75
+ void addCompletedHandler(MTLCommandBufferHandler block);
76
+
77
+ /// Get the MPS device index that this stream is associated with.
78
+ c10::DeviceIndex device_index() const { return _stream.device_index(); }
79
+
80
+ MTLCommandQueue_t stream() const { return _commandQueue; };
81
+
82
+ MTLDevice_t device() const { return [_commandQueue device];}
83
+
84
+ /// Explicit conversion to Stream.
85
+ Stream unwrap() const { return _stream; }
86
+
87
+ private:
88
+ Stream _stream;
89
+ MTLCommandQueue_t _commandQueue = nil;
90
+ MPSCommandBuffer* _commandBuffer = nil;
91
+ MPSCommandBuffer* _prevCommandBuffer = nil;
92
+ MTLComputeCommandEncoder_t _commandEncoder = nil;
93
+ MPSGraphExecutionDescriptor *_executionDescriptor = nil;
94
+ MPSGraphCompilationDescriptor *_compilationDescriptor = nil;
95
+ dispatch_queue_t _serialQueue = nullptr;
96
+ // CommitAndContinue is enabled by default
97
+ bool _enableCommitAndContinue = true;
98
+
99
+ // use synchronize() to access any of these commit functions outside MPSStream
100
+ void commit();
101
+ void commitAndWait();
102
+ void commitAndContinue();
103
+ void flush();
104
+ };
105
+
106
+ /**
107
+ * Get the current MPS stream
108
+ */
109
+ TORCH_API MPSStream* getCurrentMPSStream();
110
+
111
+ /**
112
+ * Get the default MPS stream
113
+ */
114
+ TORCH_API MPSStream* getDefaultMPSStream();
115
+
116
+ //-----------------------------------------------------------------
117
+ // MPSStreamImpl
118
+ //-----------------------------------------------------------------
119
+
120
+ class TORCH_API MPSStreamImpl
121
+ {
122
+ public:
123
+ /**
124
+ * Gets single instance of the MPSStream.
125
+ */
126
+ static MPSStream* getInstance();
127
+
128
+ private:
129
+ static MPSStream* _stream;
130
+ MPSStreamImpl();
131
+ };
132
+
133
+ } // namespace at::mps
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/CatKernel.h ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/native/DispatchStub.h>
5
+ #include <ATen/core/IListRef.h>
6
+
7
+ namespace at::native {
8
+
9
+ using cat_serial_fn = void(*)(const Tensor &, const MaterializedITensorListRef&, int64_t);
10
+ DECLARE_DISPATCH(cat_serial_fn, cat_serial_stub);
11
+
12
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Loops.h ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // This file provides two functions to help write elementwise kernels:
4
+ //
5
+ // cpu_kernel(TensorIterator iter, <lambda>)
6
+ // cpu_kernel_vec(TensorIterator iter, <lambda>, <vec_lambda>)
7
+ //
8
+ // Both functions may generate vectorized code. The cpu_kernel implementation
9
+ // relies on the compiler's auto-vectorization. The cpu_kernel_vec
10
+ // implementation uses x86 SIMD intrinsics when available. These functions
11
+ // are only intended to be used in the ATen/native/cpu subdirectory, since files
12
+ // in other directories are not compiled with AVX/AVX2 enabled. See README.md
13
+ // for more details.
14
+ //
15
+ // For example, to write a multiplication kernel for float:
16
+ //
17
+ // cpu_kernel(iter, [](float a, float b) { return a * b; });
18
+ //
19
+ // Or you may write:
20
+ //
21
+ // cpu_kernel_vec(iter,
22
+ // [](float a, float b) { return a * b; },
23
+ // [](Vectorized<float> a, Vectorized<float> b) { return a * b; });
24
+ //
25
+ // See BinaryOpsKernel.cpp for the complete implementation
26
+ //
27
+ //
28
+
29
+ #include <cstdint>
30
+ #include <c10/util/C++17.h>
31
+ #include <c10/util/Load.h>
32
+ #include <c10/util/irange.h>
33
+ #include <ATen/detail/FunctionTraits.h>
34
+ #include <ATen/native/cpu/IsContiguous.h>
35
+ #include <ATen/native/TensorIterator.h>
36
+ #include <ATen/native/TensorIteratorDynamicCasting.h>
37
+ #include <ATen/cpu/vec/vec.h>
38
+
39
+ #include <utility>
40
+
41
+ namespace at::native { inline namespace CPU_CAPABILITY {
42
+
43
+ using namespace vec;
44
+
45
+ template <typename traits, std::size_t... INDEX>
46
+ typename traits::ArgsTuple
47
+ dereference_impl(char* C10_RESTRICT data[], const int64_t* strides, int64_t i,
48
+ std::index_sequence<INDEX...>) {
49
+ return std::make_tuple(
50
+ c10::load<typename traits::template arg<INDEX>::type>(
51
+ data[INDEX] + i * strides[INDEX])...);
52
+ }
53
+
54
+ template <typename traits>
55
+ typename traits::ArgsTuple
56
+ dereference(char* C10_RESTRICT data[], const int64_t* strides, int64_t i) {
57
+ using Indices = std::make_index_sequence<traits::arity>;
58
+ return dereference_impl<traits>(data, strides, i, Indices{});
59
+ }
60
+
61
+ template <typename traits, std::size_t... INDEX>
62
+ typename traits::ArgsTuple
63
+ dereference_vec_impl(char* C10_RESTRICT data[],
64
+ const typename traits::result_type& opt_scalar,
65
+ size_t S,
66
+ int64_t i,
67
+ std::index_sequence<INDEX...>) {
68
+ using Vec = typename traits::result_type;
69
+ using scalar_t = typename Vec::value_type;
70
+ return std::make_tuple(
71
+ S == INDEX + 1 ?
72
+ opt_scalar :
73
+ Vec::loadu(data[INDEX] + i * sizeof(scalar_t))...);
74
+ }
75
+
76
+ template <typename traits>
77
+ typename traits::ArgsTuple
78
+ dereference_vec(char* C10_RESTRICT data[], const typename traits::result_type& opt_scalar, size_t S, int64_t i) {
79
+ using Indices = std::make_index_sequence<traits::arity>;
80
+ return dereference_vec_impl<traits>(data, opt_scalar, S, i, Indices{});
81
+ }
82
+
83
+ template <typename func_t,
84
+ std::enable_if_t<!std::is_void_v<typename function_traits<func_t>::result_type>>* = nullptr>
85
+ inline void
86
+ execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) {
87
+ using traits = function_traits<func_t>;
88
+ using result_type = typename traits::result_type;
89
+ for (; i < n; i++) {
90
+ result_type* out_ptr = (result_type*)(data[0] + i * strides[0]);
91
+ *out_ptr = c10::guts::apply(op, dereference<traits>(
92
+ &data[1],
93
+ &strides[1],
94
+ i));
95
+ }
96
+ }
97
+
98
+ template <typename func_t,
99
+ std::enable_if_t<std::is_void_v<typename function_traits<func_t>::result_type>>* = nullptr>
100
+ inline void
101
+ execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) {
102
+ using traits = function_traits<func_t>;
103
+ for (; i < n; i++) {
104
+ c10::guts::apply(op, dereference<traits>(
105
+ &data[0],
106
+ &strides[0],
107
+ i));
108
+ }
109
+ }
110
+
111
+ // Basic loop operation (one output, N inputs). May be auto-vectorized
112
+ // by the compiler. Supports inputs and outputs of different types.
113
+ template <typename func_t>
114
+ inline void
115
+ basic_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) {
116
+ using traits = function_traits<func_t>;
117
+ constexpr int ntensors = traits::arity + 1;
118
+
119
+ // Copying strides to temporary array helps auto vectorization in older GCC
120
+ // versions.
121
+ int64_t strides[ntensors];
122
+ for (const auto arg : c10::irange(ntensors)) {
123
+ strides[arg] = strides_[arg];
124
+ }
125
+
126
+ execute_op(data, strides, i, n, std::forward<func_t>(op));
127
+ }
128
+
129
+ // the recursive variadic template for iterating over the returned tuple
130
+ template<class T, size_t N>
131
+ struct TupleOutput {
132
+ static void handle(char *C10_RESTRICT data[], const int64_t *strides, int64_t i,
133
+ const T &tuple) {
134
+ TupleOutput<T, N - 1>::handle(data, strides, i, tuple);
135
+
136
+ auto output = std::get<N - 1>(tuple);
137
+ using output_type = decltype(output);
138
+ output_type * out_ptr = (output_type *)(data[N - 1] + i * strides[N - 1]);
139
+ *out_ptr = output;
140
+ }
141
+ };
142
+
143
+ // Base case for the above recursive template
144
+ template<class T>
145
+ struct TupleOutput<T, 1> {
146
+ static void handle(char *C10_RESTRICT data[], const int64_t *strides, int64_t i,
147
+ const T &tuple) {
148
+ auto output = std::get<0>(tuple);
149
+ using output_type = decltype(output);
150
+ output_type* out_ptr = (output_type *)(data[0] + i * strides[0]);
151
+ *out_ptr = output;
152
+ }
153
+ };
154
+
155
+ template<class... Args>
156
+ void handle_tuple_outputs(char* C10_RESTRICT data[],
157
+ const int64_t* strides,
158
+ int64_t i,
159
+ const std::tuple<Args...> &tuple) {
160
+ TupleOutput<decltype(tuple), sizeof...(Args)>::handle(data, strides, i, tuple);
161
+ }
162
+
163
+ // Loop operation for `cpu_kernel_multiple_outputs`.
164
+ // 1. Use `c10::guts::apply` to make dynamic method invocation
165
+ // for the lambda passed in `cpu_kernel_multiple_outputs`.
166
+ // 2. Iterate over the members of the returned tuple, set the corresponding
167
+ // output tensor by the tuple member in `handle_tuple_outputs` function.
168
+ template <typename func_t>
169
+ inline void
170
+ multiple_outputs_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) {
171
+ using traits = function_traits<func_t>;
172
+
173
+ using result_type = typename traits::result_type;
174
+ constexpr int num_outputs = std::tuple_size<result_type>::value;
175
+ constexpr int ntensors = traits::arity + num_outputs;
176
+
177
+ // Copying strides to temporary array helps auto vectorization in older GCC
178
+ // versions.
179
+ int64_t strides[ntensors];
180
+ for (const auto arg : c10::irange(ntensors)) {
181
+ strides[arg] = strides_[arg];
182
+ }
183
+
184
+ for (; i < n; i++) {
185
+ auto output = c10::guts::apply(op, dereference<traits>(
186
+ &data[num_outputs],
187
+ &strides[num_outputs],
188
+ i));
189
+ handle_tuple_outputs(data, strides, i, output);
190
+ }
191
+ }
192
+
193
+ // Explicitly vectorized loop implementation. All inputs and outputs must be
194
+ // the same type and contiguous with one exception: a single input may be
195
+ // a scalar (stride 0). It's position is indicated by the argument `S`. If `S`
196
+ // is 0, then there are no scalar inputs.
197
+ template <typename func_t, typename vec_func_t>
198
+ inline void
199
+ vectorized_loop(char** C10_RESTRICT data_, int64_t n, int64_t S, func_t&& op, vec_func_t&& vop) {
200
+ using traits = function_traits<vec_func_t>;
201
+ using scalar_t = typename function_traits<func_t>::result_type;
202
+ using Vec = Vectorized<scalar_t>;
203
+ constexpr int ntensors = traits::arity + 1;
204
+
205
+ char* C10_RESTRICT data[ntensors];
206
+ for (const auto arg : c10::irange(ntensors)) {
207
+ data[arg] = data_[arg];
208
+ }
209
+
210
+ Vec opt_scalar = Vec(S > 0 ? *(scalar_t*)data[S] : scalar_t(0));
211
+ int64_t i = 0;
212
+ for (; i <= n - 2 * Vec::size(); i += 2 * Vec::size()) {
213
+ auto args1 = dereference_vec<traits>(&data[1], opt_scalar, S, i);
214
+ auto args2 = dereference_vec<traits>(&data[1], opt_scalar, S, i + Vec::size());
215
+ auto out1 = c10::guts::apply(vop, std::move(args1));
216
+ auto out2 = c10::guts::apply(vop, std::move(args2));
217
+ out1.store(data[0] + i * sizeof(scalar_t));
218
+ out2.store(data[0] + (i + Vec::size()) * sizeof(scalar_t));
219
+ }
220
+ if (i < n) {
221
+ int64_t strides[ntensors];
222
+ for (const auto arg : c10::irange(ntensors)) {
223
+ strides[arg] = (S > 0 && arg == S) ? 0 : sizeof(scalar_t);
224
+ }
225
+ basic_loop(data, strides, i, n, std::forward<func_t>(op));
226
+ }
227
+ }
228
+
229
+
230
+ template <typename traits, typename cb_t>
231
+ inline void unroll_contiguous_scalar_checks(
232
+ const int64_t* /*strides*/,
233
+ std::index_sequence<>,
234
+ cb_t&& cb) {
235
+ cb(0);
236
+ }
237
+
238
+ template <typename traits, typename cb_t, size_t INDEX0, size_t ...INDEX>
239
+ inline void unroll_contiguous_scalar_checks(
240
+ const int64_t* strides,
241
+ std::index_sequence<INDEX0, INDEX...>,
242
+ cb_t&& cb) {
243
+ if (is_contiguous_scalar<traits, INDEX0 + 1>(strides)) {
244
+ cb(INDEX0 + 1);
245
+ } else {
246
+ unroll_contiguous_scalar_checks<traits>(strides, std::index_sequence<INDEX...>{}, std::forward<cb_t>(cb));
247
+ }
248
+ }
249
+
250
+ template <typename op_t, typename vop_t>
251
+ struct VectorizedLoop2d {
252
+ op_t op;
253
+ vop_t vop;
254
+
255
+ using traits = function_traits<op_t>;
256
+ static constexpr int ntensors = traits::arity + 1;
257
+ using data_t = std::array<char*, ntensors>;
258
+
259
+ VectorizedLoop2d(op_t op, vop_t vop):
260
+ op(std::move(op)), vop(std::move(vop)) {}
261
+
262
+ static void advance(data_t &data, const int64_t *outer_strides) {
263
+ for (const auto arg : c10::irange(data.size())) {
264
+ data[arg] += outer_strides[arg];
265
+ }
266
+ }
267
+
268
+ void operator()(char** base, const int64_t *strides, int64_t size0, int64_t size1) {
269
+ data_t data;
270
+ std::copy_n(base, ntensors, data.data());
271
+ const int64_t *outer_strides = &strides[ntensors];
272
+
273
+ if (is_contiguous<traits>(strides)) {
274
+ for (const auto i C10_UNUSED : c10::irange(size1)) {
275
+ vectorized_loop(data.data(), size0, 0, op, vop);
276
+ advance(data, outer_strides);
277
+ }
278
+ } else {
279
+ using Indices = std::make_index_sequence<traits::arity>;
280
+ unroll_contiguous_scalar_checks<traits>(strides, Indices{}, [&](size_t idx) {
281
+ if (idx) {
282
+ for (const auto i C10_UNUSED : c10::irange(size1)) {
283
+ vectorized_loop(data.data(), size0, idx, op, vop);
284
+ advance(data, outer_strides);
285
+ }
286
+ } else {
287
+ for (const auto i C10_UNUSED : c10::irange(size1)) {
288
+ basic_loop(data.data(), strides, 0, size0, op);
289
+ advance(data, outer_strides);
290
+ }
291
+ }
292
+ });
293
+ }
294
+ }
295
+ };
296
+
297
+ template <typename op_t, typename vop_t>
298
+ VectorizedLoop2d<op_t, vop_t> make_vectorized_loop2d(
299
+ op_t &&op, vop_t &&vop) {
300
+ return VectorizedLoop2d<op_t, vop_t>(std::forward<op_t>(op), std::forward<vop_t>(vop));
301
+ }
302
+
303
+ template <typename func_t>
304
+ void cpu_kernel(TensorIteratorBase& iter, func_t&& op, int64_t grain_size = at::internal::GRAIN_SIZE) {
305
+ using traits = function_traits<func_t>;
306
+ // this could be extended to work with void return types
307
+ TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
308
+ TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
309
+ // dynamic casting not currently supported on CPU
310
+ TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
311
+
312
+ iter.for_each([&](char** data, const int64_t* strides, int64_t n) {
313
+ // basic loop can handle 1d slices with arbitrary strides, and 1d slices is all that
314
+ // iter.for_each is ever sending to the loop lambda
315
+ basic_loop(data, strides, 0, n, op);
316
+ }, grain_size);
317
+ iter.cast_outputs();
318
+ }
319
+
320
+ // This function helps write elementwise kernels that requires multiple outputs.
321
+ // It follows the similar structure of cpu_kernel.
322
+ // Instead of `basic_loop` function, a new `multiple_outputs_loop` function is
323
+ // manipulated to handle multiple return values.
324
+ // For now `needs_dynamic_casting` check is not added as the passed lambda (`func_t`)
325
+ // of `multiple_outputs_loop` returns `std::tuple` instead of `scalar_t`.
326
+ // The `gpu_kernel_multiple_outputs` is also implemented without this check,
327
+ // We could extend `needs_dynamic_casting` to support both `std::tuple` and
328
+ // `thrust::tuple` in the future.
329
+ template <typename func_t>
330
+ void cpu_kernel_multiple_outputs(TensorIteratorBase& iter, func_t&& op, int64_t grain_size = at::internal::GRAIN_SIZE) {
331
+ using traits = function_traits<func_t>;
332
+ TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
333
+
334
+ iter.for_each([&](char** data, const int64_t* strides, int64_t n) {
335
+ multiple_outputs_loop(data, strides, 0, n, op);
336
+ }, grain_size);
337
+ iter.cast_outputs();
338
+ }
339
+
340
+ template <bool check_dynamic_cast=true, typename func_t, typename vec_func_t>
341
+ void cpu_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop, int64_t grain_size = at::internal::GRAIN_SIZE) {
342
+ using traits = function_traits<func_t>;
343
+ // this could be extended to work with void return types
344
+ TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
345
+ TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
346
+ // dynamic casting not currently supported on CPU, but some kernels (like Fill)
347
+ // explicitly dynamic_cast, so we give the opt-out of checking.
348
+ if constexpr (check_dynamic_cast) {
349
+ TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
350
+ }
351
+
352
+ iter.for_each(make_vectorized_loop2d(std::forward<func_t>(op), std::forward<vec_func_t>(vop)), grain_size);
353
+ iter.cast_outputs();
354
+ }
355
+
356
+ template <typename func_t>
357
+ void cpu_serial_kernel(TensorIteratorBase& iter, func_t&& op, const Range& range) {
358
+ using traits = function_traits<func_t>;
359
+ constexpr bool result_void = std::is_void_v<typename traits::result_type>;
360
+ TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity &&
361
+ ((result_void && iter.noutputs() == 0) || (!result_void && iter.noutputs() == 1)));
362
+ // dynamic casting not currently supported on CPU
363
+ TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
364
+
365
+ iter.serial_for_each([&](char** data, const int64_t* strides, int64_t n) {
366
+ basic_loop(data, strides, 0, n, op);
367
+ }, range);
368
+ iter.cast_outputs();
369
+ }
370
+
371
+ template <typename func_t>
372
+ void cpu_serial_kernel(TensorIteratorBase& iter, func_t&& op) {
373
+ cpu_serial_kernel(iter, std::forward<func_t>(op), {0, iter.numel()});
374
+ }
375
+
376
+ template <typename func_t, typename vec_func_t>
377
+ void cpu_serial_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop, const Range& range) {
378
+ using traits = function_traits<func_t>;
379
+ // this could be extended to work with void return types
380
+ TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
381
+ TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
382
+ // dynamic casting not currently supported on CPU
383
+ TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
384
+
385
+ iter.serial_for_each(make_vectorized_loop2d(std::forward<func_t>(op), std::forward<vec_func_t>(vop)), range);
386
+ iter.cast_outputs();
387
+ }
388
+
389
+ template <typename func_t, typename vec_func_t>
390
+ void cpu_serial_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop) {
391
+ cpu_serial_kernel_vec(iter, std::forward<func_t>(op), std::forward<vec_func_t>(vop), {0, iter.numel()});
392
+ }
393
+
394
+ }} // namespace at::native::<anonymous>
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/ReduceUtils.h ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/Parallel.h>
4
+ #include <ATen/NumericUtils.h>
5
+ #include <ATen/cpu/vec/vec.h>
6
+ #include <ATen/cpu/vec/functional.h>
7
+ #include <ATen/native/ReductionType.h>
8
+ #include <c10/util/irange.h>
9
+ #include <ATen/OpMathType.h>
10
+ #include <ATen/native/cpu/utils.h>
11
+ #include <ATen/OpMathType.h>
12
+
13
+ namespace at::native {
14
+ inline namespace CPU_CAPABILITY {
15
+
16
+ using namespace vec;
17
+
18
+ #define AT_DISPATCH_REDUCTION_TYPES(op, ...) \
19
+ [&] { \
20
+ switch (op) { \
21
+ case ReductionType::SUM: { \
22
+ static constexpr auto reduce = ReductionType::SUM; \
23
+ return __VA_ARGS__(); \
24
+ } \
25
+ case ReductionType::MEAN: { \
26
+ static constexpr auto reduce = ReductionType::MEAN; \
27
+ return __VA_ARGS__(); \
28
+ } \
29
+ case ReductionType::MIN: { \
30
+ static constexpr auto reduce = ReductionType::MIN; \
31
+ return __VA_ARGS__(); \
32
+ } \
33
+ case ReductionType::MAX: { \
34
+ static constexpr auto reduce = ReductionType::MAX; \
35
+ return __VA_ARGS__(); \
36
+ } \
37
+ case ReductionType::PROD: { \
38
+ static constexpr auto reduce = ReductionType::PROD; \
39
+ return __VA_ARGS__(); \
40
+ } \
41
+ } \
42
+ }()
43
+
44
+ template <typename scalar_t, ReductionType reduce>
45
+ inline vec_scalar_t<scalar_t> init_value() {
46
+ using acc_t = vec_scalar_t<scalar_t>;
47
+ acc_t val;
48
+ if (reduce == ReductionType::SUM ||
49
+ reduce == ReductionType::MEAN) {
50
+ val = static_cast<acc_t>(0);
51
+ } else if (reduce == ReductionType::PROD) {
52
+ val = static_cast<acc_t>(1);
53
+ } else if (reduce == ReductionType::MAX) {
54
+ val = -std::numeric_limits<acc_t>::infinity();
55
+ } else {
56
+ TORCH_INTERNAL_ASSERT(reduce == ReductionType::MIN);
57
+ val = std::numeric_limits<acc_t>::infinity();
58
+ }
59
+ return val;
60
+ }
61
+
62
+ template <typename scalar_t, ReductionType reduce>
63
+ inline vec_scalar_t<scalar_t> init_value(const std::optional<Scalar>& initial) {
64
+ using acc_t = vec_scalar_t<scalar_t>;
65
+ if (initial.has_value()) {
66
+ return initial.value().to<acc_t>();
67
+ } else {
68
+ return init_value<scalar_t, reduce>();
69
+ }
70
+ }
71
+
72
+ template <typename scalar_t>
73
+ inline void init(scalar_t* out, int64_t size, const vec_scalar_t<scalar_t>& val) {
74
+ using Vec = Vectorized<vec_scalar_t<scalar_t>>;
75
+ map<scalar_t>(
76
+ [val](Vec x) { return Vec(val); },
77
+ out,
78
+ out,
79
+ size);
80
+ }
81
+
82
+ template <typename scalar_t, ReductionType reduce>
83
+ inline void init(scalar_t* out, int64_t size, const std::optional<Scalar>& initial) {
84
+ using acc_t = vec_scalar_t<scalar_t>;
85
+ acc_t val = init_value<scalar_t, reduce>(initial);
86
+ init(out, size, val);
87
+ }
88
+
89
+ // overload with `include_self`, used by scatter_reduce
90
+ template <typename scalar_t, ReductionType reduce>
91
+ inline void init(scalar_t* out, int64_t size, bool include_self = false) {
92
+ using acc_t = vec_scalar_t<scalar_t>;
93
+ if (!include_self) {
94
+ acc_t val = init_value<scalar_t, reduce>();
95
+ init(out, size, val);
96
+ }
97
+ }
98
+
99
+ template <typename scalar_t, ReductionType reduce>
100
+ inline void _init(scalar_t* self_ptr, at::opmath_type<scalar_t>* buffer_ptr, int64_t size, bool include_self) {
101
+ if (!include_self) {
102
+ init<at::opmath_type<scalar_t>, reduce>(buffer_ptr, size, include_self);
103
+ } else {
104
+ vec::convert(self_ptr, buffer_ptr, size);
105
+ }
106
+ }
107
+
108
+ template <typename scalar_t>
109
+ inline typename std::enable_if<!std::is_same<scalar_t, Vec2>::value, scalar_t>::type
110
+ _max(const scalar_t& x, const scalar_t& y) {
111
+ return at::_isnan(y) ? y : std::max(x, y);
112
+ }
113
+
114
+ template <typename scalar_t>
115
+ inline Vectorized<scalar_t> _max(const Vectorized<scalar_t>& x, const Vectorized<scalar_t>& y) {
116
+ // vec::maximum propagates NaN
117
+ return vec::maximum(x, y);
118
+ }
119
+
120
+ template <typename vec_t>
121
+ inline typename std::enable_if<std::is_same<vec_t, Vec2>::value, Vec2>::type
122
+ _max(const vec_t& x, const vec_t& y) {
123
+ // vec::maximum propagates NaN
124
+ return maximum(x, y);
125
+ }
126
+
127
+ template <typename scalar_t>
128
+ inline typename std::enable_if<!std::is_same<scalar_t, Vec2>::value, scalar_t>::type
129
+ _min(const scalar_t& x, const scalar_t& y) {
130
+ return at::_isnan(y) ? y : std::min(x, y);
131
+ }
132
+
133
+ template <typename scalar_t>
134
+ inline Vectorized<scalar_t> _min(const Vectorized<scalar_t>& x, const Vectorized<scalar_t>& y) {
135
+ // vec::minimum propagates NaN
136
+ return vec::minimum(x, y);
137
+ }
138
+
139
+ template <typename vec_t>
140
+ inline typename std::enable_if<std::is_same<vec_t, Vec2>::value, Vec2>::type
141
+ _min(const vec_t& x, const vec_t& y) {
142
+ // vec::minimum propagates NaN
143
+ return minimum(x, y);
144
+ }
145
+
146
+ template <typename scalar_t, typename accumut, typename Op,
147
+ typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
148
+ inline void map_acc(
149
+ const Op& vec_fun,
150
+ accumut* output_data,
151
+ const accumut* input_data,
152
+ const scalar_t* input_data2,
153
+ int64_t size) {
154
+ using Vec = vec::Vectorized<scalar_t>;
155
+ using aVec = vec::Vectorized<accumut>;
156
+ int64_t d = 0;
157
+ constexpr int64_t kVecSize = Vec::size();
158
+ constexpr int64_t kaVecSize = aVec::size();
159
+ for (d = 0; d < size - (size % kVecSize); d += kVecSize) {
160
+ Vec data2_vec = Vec::loadu(input_data2 + d);
161
+ auto [data2_avec0, data2_avec1] = convert_to_float<scalar_t>(data2_vec);
162
+ aVec input_vec0 = aVec::loadu(input_data + d);
163
+ aVec input_vec1 = aVec::loadu(input_data + d + kaVecSize);
164
+ vec_fun(input_vec0, data2_avec0).store(output_data + d);
165
+ vec_fun(input_vec1, data2_avec1).store(output_data + d + kaVecSize);
166
+ }
167
+ if (size - d > 0) {
168
+ int64_t tail_size = size - d;
169
+ Vec data2_vec = Vec::loadu(input_data2 + d, tail_size);
170
+ auto [data2_avec0, data2_avec1] = convert_to_float<scalar_t>(data2_vec);
171
+ if (tail_size > kaVecSize) {
172
+ aVec input_vec0 = aVec::loadu(input_data + d);
173
+ aVec input_vec1 = aVec::loadu(input_data + d + kaVecSize, tail_size - kaVecSize);
174
+ vec_fun(input_vec0, data2_avec0).store(output_data + d);
175
+ vec_fun(input_vec1, data2_avec1).store(output_data + d + kaVecSize, tail_size - kaVecSize);
176
+ } else {
177
+ aVec input_vec0 = aVec::loadu(input_data + d, tail_size);
178
+ vec_fun(input_vec0, data2_avec0).store(output_data + d, tail_size);
179
+ }
180
+ }
181
+ }
182
+
183
+ // for Max and Min, propagate NaN:
184
+ template <typename T, ReductionType reduce>
185
+ inline T update(const T& x, const T& y) {
186
+ if (reduce == ReductionType::SUM ||
187
+ reduce == ReductionType::MEAN) {
188
+ return x + y;
189
+ } else if (reduce == ReductionType::PROD) {
190
+ return x * y;
191
+ } else if (reduce == ReductionType::MAX) {
192
+ return _max(x, y);
193
+ } else {
194
+ TORCH_INTERNAL_ASSERT(reduce == ReductionType::MIN);
195
+ return _min(x, y);
196
+ }
197
+ }
198
+
199
+ template <typename scalar_t, ReductionType reduce>
200
+ inline void update(scalar_t* out, const scalar_t* data, int64_t K) {
201
+ using Vec = vec::Vectorized<vec_scalar_t<scalar_t>>;
202
+ map2<scalar_t>(
203
+ [](Vec x, Vec y) { return update<Vec, reduce>(x, y); },
204
+ out,
205
+ out,
206
+ data,
207
+ K);
208
+ }
209
+
210
+ template <typename scalar_t, ReductionType reduce,
211
+ typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
212
+ inline void update(at::opmath_type<scalar_t>* out, const scalar_t* data, int64_t K) {
213
+ using opmath_t = at::opmath_type<scalar_t>;
214
+ using Vec = vec::Vectorized<opmath_t>;
215
+ map_acc<scalar_t, opmath_t>(
216
+ [](Vec x, Vec y) { return update<Vec, reduce>(x, y); },
217
+ out,
218
+ out,
219
+ data,
220
+ K);
221
+ }
222
+
223
+ template <typename scalar_t, ReductionType reduce>
224
+ inline void write(scalar_t* out, int64_t count, int64_t K) {
225
+ using Vec = vec::Vectorized<vec_scalar_t<scalar_t>>;
226
+ if (reduce == ReductionType::MEAN) {
227
+ if (count > 0) {
228
+ vec::map<scalar_t>(
229
+ [count](Vec x) { return x / Vec(count); },
230
+ out,
231
+ out,
232
+ K);
233
+ }
234
+ }
235
+ }
236
+
237
+ } // namespace CPU_CAPABILITY
238
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/AffineQuantizer.h ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/Dispatch.h>
5
+ #include <ATen/native/DispatchStub.h>
6
+ #include <ATen/native/quantized/AffineQuantizerBase.h>
7
+
8
+ namespace at {
9
+ namespace native {
10
+
11
+ Tensor& quantize_tensor_per_tensor_affine(
12
+ const Tensor& rtensor,
13
+ Tensor& qtensor,
14
+ double scale,
15
+ int64_t zero_point);
16
+ Tensor& quantize_tensor_per_channel_affine(
17
+ const Tensor& rtensor,
18
+ Tensor& qtensor,
19
+ const Tensor& scales,
20
+ Tensor zero_points,
21
+ int64_t axis);
22
+
23
+ Tensor& quantize_tensor_per_channel_float_qparams(
24
+ const Tensor& rtensor,
25
+ Tensor& qtensor,
26
+ const Tensor& scales,
27
+ const Tensor& zero_points,
28
+ int64_t axis);
29
+
30
+ Tensor& dequantize_tensor_per_tensor_affine(
31
+ const Tensor& qtensor,
32
+ Tensor& rtensor,
33
+ double scale,
34
+ int64_t zero_point);
35
+ Tensor& dequantize_tensor_per_channel_affine(
36
+ const Tensor& qtensor,
37
+ Tensor& rtensor,
38
+ const Tensor& scales,
39
+ Tensor zero_points,
40
+ int64_t axis);
41
+ Tensor& dequantize_tensor_per_channel_float_qparams(
42
+ const Tensor& qtensor,
43
+ Tensor& rtensor,
44
+ const Tensor& scales,
45
+ const Tensor& zero_points,
46
+ int64_t axis);
47
+
48
+ using quantize_tensor_per_tensor_affine_fn =
49
+ void (*)(const Tensor& rtensor, Tensor& qtensor, double scale, int64_t zero_point);
50
+
51
+ using quantize_tensor_per_channel_affine_fn = void (*)(
52
+ const Tensor& rtensor,
53
+ Tensor& qtensor,
54
+ const Tensor& scales,
55
+ const Tensor& zero_points,
56
+ int64_t axis);
57
+
58
+ using quantize_tensor_per_channel_float_qparams_fn = void (*)(
59
+ const Tensor& rtensor,
60
+ Tensor& qtensor,
61
+ const Tensor& scales,
62
+ const Tensor& zero_points,
63
+ int64_t axis);
64
+
65
+ using dequantize_tensor_per_tensor_affine_fn =
66
+ void (*)(const Tensor& qtensor, Tensor& rtensor, double scale, int64_t zero_point);
67
+
68
+ using dequantize_tensor_per_channel_affine_fn = void (*)(
69
+ const Tensor& qtensor,
70
+ Tensor& rtensor,
71
+ const Tensor& scales,
72
+ const Tensor& zero_points,
73
+ int64_t axis);
74
+
75
+ using dequantize_tensor_per_channel_float_qparams_fn = void (*)(
76
+ const Tensor& qtensor,
77
+ Tensor& rtensor,
78
+ const Tensor& scales,
79
+ const Tensor& zero_points,
80
+ int64_t axis);
81
+
82
+ using quantize_tensor_per_tensor_affine_sub_byte_fn =
83
+ void (*)(const Tensor& rtensor, Tensor& qtensor, float scale, float zero_point);
84
+
85
+ using dequantize_tensor_per_tensor_affine_sub_byte_fn =
86
+ void (*)(const Tensor& qtensor, Tensor& rtensor, float scale, float zero_point);
87
+
88
+ DECLARE_DISPATCH(
89
+ quantize_tensor_per_tensor_affine_fn,
90
+ quantize_tensor_per_tensor_affine_stub);
91
+ DECLARE_DISPATCH(
92
+ quantize_tensor_per_channel_affine_fn,
93
+ quantize_tensor_per_channel_affine_stub);
94
+ DECLARE_DISPATCH(
95
+ quantize_tensor_per_channel_float_qparams_fn,
96
+ quantize_tensor_per_channel_float_qparams_stub);
97
+
98
+ DECLARE_DISPATCH(
99
+ dequantize_tensor_per_tensor_affine_fn,
100
+ dequantize_tensor_per_tensor_affine_stub);
101
+ DECLARE_DISPATCH(
102
+ dequantize_tensor_per_channel_affine_fn,
103
+ dequantize_tensor_per_channel_affine_stub);
104
+ DECLARE_DISPATCH(
105
+ dequantize_tensor_per_channel_float_qparams_fn,
106
+ dequantize_tensor_per_channel_float_qparams_stub);
107
+
108
+ DECLARE_DISPATCH(
109
+ quantize_tensor_per_tensor_affine_sub_byte_fn,
110
+ quantize_tensor_per_tensor_affine_sub_byte_stub);
111
+
112
+ DECLARE_DISPATCH(
113
+ dequantize_tensor_per_tensor_affine_sub_byte_fn,
114
+ dequantize_tensor_per_tensor_affine_sub_byte_stub);
115
+
116
+ template <typename T>
117
+ TORCH_API Tensor quantize_tensor(
118
+ Tensor rtensor,
119
+ Tensor qtensor,
120
+ double scale,
121
+ int64_t zero_point);
122
+ template <typename T>
123
+ TORCH_API Tensor dequantize_tensor(
124
+ Tensor qtensor,
125
+ Tensor rtensor,
126
+ double scale,
127
+ int64_t zero_point);
128
+
129
+ } // namespace native
130
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/AffineQuantizerBase.h ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <c10/macros/Export.h>
3
+ #include <c10/core/ScalarType.h>
4
+
5
+ namespace at {
6
+ namespace native {
7
+
8
+ // Quantize a float value into a uint value given scale and zero_point
9
+ template <typename T>
10
+ TORCH_API T quantize_val(double scale, int64_t zero_point, float value);
11
+ // TODO combine this with quantize_val once the numerics for ARM are aligned
12
+ // with it
13
+ template <typename T>
14
+ T quantize_val_arm(
15
+ const float scale,
16
+ const int32_t zero_point,
17
+ const float value);
18
+ template <typename T, int precision = 8>
19
+ void quantize_vec(
20
+ double scale,
21
+ int64_t zero_point,
22
+ const float* src,
23
+ T* dst,
24
+ size_t count = 8);
25
+ template <typename T>
26
+ TORCH_API float dequantize_val(double scale, int64_t zero_point, T value);
27
+ template <typename T>
28
+ TORCH_API float dequantize_vec(
29
+ double scale,
30
+ int64_t zero_point,
31
+ const T* src,
32
+ float* dst,
33
+ size_t count = 8);
34
+ template <typename SRC_T, typename DST_T>
35
+ TORCH_API DST_T requantize_val(double, int64_t, double, int64_t, SRC_T src);
36
+
37
+ // Given a multiplier and a zero_point, requantize int32_t computed values back
38
+ // to quantized values. See comment above
39
+ // make_per_tensor_affine_quantizer function for the usage of int64_t
40
+ template <typename DST_T>
41
+ TORCH_API DST_T
42
+ requantize_from_int(double multiplier, int64_t zero_point, int64_t src);
43
+
44
+ int quantize_val_float_qparams(float scale, float zero_point, float value, int qmin, int qmax);
45
+
46
+ } // namespace native
47
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/ConvUtils.h ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/List.h>
3
+ #include <ATen/native/ConvUtils.h>
4
+
5
+ namespace at::native::quantized {
6
+ namespace {
7
+ // MakeConvOutputShape used from both CPU and CUDA libraries
8
+ // and exporting symbol from torch_cpu would probably take more storage
9
+ // than duplicating implementation which likely be inlined away
10
+ template <int kSpatialDim>
11
+ at::SmallVector<int64_t, kSpatialDim + 2> MakeConvOutputShape(
12
+ int N, // mini-batch
13
+ int M, // output channels
14
+ const std::array<int64_t, kSpatialDim>& input_image_shape,
15
+ const std::vector<int64_t>& kernel,
16
+ const torch::List<int64_t>& stride,
17
+ const torch::List<int64_t>& padding,
18
+ const torch::List<int64_t>& dilation);
19
+
20
+ #if defined(USE_CUDA) || defined(USE_PYTORCH_QNNPACK)
21
+ template <>
22
+ at::SmallVector<int64_t, 4> MakeConvOutputShape<2>(
23
+ int N, // mini-batch
24
+ int M, // output channels
25
+ const std::array<int64_t, 2>& input_image_shape,
26
+ const std::vector<int64_t>& kernel,
27
+ const at::List<int64_t>& stride,
28
+ const at::List<int64_t>& padding,
29
+ const at::List<int64_t>& dilation) {
30
+ const int H = input_image_shape[0];
31
+ const int W = input_image_shape[1];
32
+ const int64_t Y_H =
33
+ (H + 2 * padding[0] - dilation[0] * (kernel[0] - 1) - 1) / stride[0] + 1;
34
+ const int64_t Y_W =
35
+ (W + 2 * padding[1] - dilation[1] * (kernel[1] - 1) - 1) / stride[1] + 1;
36
+ return {N, M, Y_H, Y_W};
37
+ }
38
+
39
+ template <>
40
+ at::SmallVector<int64_t, 5> MakeConvOutputShape<3>(
41
+ int N, // mini-batch
42
+ int M, // output channels
43
+ const std::array<int64_t, 3>& input_image_shape,
44
+ const std::vector<int64_t>& kernel,
45
+ const at::List<int64_t>& stride,
46
+ const at::List<int64_t>& padding,
47
+ const torch::List<int64_t>& dilation) {
48
+ const int D = input_image_shape[0];
49
+ const int H = input_image_shape[1];
50
+ const int W = input_image_shape[2];
51
+ const int64_t Y_D =
52
+ (D + 2 * padding[0] - dilation[0] * (kernel[0] - 1) - 1) / stride[0] + 1;
53
+ const int64_t Y_H =
54
+ (H + 2 * padding[1] - dilation[1] * (kernel[1] - 1) - 1) / stride[1] + 1;
55
+ const int64_t Y_W =
56
+ (W + 2 * padding[2] - dilation[2] * (kernel[2] - 1) - 1) / stride[2] + 1;
57
+ return {N, M, Y_D, Y_H, Y_W};
58
+ }
59
+
60
+ #endif
61
+ } // anonymous namespace
62
+ } // namespace at::native::quantized
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/Copy.h ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+
5
+ namespace at {
6
+ namespace native {
7
+
8
+ Tensor& quantized_copy_from_float_(Tensor& self, const Tensor& src);
9
+ }
10
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/FakeQuantAffine.h ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/Dispatch.h>
5
+ #include <ATen/native/DispatchStub.h>
6
+
7
+ namespace at {
8
+
9
+ struct TensorIterator;
10
+
11
+ namespace native {
12
+
13
+ using fake_quant_tensor_cachemask_fn = void (*)(
14
+ Tensor& output,
15
+ Tensor& mask,
16
+ const Tensor& input,
17
+ float sc,
18
+ int64_t z_point,
19
+ int64_t quant_min,
20
+ int64_t quant_max);
21
+
22
+ using fake_quant_tensor_cachemask_tensor_qparams_fn = void (*)(
23
+ Tensor& output,
24
+ Tensor& mask,
25
+ const Tensor& input,
26
+ const Tensor& sc,
27
+ const Tensor& z_point,
28
+ const Tensor& fake_quant_enabled,
29
+ int64_t quant_min,
30
+ int64_t quant_max);
31
+
32
+ using fake_quant_learnable_grad_tensor_fn = void (*)(
33
+ TensorIterator& iter,
34
+ float scale,
35
+ float inv_scale,
36
+ int64_t zero_point,
37
+ int64_t quant_min,
38
+ int64_t quant_max,
39
+ float grad_factor);
40
+
41
+ DECLARE_DISPATCH(fake_quant_tensor_cachemask_fn, fake_quant_tensor_cachemask_stub);
42
+ DECLARE_DISPATCH(fake_quant_tensor_cachemask_tensor_qparams_fn, fake_quant_tensor_cachemask_tensor_qparams_stub);
43
+ DECLARE_DISPATCH(fake_quant_learnable_grad_tensor_fn, fake_quant_grad_learnable_tensor_stub);
44
+
45
+ using fake_quant_per_channel_fn = void (*)(
46
+ TensorIterator &iter,
47
+ int64_t quant_min,
48
+ int64_t quant_max);
49
+
50
+ using fake_quant_per_channel_cachemask_fn = void (*)(
51
+ TensorIterator &iter,
52
+ TensorIterator &iter_mask,
53
+ int64_t quant_min,
54
+ int64_t quant_max);
55
+
56
+ DECLARE_DISPATCH(fake_quant_per_channel_cachemask_fn, fake_quant_per_channel_cachemask_stub);
57
+
58
+ using fake_quant_learnable_per_channel_fn = void (*)(
59
+ TensorIterator &iter,
60
+ int64_t quant_min,
61
+ int64_t quant_max,
62
+ float grad_factor);
63
+
64
+ DECLARE_DISPATCH(fake_quant_learnable_per_channel_fn, fake_quant_grad_learnable_channel_stub);
65
+
66
+ } // namespace native
67
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/IndexKernel.h ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/native/TensorIterator.h>
3
+
4
+ namespace at {
5
+ namespace native {
6
+ using masked_fill_kernel_quantized_fn = void(*)(TensorIterator& iter, const Scalar& value, double scale, int zero_point);
7
+ using index_put_kernel_quantized_fn = void(*)(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate, double scale, int zero_point);
8
+
9
+ DECLARE_DISPATCH(masked_fill_kernel_quantized_fn, masked_fill_kernel_quantized_stub);
10
+ DECLARE_DISPATCH(index_put_kernel_quantized_fn, index_put_kernel_quantized_stub);
11
+
12
+
13
+ } // native
14
+ } // at
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/PackedParams.h ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/core/ivalue.h>
5
+
6
+ struct LinearPackedParamsBase : public torch::jit::CustomClassHolder {
7
+ virtual at::Tensor apply(
8
+ at::Tensor input,
9
+ double output_scale,
10
+ int64_t output_zero_point) = 0;
11
+ virtual at::Tensor apply_relu(
12
+ at::Tensor input,
13
+ double output_scale,
14
+ int64_t output_zero_point) = 0;
15
+
16
+ // out variant of LinearPackedParamsBase::apply
17
+ virtual at::Tensor& apply_out(
18
+ const at::Tensor& /*input*/,
19
+ double /*output_scale*/,
20
+ int64_t /*output_zero_point*/,
21
+ at::Tensor& output) {
22
+ throw std::runtime_error(
23
+ "apply_out is not implemented for this packed "
24
+ "parameter type");
25
+ return output;
26
+ }
27
+
28
+ virtual at::Tensor& apply_relu_out(
29
+ const at::Tensor& /*input*/,
30
+ double /*output_scale*/,
31
+ int64_t /*output_zero_point*/,
32
+ at::Tensor& output) {
33
+ throw std::runtime_error(
34
+ "apply_relu_out is not implemented for this packed "
35
+ "parameter type");
36
+ return output;
37
+ }
38
+
39
+ // Corresponding pattern (the ops with `*` are part of the pattern that
40
+ // represents the computation of quantized::linear_with_input_q_dq_qweight_dq_output_fp32):
41
+ // input -> q* -> dq* -> linear* ->
42
+ // qweight -> dq* /
43
+ //
44
+ // After fusion:
45
+ // input -> quantized::linear_with_input_q_dq_qweight_dq_output_fp32* ->
46
+ // qweight /
47
+ //
48
+ // Additional Note: the weight is packed as well
49
+ // Params:
50
+ // X: float32 Tensor, will be quantized to quint8 in the op
51
+ // W_prepack: packed qint8 quantized weight and bias
52
+ // Returns:
53
+ // Y: float32 Tensor
54
+ virtual at::Tensor apply_with_input_q_dq_qweight_dq_output_fp32(
55
+ at::Tensor input,
56
+ double input_scale,
57
+ int64_t input_zero_point) {
58
+ throw std::runtime_error(
59
+ "apply_with_input_q_dq_qweight_dq_output_fp32 is not implemented for this packed "
60
+ "parameter type");
61
+ return {};
62
+ }
63
+
64
+ // Corresponding pattern (the ops with `*` are part of the pattern that
65
+ // represents the computation of quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32):
66
+ // input -> q* -> dq* -> linear* -> relu* ->
67
+ // qweight -> dq* /
68
+ //
69
+ // After fusion:
70
+ // input -> quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32* ->
71
+ // qweight /
72
+ //
73
+ // Additional Note: the weight is packed as well
74
+ // Params:
75
+ // input: float32 Tensor, will be quantized to quint8 in the op
76
+ // Returns:
77
+ // float32 Tensor
78
+ virtual at::Tensor apply_with_input_q_dq_qweight_dq_relu_output_fp32(
79
+ at::Tensor input,
80
+ double input_scale,
81
+ int64_t input_zero_point) {
82
+ throw std::runtime_error(
83
+ "apply_with_input_q_dq_qweight_dq_relu_output_fp32 is not implemented for this packed "
84
+ "parameter type");
85
+ return {};
86
+ }
87
+
88
+ virtual at::Tensor apply_dynamic(
89
+ at::Tensor input,
90
+ bool reduce_range = false) = 0;
91
+ virtual at::Tensor apply_dynamic_relu(
92
+ at::Tensor input,
93
+ bool reduce_range = false) = 0;
94
+
95
+ virtual at::Tensor& apply_dynamic_out(
96
+ const at::Tensor& /* input */,
97
+ at::Tensor& output,
98
+ bool /* reduce_range */) {
99
+ throw std::runtime_error(
100
+ "apply_dynamic_out is not implemented for this packed "
101
+ "parameter type");
102
+ return output;
103
+ }
104
+ virtual at::Tensor& apply_dynamic_relu_out(
105
+ const at::Tensor& /* input */,
106
+ at::Tensor& output,
107
+ bool /* reduce_range */) {
108
+ throw std::runtime_error(
109
+ "apply_dynamic_relu_out is not implemented for this packed "
110
+ "parameter type");
111
+ return output;
112
+ }
113
+
114
+ virtual std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() = 0;
115
+
116
+ virtual std::optional<at::Tensor> bias() = 0;
117
+
118
+ virtual void set_bias(std::optional<at::Tensor> /*bias*/) {
119
+ throw std::runtime_error(
120
+ "set_bias is not implemented for this packed "
121
+ "parameter type");
122
+ }
123
+ };
124
+
125
+ template <int kSpatialDim = 2>
126
+ struct ConvPackedParamsBase : public torch::jit::CustomClassHolder {
127
+ virtual at::Tensor apply(
128
+ const at::Tensor& input,
129
+ double output_scale,
130
+ int64_t output_zero_point) = 0;
131
+ virtual at::Tensor apply_relu(
132
+ const at::Tensor& input,
133
+ double output_scale,
134
+ int64_t output_zero_point) = 0;
135
+ virtual at::Tensor apply_dynamic(
136
+ const at::Tensor& input,
137
+ bool reduce_range) = 0;
138
+
139
+ virtual std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() = 0;
140
+
141
+ virtual torch::List<int64_t> stride() const = 0;
142
+ virtual torch::List<int64_t> padding() const = 0;
143
+ virtual torch::List<int64_t> output_padding() const = 0;
144
+ virtual torch::List<int64_t> dilation() const = 0;
145
+ virtual int64_t groups() const = 0;
146
+ virtual bool transpose() const = 0;
147
+ };
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/BinaryOps.h ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/core/Tensor.h>
2
+
3
+ namespace at {
4
+ namespace native {
5
+ TORCH_API Tensor
6
+ quantized_add(Tensor qa, Tensor qb, double scale, int64_t zero_point);
7
+ }
8
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/EmbeddingPackedParams.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/core/ivalue.h>
5
+
6
+ struct EmbeddingPackedParamsBase : public torch::jit::CustomClassHolder {
7
+ virtual at::Tensor embeddingbag_byte(
8
+ const at::Tensor& indices,
9
+ const std::optional<at::Tensor>& offsets,
10
+ bool pruned_weights,
11
+ const std::optional<at::Tensor>& per_sample_weights_,
12
+ const std::optional<at::Tensor>& compressed_indices_mapping,
13
+ bool include_last_offset,
14
+ bool is_embedding_op) = 0;
15
+
16
+ virtual at::Tensor embeddingbag_4bit(
17
+ const at::Tensor& indices,
18
+ const std::optional<at::Tensor>& offsets,
19
+ bool pruned_weights,
20
+ const std::optional<at::Tensor>& per_sample_weights_,
21
+ const std::optional<at::Tensor>& compressed_indices_mapping,
22
+ bool include_last_offset,
23
+ bool is_embedding_op) = 0;
24
+
25
+ virtual at::Tensor unpack() = 0;
26
+
27
+ virtual int64_t bit_rate() const = 0;
28
+ virtual int64_t version() const = 0;
29
+ };
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/OnednnUtils.h ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/Config.h>
4
+ #if AT_MKLDNN_ENABLED()
5
+ #include <ATen/Tensor.h>
6
+ #include <ATen/native/quantized/PackedParams.h>
7
+ #include <ideep.hpp>
8
+ #include <cpuinfo.h>
9
+
10
+ #include <c10/util/CallOnce.h>
11
+
12
+ using PrimitiveCacheKey = std::tuple<
13
+ double, // input_scale
14
+ int64_t, // input_zero_point
15
+ std::vector<int64_t>, // input_shape
16
+ double, // output_scale
17
+ int64_t, // output_zero_point
18
+ int64_t, // OMP_number_of_threads
19
+ double, // accum_scale
20
+ int64_t>; // accum_zero_point
21
+
22
+ enum CacheKeyIndex {
23
+ InputScale,
24
+ InputZeroPoint,
25
+ InputShape,
26
+ OutputScale,
27
+ OutputZeroPoint,
28
+ NumOfThreads,
29
+ };
30
+
31
+ // Base class of primitive cache
32
+ struct PrimitiveCache {
33
+ PrimitiveCacheKey key;
34
+
35
+ bool hit(const PrimitiveCacheKey& key) {
36
+ return this->key == key;
37
+ }
38
+ };
39
+
40
+ using LinearParams = ideep::matmul_forward_params;
41
+ using Conv = dnnl::convolution_forward;
42
+ using ConvDesc = dnnl::convolution_forward::primitive_desc;
43
+ using ConvParams = ideep::convolution_forward_params;
44
+ using Deconv = dnnl::deconvolution_forward;
45
+ using DeconvDesc = dnnl::deconvolution_forward::primitive_desc;
46
+ using DeconvParams = ideep::deconv_forward_params;
47
+
48
+ struct LinearPrimitiveCache : PrimitiveCache {
49
+ LinearPrimitiveCache() {}
50
+
51
+ LinearPrimitiveCache(
52
+ const PrimitiveCacheKey& key,
53
+ const LinearParams& param) {
54
+ this->key = key;
55
+ this->param = param;
56
+ }
57
+
58
+ LinearParams param;
59
+
60
+ // For dynamic qlinear, scale and zero point
61
+ // are set at execution time. So we only need to compare
62
+ // the rest part of key.
63
+ bool hit_dynamic(const PrimitiveCacheKey& new_key) {
64
+ auto cached_input_shape = std::get<InputShape>(this->key);
65
+ auto new_input_shape = std::get<InputShape>(new_key);
66
+ return (
67
+ cached_input_shape == new_input_shape &&
68
+ std::get<NumOfThreads>(this->key) == std::get<NumOfThreads>(new_key));
69
+ }
70
+
71
+ LinearParams& get_param() {
72
+ return param;
73
+ }
74
+ };
75
+
76
+ struct ConvPrimitiveCache : PrimitiveCache {
77
+ ConvPrimitiveCache() {}
78
+
79
+ ConvPrimitiveCache(
80
+ const PrimitiveCacheKey& key,
81
+ const ConvParams& params) {
82
+ this->key = key;
83
+ this->params = params;
84
+ }
85
+
86
+ ConvParams params;
87
+
88
+ ConvParams& get_params() {
89
+ return params;
90
+ }
91
+ };
92
+
93
+ struct DeconvPrimitiveCache : PrimitiveCache {
94
+ DeconvPrimitiveCache() {}
95
+
96
+ DeconvPrimitiveCache(
97
+ const PrimitiveCacheKey& key,
98
+ const DeconvParams& params) {
99
+ this->key = key;
100
+ this->params = params;
101
+ }
102
+
103
+ DeconvParams params;
104
+
105
+ DeconvParams& get_params() {
106
+ return params;
107
+ }
108
+ };
109
+
110
+ enum PostOps {
111
+ NoPostOp,
112
+ Relu,
113
+ LeakyRelu,
114
+ Tanh,
115
+ Gelu
116
+ };
117
+
118
+
119
+ struct PackedLinearWeightsOnednn : public LinearPackedParamsBase {
120
+ PackedLinearWeightsOnednn(
121
+ std::unique_ptr<ideep::tensor> weight,
122
+ std::optional<ideep::tensor> bias,
123
+ at::Tensor orig_weight,
124
+ std::optional<at::Tensor> orig_bias)
125
+ : weight_(std::move(weight)),
126
+ bias_(std::move(bias)),
127
+ orig_weight_(std::move(orig_weight)),
128
+ orig_bias_(std::move(orig_bias)) {
129
+ cache_initialized_flag = std::make_unique<c10::once_flag>();
130
+ }
131
+ std::unique_ptr<ideep::tensor> weight_;
132
+ std::optional<ideep::tensor> bias_;
133
+ at::Tensor orig_weight_;
134
+ std::optional<at::Tensor> orig_bias_;
135
+
136
+ at::Tensor apply(
137
+ at::Tensor input,
138
+ double output_scale,
139
+ int64_t output_zero_point) override;
140
+ at::Tensor apply_relu(
141
+ at::Tensor input,
142
+ double output_scale,
143
+ int64_t output_zero_point) override;
144
+
145
+ at::Tensor apply_dynamic(at::Tensor input, bool reduce_range=false) override;
146
+ at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range=false) override;
147
+
148
+ at::Tensor apply_leaky_relu(
149
+ at::Tensor input,
150
+ double output_scale,
151
+ int64_t output_zero_point,
152
+ double negative_slope);
153
+
154
+ at::Tensor apply_tanh(
155
+ at::Tensor input,
156
+ double output_scale,
157
+ int64_t output_zero_point);
158
+
159
+ std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() override;
160
+
161
+ std::optional<at::Tensor> bias() override {
162
+ return orig_bias_;
163
+ }
164
+
165
+ static c10::intrusive_ptr<LinearPackedParamsBase> prepack(
166
+ at::Tensor weight,
167
+ std::optional<at::Tensor> bias);
168
+
169
+ private:
170
+ LinearPrimitiveCache prim_cache;
171
+ std::unique_ptr<c10::once_flag> cache_initialized_flag;
172
+
173
+ template <PostOps post_op>
174
+ at::Tensor apply_impl(
175
+ at::Tensor input,
176
+ double output_scale,
177
+ int64_t output_zero_point,
178
+ torch::List<at::Scalar> post_op_args = torch::List<at::Scalar>());
179
+
180
+ template <bool ReluFused>
181
+ at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range=false);
182
+
183
+ LinearPrimitiveCache& get_cache() {
184
+ return prim_cache;
185
+ }
186
+ };
187
+
188
+ template <int kSpatialDim = 2>
189
+ struct PackedConvWeightsOnednn : public ConvPackedParamsBase<kSpatialDim> {
190
+ PackedConvWeightsOnednn(
191
+ std::unique_ptr<ideep::tensor> weight,
192
+ std::optional<ideep::tensor> bias,
193
+ at::Tensor orig_weight,
194
+ std::optional<at::Tensor> orig_bias,
195
+ torch::List<int64_t> stride,
196
+ torch::List<int64_t> padding,
197
+ torch::List<int64_t> output_padding,
198
+ torch::List<int64_t> dilation,
199
+ int64_t groups,
200
+ uint8_t transpose)
201
+ : weight_(std::move(weight)),
202
+ bias_(std::move(bias)),
203
+ orig_weight_(std::move(orig_weight)),
204
+ orig_bias_(std::move(orig_bias)),
205
+ stride_(std::move(stride)),
206
+ padding_(std::move(padding)),
207
+ output_padding_(std::move(output_padding)),
208
+ dilation_(std::move(dilation)),
209
+ groups_(groups),
210
+ transpose_(transpose) {
211
+ cache_initialized_flag = std::make_unique<c10::once_flag>();
212
+ }
213
+
214
+ std::unique_ptr<ideep::tensor> weight_;
215
+ std::optional<ideep::tensor> bias_;
216
+ at::Tensor orig_weight_;
217
+ std::optional<at::Tensor> orig_bias_;
218
+ torch::List<int64_t> stride_;
219
+ torch::List<int64_t> padding_;
220
+ torch::List<int64_t> output_padding_;
221
+ torch::List<int64_t> dilation_;
222
+ int64_t groups_;
223
+ uint8_t transpose_;
224
+
225
+ at::Tensor apply(
226
+ const at::Tensor& input,
227
+ double output_scale,
228
+ int64_t output_zero_point) override;
229
+
230
+ at::Tensor apply_relu(
231
+ const at::Tensor& input,
232
+ double output_scale,
233
+ int64_t output_zero_point) override;
234
+
235
+ at::Tensor apply_dynamic(
236
+ const at::Tensor& input,
237
+ bool reduce_range) override;
238
+
239
+ at::Tensor apply_add(
240
+ const at::Tensor& input,
241
+ const at::Tensor& accum,
242
+ double output_scale,
243
+ int64_t output_zero_point);
244
+
245
+ at::Tensor apply_add_relu(
246
+ const at::Tensor& input,
247
+ const at::Tensor& accum,
248
+ double output_scale,
249
+ int64_t output_zero_point);
250
+
251
+ std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() override;
252
+
253
+ static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> prepack(
254
+ at::Tensor weight,
255
+ std::optional<at::Tensor> bias,
256
+ torch::List<int64_t> stride,
257
+ torch::List<int64_t> padding,
258
+ torch::List<int64_t> output_padding,
259
+ torch::List<int64_t> dilation,
260
+ int64_t groups,
261
+ bool transpose);
262
+
263
+ torch::List<int64_t> stride() const override {
264
+ return stride_;
265
+ }
266
+
267
+ torch::List<int64_t> padding() const override {
268
+ return padding_;
269
+ }
270
+
271
+ torch::List<int64_t> output_padding() const override {
272
+ return output_padding_;
273
+ }
274
+
275
+ torch::List<int64_t> dilation() const override {
276
+ return dilation_;
277
+ }
278
+
279
+ int64_t groups() const override {
280
+ return groups_;
281
+ }
282
+
283
+ bool transpose() const override {
284
+ return (bool)transpose_;
285
+ }
286
+
287
+ private:
288
+ ConvPrimitiveCache conv_prim_cache;
289
+ DeconvPrimitiveCache deconv_prim_cache;
290
+ std::unique_ptr<c10::once_flag> cache_initialized_flag;
291
+
292
+ template <bool ReluFused>
293
+ at::Tensor apply_impl(
294
+ const at::Tensor& input,
295
+ const std::optional<at::Tensor>& accum,
296
+ double output_scale,
297
+ int64_t output_zero_point);
298
+
299
+ ConvPrimitiveCache& get_conv_cache() {
300
+ assert(!transpose());
301
+ return conv_prim_cache;
302
+ }
303
+
304
+ DeconvPrimitiveCache& get_deconv_cache() {
305
+ assert(transpose());
306
+ return deconv_prim_cache;
307
+ }
308
+ };
309
+
310
+ namespace onednn_utils {
311
+
312
+ inline ideep::attr_t create_attr_by_post_op(
313
+ const c10::string_view& binary_post_op,
314
+ double binary_alpha,
315
+ double input1_scale,
316
+ int64_t input1_zero_point,
317
+ const ideep::tensor::desc& input1_desc,
318
+ const c10::string_view& unary_post_op,
319
+ const torch::List<std::optional<at::Scalar>>& unary_post_op_args,
320
+ const c10::string_view& unary_post_op_algorithm) {
321
+ using ideep::tensor;
322
+ if (binary_post_op == "none") {
323
+ if (unary_post_op == "relu") {
324
+ return ideep::attr_t::fuse_relu();
325
+ } else if (unary_post_op == "leaky_relu") {
326
+ TORCH_CHECK(
327
+ unary_post_op_args.size() == 1,
328
+ "onednn qlinear: expect one argument for post op leaky_relu but got ", unary_post_op_args.size(), " args");
329
+ auto alpha = unary_post_op_args[0].value().to<float>();
330
+ return ideep::attr_t::fuse_relu_v2(alpha);
331
+ } else if (unary_post_op == "tanh") {
332
+ return ideep::attr_t::fuse_tanh();
333
+ } else if (unary_post_op == "gelu") {
334
+ TORCH_CHECK(
335
+ unary_post_op_algorithm == "none" || unary_post_op_algorithm == "tanh",
336
+ "onednn qlinear: algorithm for post op gelu must be none or tanh but got ", unary_post_op_algorithm);
337
+ auto post_algorithm = unary_post_op_algorithm == "none" ?
338
+ dnnl::algorithm::eltwise_gelu_erf :
339
+ dnnl::algorithm::eltwise_gelu_tanh;
340
+ return ideep::attr_t::fuse_gelu_v2(0.f, 0.f, post_algorithm);
341
+ } else if (unary_post_op == "hardtanh") {
342
+ TORCH_CHECK(
343
+ unary_post_op_args.size() == 2 &&
344
+ unary_post_op_args[0].has_value() &&
345
+ unary_post_op_args[1].has_value(),
346
+ "hardtanh is expected to have two scalar input: min_val and max_val");
347
+ auto lower_bound_value =
348
+ unary_post_op_args[0].value().to<float>();
349
+ auto upper_bound_value =
350
+ unary_post_op_args[1].value().to<float>();
351
+ return ideep::attr_t::fuse_clamp(lower_bound_value, upper_bound_value);
352
+ } else if (unary_post_op == "hardswish") {
353
+ return ideep::attr_t::fuse_hardswish();
354
+ } else if (unary_post_op == "swish") {
355
+ return ideep::attr_t::fuse_swish();
356
+ } else {
357
+ TORCH_CHECK(
358
+ unary_post_op == "none",
359
+ "onednn qlinear: unsupported unary post op ", unary_post_op);
360
+ }
361
+ } else if (binary_post_op == "sum") {
362
+ if (unary_post_op == "none") {
363
+ return ideep::attr_t::fuse_sum(input1_scale, input1_zero_point);
364
+ } else if (unary_post_op == "relu") {
365
+ return ideep::attr_t::residual_with_sum_zero_point(input1_scale, input1_zero_point);
366
+ } else {
367
+ TORCH_CHECK(
368
+ false,
369
+ "onednn qlinear: unsupported unary post op ", unary_post_op, " with binary post op sum");
370
+ }
371
+ } else if (binary_post_op == "add") {
372
+ if (unary_post_op == "none") {
373
+ return ideep::attr_t::fuse_binary(ideep::algorithm::binary_add, input1_desc);
374
+ } else if (unary_post_op == "relu") {
375
+ ideep::post_ops po;
376
+ po.append_binary(ideep::algorithm::binary_add, input1_desc);
377
+ po.append_eltwise(ideep::algorithm::eltwise_relu, 0, 0);
378
+ return ideep::attr_t::attr_post_ops(po);
379
+ } else {
380
+ TORCH_CHECK(
381
+ false,
382
+ "onednn qlinear: unsupported unary post op ", unary_post_op, " with binary post op add");
383
+ }
384
+ } else {
385
+ TORCH_CHECK(
386
+ false,
387
+ "onednn qlinear: unsupported binary post op ", binary_post_op);
388
+ }
389
+ return ideep::attr_t();
390
+ }
391
+
392
+ // ONEDNN requires symmetric quantization of weight
393
+ // Use this util function to check.
394
+ inline bool is_weight_symmetric_quant(
395
+ const at::Tensor& weight,
396
+ bool is_transposed_conv) {
397
+ bool is_symmetric = true;
398
+ const auto qtype = weight.qscheme();
399
+ if (qtype == c10::kPerTensorAffine) {
400
+ is_symmetric &= (weight.q_zero_point() == 0);
401
+ } else if (qtype == c10::kPerChannelAffine) {
402
+ if (is_transposed_conv) {
403
+ // This case is currently not supported in PyTorch
404
+ // but we do not want to raise an error in this util function.
405
+ is_symmetric = false;
406
+ } else {
407
+ auto output_channels = weight.size(0);
408
+ for (int i = 0; i < output_channels; ++i) {
409
+ auto zp = weight.q_per_channel_zero_points()[i].item<int32_t>();
410
+ is_symmetric &= (zp == 0);
411
+ }
412
+ }
413
+ } else {
414
+ // This case is currently not supported in PyTorch
415
+ // but we do not want to raise an error in this util function.
416
+ is_symmetric = false;
417
+ }
418
+ return is_symmetric;
419
+ }
420
+
421
+ // When qengine is x86, use this util func to check if onednn kernel
422
+ // is preferred than fbgemm's to get better performance.
423
+ inline bool should_use_onednn_quant(
424
+ const at::Tensor& weight,
425
+ bool is_transposed_conv,
426
+ int groups,
427
+ torch::List<int64_t> output_padding) {
428
+ // Performance of onednn is only validated on Linux right now.
429
+ // Also, the heuristics for dispatching are based on perf data on Linux.
430
+ // So, for x86 qengine, we always use fbgemm kernels if OS is not Linux.
431
+ // TODO Support more OSs.
432
+ #if !defined(__linux__)
433
+ return false;
434
+ #else
435
+ bool vnni_available = cpuinfo_has_x86_avx512vnni();
436
+ bool w_sym_quant =
437
+ is_weight_symmetric_quant(weight, is_transposed_conv);
438
+ bool opad_all_zero =
439
+ std::all_of(output_padding.begin(), output_padding.end(), [](int i) { return i==0; });
440
+ return vnni_available && (groups <= 100) && w_sym_quant && opad_all_zero;
441
+ #endif
442
+ }
443
+
444
+ } // onednn_utils
445
+
446
+ at::Tensor _qconv_prepack_onednn(
447
+ at::Tensor weight, // from CPU backend instead of QuantizedCPU
448
+ at::Tensor weight_scales, // Weight zero points must be 0 for onednn
449
+ double input_scale,
450
+ int64_t input_zero_point,
451
+ torch::List<int64_t> stride,
452
+ torch::List<int64_t> padding,
453
+ torch::List<int64_t> dilation,
454
+ int64_t groups,
455
+ std::optional<torch::List<int64_t>> input_shape=std::nullopt);
456
+
457
+ #endif // #if AT_MKLDNN_ENABLED()
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/QnnpackUtils.h ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #ifdef USE_PYTORCH_QNNPACK
4
+ #include <ATen/core/Tensor.h>
5
+ #include <c10/util/irange.h>
6
+ #include <pytorch_qnnpack.h>
7
+ #include <qnnpack_func.h>
8
+ #include <ATen/native/quantized/cpu/XnnpackUtils.h>
9
+ #include <ATen/native/quantized/PackedParams.h>
10
+ #include <ATen/native/utils/Factory.h>
11
+
12
+ #ifndef AT_PER_OPERATOR_HEADERS
13
+ #include <ATen/Functions.h>
14
+ #else
15
+ #include <ATen/ops/empty.h>
16
+ #endif
17
+
18
+ #include <utility>
19
+ inline int kPaddingChannels = 8;
20
+ struct QnnpackOperatorDeleter {
21
+ void operator()(pytorch_qnnp_operator_t op) {
22
+ pytorch_qnnp_delete_operator(op);
23
+ }
24
+ };
25
+
26
+ // PackedWeight struct for QNNPACK stores the original Weight and Bias as
27
+ // QNNPACK currently does not support an unpack function.
28
+ // For PyTorch Mobile, once the model is scripted and serialized we don't need
29
+ // to call unpack, so we can save some memory by checking for this case and free
30
+ // the original weights after packing.
31
+ // Input scale is set to null in pre-pack step. QNNPACK needs bias quantized
32
+ // with input scale which is available at runtime in pytorch. During runtime if
33
+ // input scale value changes then we requantize bias with the updated scale. For
34
+ // inference we expect the graph to be static so the input scale should not
35
+ // change across consecutive inference calls.
36
+ struct PackedLinearWeightsQnnp : public LinearPackedParamsBase {
37
+ PackedLinearWeightsQnnp(
38
+ std::unique_ptr<qnnpack::PackBMatrix> w,
39
+ at::Tensor orig_weight,
40
+ at::Tensor bias,
41
+ std::optional<double> input_scale,
42
+ at::Tensor w_scales,
43
+ std::vector<uint8_t>&& w_zps)
44
+ : w(std::move(w)),
45
+ orig_weight(std::move(orig_weight)),
46
+ bias_(at::native::mobile::allocate_padded_contiguous_if_needed(
47
+ bias, bias.suggest_memory_format())),
48
+ per_channel_(this->orig_weight.qscheme() == at::kPerChannelAffine),
49
+ input_scale(std::move(input_scale)),
50
+ w_scales(std::move(w_scales)),
51
+ w_zero_points(std::move(w_zps)),
52
+ q_scheme(this->orig_weight.qscheme()) {
53
+ weight_sizes = this->orig_weight.sizes().vec();
54
+ }
55
+
56
+ std::unique_ptr<qnnpack::PackBMatrix> w;
57
+ at::Tensor orig_weight;
58
+ at::Tensor bias_;
59
+ bool per_channel_;
60
+ std::optional<double> input_scale;
61
+ at::Tensor w_scales;
62
+ std::vector<uint8_t> w_zero_points;
63
+ std::vector<float> requantization_scales;
64
+ std::vector<int64_t> weight_sizes;
65
+ c10::QScheme q_scheme;
66
+
67
+ at::Tensor apply(
68
+ at::Tensor input,
69
+ double output_scale,
70
+ int64_t output_zero_point) override;
71
+ at::Tensor apply_relu(
72
+ at::Tensor input,
73
+ double output_scale,
74
+ int64_t output_zero_point) override;
75
+
76
+ at::Tensor apply_dynamic(at::Tensor input, bool reduce_range=false) override;
77
+ at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range=false) override;
78
+
79
+ std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() override;
80
+
81
+ std::optional<at::Tensor> bias() override {
82
+ return bias_;
83
+ }
84
+
85
+ static c10::intrusive_ptr<LinearPackedParamsBase> prepack(
86
+ at::Tensor weight,
87
+ std::optional<at::Tensor> bias);
88
+
89
+ bool per_channel() const {
90
+ return per_channel_;
91
+ }
92
+
93
+ private:
94
+ std::mutex qnnp_mutex_;
95
+
96
+ #ifdef USE_XNNPACK
97
+ xnnpack_operator xnnp_linear_op;
98
+
99
+ template <typename scalar_t, bool kReluFused>
100
+ at::Tensor apply_impl_xnnp(
101
+ const at::Tensor& input,
102
+ double output_scale,
103
+ int64_t output_zero_point);
104
+ #endif // USE_XNNPACK
105
+
106
+ template <bool ReluFused>
107
+ at::Tensor apply_impl(
108
+ at::Tensor input,
109
+ double output_scale,
110
+ int64_t output_zero_point);
111
+
112
+ template <bool ReluFused>
113
+ at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range);
114
+ };
115
+
116
+ template <int kSpatialDim = 2>
117
+ struct PackedConvWeightsQnnp : public ConvPackedParamsBase<kSpatialDim> {
118
+ PackedConvWeightsQnnp(
119
+ std::unique_ptr<qnnpack::PrePackConvWeights> w,
120
+ at::Tensor orig_weight,
121
+ at::Tensor bias,
122
+ torch::List<int64_t> stride,
123
+ torch::List<int64_t> padding,
124
+ torch::List<int64_t> output_padding,
125
+ torch::List<int64_t> dilation,
126
+ int64_t groups,
127
+ bool transpose,
128
+ std::optional<double> input_scale,
129
+ std::vector<int64_t> kernel,
130
+ at::Tensor w_scale,
131
+ std::vector<uint8_t>&& w_zps,
132
+ bool is_per_channel)
133
+ : w(std::move(w)),
134
+ orig_weight(std::move(orig_weight)),
135
+ bias(std::move(bias)),
136
+ stride_(std::move(stride)),
137
+ padding_(std::move(padding)),
138
+ output_padding_(std::move(output_padding)),
139
+ dilation_(std::move(dilation)),
140
+ groups_(groups),
141
+ transpose_(transpose),
142
+ is_per_channel_(is_per_channel),
143
+ input_scale(input_scale),
144
+ kernel_(std::move(kernel)),
145
+ w_scales(std::move(w_scale)),
146
+ w_zero_points(std::move(w_zps)) {
147
+ const bool any_padding = std::any_of(
148
+ padding_.begin(), padding_.end(), [](const auto& e) { return e != 0; });
149
+ const size_t kernel_size =
150
+ std::accumulate(kernel_.begin(), kernel_.end(), 1, std::multiplies<>());
151
+
152
+ const size_t group_input_channels = transpose
153
+ ? this->orig_weight.size(0) / groups
154
+ : this->orig_weight.size(1);
155
+ const size_t group_output_channels = transpose
156
+ ? this->orig_weight.size(1)
157
+ : this->orig_weight.size(0) / groups;
158
+
159
+ const size_t kernel_depth = kSpatialDim == 3 ? kernel_[0] : 1;
160
+ const size_t kernel_height = kernel_[kSpatialDim - 2];
161
+ const size_t kernel_width = kernel_[kSpatialDim - 1];
162
+
163
+ pytorch_qnnp_ukernel_type ukernel_type;
164
+ if (transpose_) {
165
+ ukernel_type = pytorch_qnnp_ukernel_type_conv;
166
+ } else {
167
+ ukernel_type = pytorch_qnnp_ukernel_type_none;
168
+
169
+ const bool has_depthwise_dimensions =
170
+ (kSpatialDim == 2 &&
171
+ ((kernel_height == 3 && kernel_width == 3) ||
172
+ (kernel_height == 5 && kernel_width == 5))) ||
173
+ (kSpatialDim == 3 && kernel_height == 3 && kernel_width == 3 &&
174
+ kernel_depth == 3);
175
+ const bool has_depthwise_grouping =
176
+ group_input_channels == 1 && group_output_channels == 1 && groups > 1;
177
+
178
+ if (has_depthwise_dimensions && has_depthwise_grouping) {
179
+ ukernel_type = pytorch_qnnp_ukernel_type_dwconv;
180
+ } else if (
181
+ kernel_size == 1 &&
182
+ std::all_of(
183
+ stride_.begin(),
184
+ stride_.end(),
185
+ [](const auto& e) { return e == 1; }) &&
186
+ !any_padding) {
187
+ ukernel_type = group_input_channels >= SIZE_MAX
188
+ ? pytorch_qnnp_ukernel_type_xzp_gemm
189
+ : pytorch_qnnp_ukernel_type_gemm;
190
+ } else {
191
+ ukernel_type = pytorch_qnnp_ukernel_type_conv;
192
+ }
193
+ }
194
+
195
+ if (is_per_channel && ukernel_type == pytorch_qnnp_ukernel_type_xzp_gemm) {
196
+ TORCH_INTERNAL_ASSERT(
197
+ false, "Per channel quantized weights are not supported for XZP kernels");
198
+ }
199
+
200
+ pytorch_qnnp_operator_t convolution{nullptr};
201
+ // Initially all the params are set to zero.
202
+ convolution = static_cast<pytorch_qnnp_operator_t>(
203
+ calloc(1, sizeof(struct pytorch_qnnp_operator)));
204
+ if (convolution == nullptr) {
205
+ TORCH_INTERNAL_ASSERT(
206
+ false, "failed to allocate %zu bytes for pytorch_qnnp_operator structure",
207
+ sizeof(struct pytorch_qnnp_operator));
208
+ }
209
+
210
+ convolution_op =
211
+ std::unique_ptr<pytorch_qnnp_operator, QnnpackOperatorDeleter>(
212
+ convolution);
213
+
214
+ // NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
215
+ convolution->ukernel_type = ukernel_type;
216
+ convolution->groups = groups;
217
+ convolution->group_input_channels = group_input_channels;
218
+ convolution->group_output_channels = group_output_channels;
219
+ convolution->kernel_depth = kernel_depth;
220
+ convolution->kernel_height = kernel_height;
221
+ convolution->kernel_width = kernel_width;
222
+ convolution->stride_depth = kSpatialDim == 3 ? stride_[0] : 1;
223
+ convolution->stride_height = stride_[kSpatialDim - 2];
224
+ convolution->stride_width = stride_[kSpatialDim - 1];
225
+ convolution->dilation_depth = kSpatialDim == 3 ? dilation_[0] : 1;
226
+ convolution->dilation_height = dilation_[kSpatialDim - 2];
227
+ convolution->dilation_width = dilation_[kSpatialDim - 1];
228
+ convolution->input_padding_height = padding_[kSpatialDim - 2];
229
+ convolution->input_padding_width = padding_[kSpatialDim - 1];
230
+ convolution->input_padding_depth = kSpatialDim == 3 ? padding_[0] : 0;
231
+ convolution->per_channel = is_per_channel_;
232
+ convolution->transpose = transpose_;
233
+
234
+ const uint32_t kr = pytorch_qnnp_params.q8conv.kr;
235
+ const size_t k_stride = (group_input_channels + (kr - 1)) & -kr;
236
+
237
+ size_t zero_size = sizeof(uint8_t) * k_stride;
238
+ size_t zero_offset = 0;
239
+
240
+ if (transpose_) {
241
+ convolution->adjustment_width = output_padding_[1];
242
+ convolution->adjustment_height = output_padding_[0];
243
+ if (group_input_channels < 8) {
244
+ zero_size += 8;
245
+ zero_offset = 8;
246
+ }
247
+ } else {
248
+ zero_buffer_size = 0;
249
+ if (any_padding) {
250
+ zero_size = 0;
251
+ zero_offset = 0;
252
+ if (ukernel_type == pytorch_qnnp_ukernel_type_dwconv) {
253
+ const uint32_t cr = pytorch_qnnp_params.q8dw9.cr;
254
+ const size_t group_stride = (groups + (cr - 1)) & -cr;
255
+ if (groups >= 8) {
256
+ zero_size = sizeof(uint8_t) * group_stride;
257
+ zero_offset = 0;
258
+ } else {
259
+ zero_size = sizeof(uint8_t) * group_stride + 8;
260
+ zero_offset = sizeof(uint8_t) * 8;
261
+ }
262
+ } else if (
263
+ ukernel_type == pytorch_qnnp_ukernel_type_conv ||
264
+ ukernel_type == pytorch_qnnp_ukernel_type_gemm) {
265
+ if (group_input_channels >= 8) {
266
+ zero_size = sizeof(uint8_t) * k_stride;
267
+ zero_offset = 0;
268
+ } else {
269
+ zero_size = sizeof(uint8_t) * k_stride + 8;
270
+ zero_offset = 8;
271
+ }
272
+ }
273
+ }
274
+ }
275
+
276
+ // NOLINTNEXTLINE(clang-analyzer-optin.portability.UnixAPI)
277
+ void* zero_buffer = malloc(zero_size);
278
+ if (zero_buffer == nullptr) {
279
+ pytorch_qnnp_delete_operator(convolution);
280
+ TORCH_INTERNAL_ASSERT(
281
+ false, "failed to allocate %zu bytes for zero padding",
282
+ zero_size);
283
+ }
284
+ // Need to set to input zero point
285
+ // memset(zero_buffer, input_zero_point, zero_size);
286
+ zero_buffer_size = zero_size;
287
+ convolution->zero_buffer = zero_buffer;
288
+ convolution->zero_pointer = (void*)((uintptr_t)zero_buffer + zero_offset);
289
+ }
290
+
291
+ std::unique_ptr<pytorch_qnnp_operator, QnnpackOperatorDeleter> convolution_op;
292
+ #ifdef USE_XNNPACK
293
+ xnnpack_operator xnnp_convolution_op;
294
+ #endif // USE_XNNPACK
295
+ std::unique_ptr<qnnpack::PrePackConvWeights> w;
296
+ at::Tensor orig_weight;
297
+ at::Tensor bias;
298
+ torch::List<int64_t> stride_;
299
+ torch::List<int64_t> padding_;
300
+ torch::List<int64_t> output_padding_;
301
+ torch::List<int64_t> dilation_;
302
+ int64_t groups_;
303
+ bool transpose_;
304
+ bool is_per_channel_;
305
+ std::optional<double> input_scale;
306
+ std::vector<int64_t> kernel_;
307
+ at::Tensor w_scales;
308
+ std::vector<uint8_t> w_zero_points;
309
+ std::vector<float> requantization_scales;
310
+ size_t zero_buffer_size;
311
+
312
+ at::Tensor apply(
313
+ const at::Tensor& input,
314
+ double output_scale,
315
+ int64_t output_zero_point) override;
316
+
317
+ at::Tensor apply_relu(
318
+ const at::Tensor& input,
319
+ double output_scale,
320
+ int64_t output_zero_point) override;
321
+
322
+ at::Tensor apply_dynamic(
323
+ const at::Tensor& input,
324
+ bool reduce_range=false) override;
325
+
326
+ std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() override;
327
+
328
+ static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> prepack(
329
+ at::Tensor weight,
330
+ std::optional<at::Tensor> bias,
331
+ torch::List<int64_t> stride,
332
+ torch::List<int64_t> padding,
333
+ torch::List<int64_t> output_padding,
334
+ torch::List<int64_t> dilation,
335
+ int64_t groups,
336
+ bool transpose);
337
+
338
+ torch::List<int64_t> stride() const override {
339
+ return stride_;
340
+ }
341
+
342
+ torch::List<int64_t> padding() const override {
343
+ return padding_;
344
+ }
345
+
346
+ torch::List<int64_t> output_padding() const override {
347
+ return output_padding_;
348
+ }
349
+
350
+ torch::List<int64_t> dilation() const override {
351
+ return dilation_;
352
+ }
353
+
354
+ int64_t groups() const override {
355
+ return groups_;
356
+ }
357
+
358
+ bool transpose() const override {
359
+ return transpose_;
360
+ }
361
+
362
+ bool per_channel() const {
363
+ return is_per_channel_;
364
+ }
365
+
366
+ private:
367
+ std::mutex qnnp_mutex_;
368
+ template <bool ReluFused>
369
+ at::Tensor apply_impl(
370
+ const at::Tensor& input,
371
+ double output_scale,
372
+ int64_t output_zero_point);
373
+
374
+ #ifdef USE_XNNPACK
375
+ template <typename scalar_t, bool ReluFused>
376
+ at::Tensor apply_impl_xnnp(
377
+ const at::Tensor& input,
378
+ double output_scale,
379
+ int64_t output_zero_point);
380
+ #endif // USE_XNNPACK
381
+ };
382
+
383
+ enum class Activation : uint8_t { NONE = 0, RELU = 1 };
384
+
385
+ #if defined(__ANDROID__) && !defined(__NDK_MAJOR__)
386
+ template <class T>
387
+ inline float Round(const float x) {
388
+ return ::nearbyintf(x);
389
+ }
390
+ inline double Round(const double x) {
391
+ return ::nearbyint(x);
392
+ }
393
+ #else
394
+ template <class T>
395
+ inline T Round(const T x) {
396
+ return std::nearbyint(x);
397
+ }
398
+ #endif
399
+
400
+ template<typename T>
401
+ inline T QuantizeValue(float scale, int32_t zero_point, float value) {
402
+ const int32_t qmin = std::numeric_limits<T>::min();
403
+ const int32_t qmax = std::numeric_limits<T>::max();
404
+ auto r = zero_point + static_cast<int32_t>(Round(value / scale));
405
+ r = std::max(r, qmin);
406
+ r = std::min(r, qmax);
407
+ return static_cast<T>(r);
408
+ }
409
+
410
+ template<typename T>
411
+ inline std::pair<T, T> activationLimits(
412
+ float scale,
413
+ int32_t zero_point,
414
+ Activation Ac) {
415
+ switch (Ac) {
416
+ case Activation::NONE:
417
+ return {std::numeric_limits<T>::min(),
418
+ std::numeric_limits<T>::max()};
419
+ case Activation::RELU:
420
+ return {QuantizeValue<T>(scale, zero_point, 0.0),
421
+ std::numeric_limits<T>::max()};
422
+ default:
423
+ #ifdef _MSC_VER
424
+ __assume(0);
425
+ #else
426
+ __builtin_unreachable();
427
+ #endif
428
+ }
429
+ }
430
+
431
+ namespace at {
432
+ namespace native {
433
+ namespace qnnp_avgpool_helper {
434
+ Tensor qnnpack_avg_pool2d(
435
+ Tensor input,
436
+ IntArrayRef kernel_size,
437
+ IntArrayRef stride,
438
+ IntArrayRef padding,
439
+ bool ceil_mode,
440
+ bool count_include_pad,
441
+ std::optional<int64_t> divisor_override);
442
+ } // qnnp_avgpool_helper
443
+ } // namespace native
444
+ } // namespace at
445
+
446
+ namespace {
447
+ C10_UNUSED std::vector<float> generate_requantization_scales(
448
+ const at::Tensor& weight_scales,
449
+ const float input_scale,
450
+ const float output_scale,
451
+ std::vector<float>& requant_scales) {
452
+ // Since weight scale is allocated with padding
453
+ // weight_scales.numel() gives us padded num elements.
454
+ const auto num_output_channels_padded = weight_scales.numel();
455
+ float *const weight_scales_data = weight_scales.data_ptr<float>();
456
+ if (static_cast<int64_t>(requant_scales.size()) < num_output_channels_padded) {
457
+ requant_scales.resize(num_output_channels_padded);
458
+ }
459
+ for (const auto i : c10::irange(num_output_channels_padded)) {
460
+ const auto inverse_output_scale = 1.f /output_scale;
461
+ requant_scales[i] = (weight_scales_data[i] * input_scale) * inverse_output_scale;
462
+ TORCH_CHECK(
463
+ (requant_scales[i] > 0.0f && std::isnormal(requant_scales[i])),
464
+ "failed to create op with requantization scale: ",
465
+ requant_scales[i],
466
+ ": requantization scale must be finite and positive");
467
+ }
468
+ return requant_scales;
469
+ }
470
+
471
+ C10_UNUSED std::pair<std::vector<uint8_t>, at::Tensor> make_zero_points_and_scales_tensor(
472
+ const at::Tensor& weight_contig,
473
+ bool transpose = false,
474
+ uint32_t groups = 1
475
+ ) {
476
+ const int out_ch_idx = transpose ? 1 : 0;
477
+ const auto num_output_channels = weight_contig.size(out_ch_idx) * (transpose ? groups : 1);
478
+ // Add 8 to account for bufferring needed by QNNPACK.
479
+ const auto num_output_channels_padded = num_output_channels + kPaddingChannels;
480
+ const auto qtype = weight_contig.qscheme();
481
+ std::vector<uint8_t> weight_zp(num_output_channels_padded, 0);
482
+ // Adjust weight zero point, similar to weight data.
483
+ if (qtype == at::kPerTensorAffine) {
484
+ for (const auto i : c10::irange(num_output_channels)) {
485
+ weight_zp[i] = (uint8_t)(weight_contig.q_zero_point() + 128);
486
+ }
487
+ } else if (qtype == at::kPerChannelAffine) {
488
+ TORCH_CHECK(
489
+ weight_contig.q_per_channel_zero_points().scalar_type() == at::kLong,
490
+ "Per channel zero points dtype must be long int.");
491
+ const int64_t* per_channel_zero_points =
492
+ weight_contig.q_per_channel_zero_points().data_ptr<int64_t>();
493
+ for (const auto i : c10::irange(num_output_channels)) {
494
+ weight_zp[i] = (uint8_t)(per_channel_zero_points[i] + 128);
495
+ }
496
+ } else {
497
+ TORCH_INTERNAL_ASSERT(false, "Unsupported quantization scheme.");
498
+ }
499
+ at:: Tensor weight_scales =
500
+ at::empty(
501
+ {num_output_channels_padded},
502
+ at::device(at::kCPU).dtype(at::kFloat));
503
+ float *const weight_scales_data = weight_scales.data_ptr<float>();
504
+ if (qtype == at::kPerTensorAffine) {
505
+ for (const auto i : c10::irange(num_output_channels)) {
506
+ weight_scales_data[i] = weight_contig.q_scale();
507
+ }
508
+ } else if (qtype == at::kPerChannelAffine) {
509
+ TORCH_CHECK(
510
+ weight_contig.q_per_channel_scales().scalar_type() == at::kDouble,
511
+ "Per channel scales dtype must be double.");
512
+ const double *const per_channel_scales =
513
+ weight_contig.q_per_channel_scales().data_ptr<double>();
514
+ for (const auto i : c10::irange(num_output_channels)) {
515
+ weight_scales_data[i] = static_cast<float>(per_channel_scales[i]);
516
+ }
517
+ } else {
518
+ TORCH_INTERNAL_ASSERT(false, "Unsupported quantization scheme.");
519
+ }
520
+ for (const auto i : c10::irange(num_output_channels, num_output_channels_padded)) {
521
+ weight_scales_data[i] = 1.f;
522
+ }
523
+ return {weight_zp, weight_scales};
524
+ }
525
+ } // namespace
526
+
527
+ #endif
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/QuantUtils.h ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/core/List.h>
5
+ #include <ATen/TensorOperators.h>
6
+ #include <c10/util/irange.h>
7
+ #include <algorithm>
8
+ #include <cmath>
9
+
10
+ #ifndef AT_PER_OPERATOR_HEADERS
11
+ #include <ATen/Functions.h>
12
+ #include <ATen/NativeFunctions.h>
13
+ #else
14
+ #include <ATen/ops/quantize_per_tensor_native.h>
15
+ #include <ATen/ops/quantize_per_channel_native.h>
16
+ #include <ATen/ops/zeros.h>
17
+ #endif
18
+
19
+ namespace quant_utils {
20
+ namespace {
21
+ float RawUint16ToFp16(unsigned short value) {
22
+ // Convert raw 16 bits half precision floating point number
23
+ // to single precision floating point number.
24
+ const unsigned short sign_bits = value >> 15;
25
+ const unsigned short exponent_bits = value >> 10 & 0x1f;
26
+ const unsigned short significand_bits = value & 0x3ff;
27
+
28
+ const float sign = sign_bits ? -1 : 1;
29
+ const float significand =
30
+ 1 + significand_bits * 0.0009765625f; // 0.0009765625f = 0x1p-10 = 2^-10;
31
+ const float exponent = exponent_bits - 0xf;
32
+
33
+ return sign * std::ldexp(significand, exponent);
34
+ }
35
+
36
+ template <typename T>
37
+ bool CheckAndSaturate(T max_val, T* element) {
38
+ if (*element > max_val) {
39
+ *element = max_val;
40
+ return true;
41
+ }
42
+ if (*element < -max_val) {
43
+ *element = -max_val;
44
+ return true;
45
+ }
46
+ return false;
47
+ }
48
+ }
49
+ using namespace std;
50
+ // A structure to hold quantization parameters 'scale' and 'zero_point'.
51
+ // The meaning of these values is as the constants in the quantization equation
52
+ //
53
+ // real_value = scale * (quantized_value - zero_point)
54
+ //
55
+ // In other words, 'zero_point' is the quantized value that corresponds
56
+ // to the real value 0, and 'scale' is the difference of real values
57
+ // corresponding to consecutive quantized values.
58
+ struct TensorQuantizationParams {
59
+ double scale;
60
+ std::int32_t zero_point;
61
+ int precision;
62
+ };
63
+
64
+ // Use fp16_min as the small scale cutoff because we don't want to use scales in
65
+ // fp16 subnormal range. This is to be consistent with Glow and FakeLowP
66
+ // implementation for NNPI.
67
+ constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f;
68
+
69
+ // Following implementation should be identical to fbgemm::ChooseQuantizationParams
70
+ inline TensorQuantizationParams ChooseQuantizationParams(
71
+ float min,
72
+ float max,
73
+ int32_t qmin,
74
+ int32_t qmax,
75
+ bool preserve_sparsity = false,
76
+ bool force_scale_power_of_two = false,
77
+ bool reduce_range = false) {
78
+ TORCH_CHECK(
79
+ min <= max,
80
+ "In ChooseQuantizationParams, min should be less than or equal to max");
81
+
82
+ if (reduce_range) {
83
+ qmin = qmin/2;
84
+ qmax = qmax/2;
85
+ }
86
+ if (min < 0 && max > 0 && preserve_sparsity) {
87
+ int symmetric_qmin = -((qmax - qmin) / 2 + 1);
88
+ int symmetric_qmax = (qmax - qmin) / 2;
89
+ double max_scale =
90
+ std::max(fabs(min / symmetric_qmin), fabs(max / symmetric_qmax));
91
+ min = max_scale * symmetric_qmin;
92
+ max = max_scale * symmetric_qmax;
93
+ }
94
+
95
+ // We extend the [min, max] interval to ensure that it contains 0.
96
+ // Otherwise, we would not meet the requirement that 0 be an exactly
97
+ // representable value.
98
+ min = std::min(min, 0.f);
99
+ max = std::max(max, 0.f);
100
+
101
+ TORCH_CHECK(
102
+ qmin < qmax,
103
+ "In ChooseQuantizationParams, qmin should be less than qmax");
104
+
105
+ // Use double precision for intermediate computation but use single precision
106
+ // in final number to reflect the actual number used during quantization.
107
+ double scale = (static_cast<double>(max) - min) / (qmax - qmin);
108
+ // If scale is 0 or too small so its reciprocal is infinity, we arbitrary
109
+ // adjust the scale to 0.1 . We want to avoid scale's reciprocal being
110
+ // infinity because some of fbgemm code pre-computes scale's reciprocal to do
111
+ // multiplication instead of division in the time critical part of code.
112
+ if (float(scale) == 0.0f || std::isinf(1.0f / float(scale))) {
113
+ scale = 0.1;
114
+ }
115
+ TORCH_CHECK(scale > 0, "quantization scale should be > 0");
116
+
117
+ if (force_scale_power_of_two) {
118
+ if (scale < 1) {
119
+ scale = 1.0 / (1 << static_cast<int>(floor(log(1.0 / scale) / log(2))));
120
+ } else {
121
+ scale = 1 << static_cast<int>(ceil(log(scale) / log(2)));
122
+ }
123
+ }
124
+
125
+ // Cut off small scale
126
+ if (scale < SMALL_SCALE_THRESHOLD) {
127
+ float org_scale = scale;
128
+ scale = SMALL_SCALE_THRESHOLD;
129
+ // Adjust the min and max based on the new scale
130
+ if (min == 0.0f) {
131
+ max = SMALL_SCALE_THRESHOLD * (qmax - qmin);
132
+ } else if (max == 0.0f) {
133
+ min = -SMALL_SCALE_THRESHOLD * (qmax - qmin);
134
+ } else {
135
+ float amplifier = SMALL_SCALE_THRESHOLD / org_scale;
136
+ min *= amplifier;
137
+ max *= amplifier;
138
+ }
139
+ }
140
+
141
+ // Zero-point computation.
142
+ // First the initial floating-point computation. The zero-point can be
143
+ // determined from solving an affine equation for any known pair
144
+ // (real value, corresponding quantized value).
145
+ // We know two such pairs: (rmin, qmin) and (rmax, qmax).
146
+ // The arithmetic error on the zero point computed from either pair
147
+ // will be roughly machine_epsilon * (sum of absolute values of terms)
148
+ // so we want to use the variant that adds the smaller terms.
149
+ double zero_point_from_min = qmin - min / static_cast<double>(scale);
150
+ double zero_point_from_max = qmax - max / static_cast<double>(scale);
151
+ double zero_point_from_min_error =
152
+ std::abs(qmin) - std::abs(min / static_cast<double>(scale));
153
+ double zero_point_from_max_error =
154
+ std::abs(qmax) - std::abs(max / static_cast<double>(scale));
155
+ double initial_zero_point =
156
+ zero_point_from_min_error < zero_point_from_max_error
157
+ ? zero_point_from_min
158
+ : zero_point_from_max;
159
+
160
+ // for symmetric quantization (preserve_sparsity == true), we force zero_point
161
+ // to be a middle value between qmin and qmax.
162
+ // If either min or max is 0, then we just use 0 as zero_point.
163
+ if (min < 0 && max > 0 && preserve_sparsity) {
164
+ initial_zero_point = static_cast<double>(qmin + qmax) / 2;
165
+ }
166
+
167
+ // Now we need to nudge the zero point to be an integer
168
+ // (our zero points are integer, and this is motivated by the requirement
169
+ // to be able to represent the real value "0" exactly as a quantized value,
170
+ // which is required in multiple places, for example in Im2col with zero
171
+ // padding).
172
+ int32_t nudged_zero_point = 0;
173
+ if (initial_zero_point < qmin) {
174
+ nudged_zero_point = qmin;
175
+ } else if (initial_zero_point > qmax) {
176
+ nudged_zero_point = qmax;
177
+ } else {
178
+ nudged_zero_point = nearbyint(initial_zero_point);
179
+ }
180
+
181
+ TensorQuantizationParams result;
182
+ result.scale = scale;
183
+ result.zero_point = nudged_zero_point;
184
+ return result;
185
+ }
186
+
187
+ // This function helps to convert the Conv1D dimensions usable by the Conv2d op.
188
+ constexpr int64_t kConv1dSqueezeDim = 0;
189
+ static C10_UNUSED torch::List<int64_t> MakeArgForConv1d(const torch::List<int64_t>& arg,
190
+ int64_t base_value) {
191
+ TORCH_CHECK(!arg.empty(), "Argument must have elements.");
192
+ torch::List<int64_t> result({arg.get(0), base_value});
193
+ if (arg.size() == 1) {
194
+ result[1] = arg.get(0);
195
+ } else {
196
+ result[1] = arg.get(1);
197
+ }
198
+ result[kConv1dSqueezeDim] = base_value;
199
+ return result;
200
+ }
201
+
202
+ // The range for using FP16 quantization of weights requires that the elements
203
+ // should be in the range of [5.96e-8, 65504]. If it is out of range, then the
204
+ // number will be saturated to max or min representable values by FP16.
205
+ inline void HandleWeightsSaturation(int64_t N, float* weight) {
206
+ const float kFp16Max = RawUint16ToFp16(0x7BFF);
207
+ bool found_out_of_range = false;
208
+ for (const auto i : c10::irange(N)) {
209
+ bool saturate = CheckAndSaturate<float>(kFp16Max, weight + i);
210
+ if (saturate) {
211
+ found_out_of_range = true;
212
+ }
213
+ }
214
+ if (found_out_of_range) {
215
+ TORCH_WARN("FOUND weight out of range ");
216
+ }
217
+ }
218
+
219
+ // Util function for quantizing bias.
220
+ inline at::Tensor QuantizeBias(
221
+ bool is_per_channel,
222
+ const at::Tensor& bias,
223
+ const at::Tensor& weight_contig,
224
+ double input_scale) {
225
+ at::Tensor qbias;
226
+ if (is_per_channel) {
227
+ auto bias_quant_scales =
228
+ weight_contig.q_per_channel_scales() * input_scale;
229
+ auto bias_zp = at::zeros(bias_quant_scales.sizes(), c10::kInt);
230
+ qbias = at::native::quantize_per_channel(
231
+ bias, bias_quant_scales, bias_zp, 0, c10::kQInt32);
232
+ } else {
233
+ qbias = at::native::quantize_per_tensor(
234
+ bias, weight_contig.q_scale() * input_scale, 0, c10::kQInt32);
235
+ }
236
+ return qbias;
237
+ }
238
+
239
+ } // namespace quant_utils
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/QuantizedOps.h ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/Tensor.h>
3
+ #include <ATen/core/IListRef.h>
4
+ #include <ATen/Dispatch.h>
5
+ #include <ATen/TensorIterator.h>
6
+ #include <ATen/native/Activation.h>
7
+ #include <ATen/native/DispatchStub.h>
8
+
9
+ namespace at {
10
+ namespace native {
11
+
12
+ using qrelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/);
13
+ using qrelu_leaky_fn = void (*)(Tensor& /*out*/, const Tensor& /*qx*/,
14
+ const Scalar& /*negval_*/);
15
+ using qgelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/, GeluType /* approximate */);
16
+ using qsigmoid_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/, double output_scale, int64_t output_zero_point);
17
+ using qhardsigmoid_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/);
18
+ using qclamp_fn = void (*)(
19
+ const at::Tensor& /*qx*/,
20
+ const Scalar& min,
21
+ const Scalar& max,
22
+ at::Tensor& /*qy*/);
23
+ using qclamp_minmax_fn = void (*)(
24
+ const at::Tensor& /*qx*/,
25
+ const Scalar& /*min or max*/,
26
+ at::Tensor& /*qy*/);
27
+ using qthreshold_fn = void (*)(
28
+ const at::Tensor& /*qx*/,
29
+ const Scalar& threshold,
30
+ const Scalar& value,
31
+ at::Tensor& /*qy*/);
32
+ using qtanh_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/);
33
+ using qelu_fn = void(*)(
34
+ const at::Tensor& /*qx*/,
35
+ const Scalar& /*alpha*/,
36
+ const Scalar& /*scale*/,
37
+ const Scalar& /*input_scale*/,
38
+ at::Tensor& /*qy*/);
39
+ using qbinary_fn =
40
+ void (*)(Tensor& /*out*/, const Tensor& /*self*/, const Tensor& /*other*/);
41
+ using qadd_scalar_fn =
42
+ void (*)(Tensor& /*out*/, const Tensor& /*self*/, const Scalar& other /*other*/);
43
+ using qhardswish_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/);
44
+ using qdropout_fn = void(*)(
45
+ const at::Tensor& /*qx*/,
46
+ const Scalar& /*p*/,
47
+ bool training /*training*/,
48
+ at::Tensor& /*qy*/);
49
+ using qmaxpool_2d_fn = void (*)(
50
+ const Tensor& qx,
51
+ int64_t iC, // input/output channels
52
+ int64_t iH,
53
+ int64_t iW, // input sizes
54
+ int64_t oH,
55
+ int64_t oW, // output sizes
56
+ int64_t kH,
57
+ int64_t kW, // kernel size
58
+ int64_t sH,
59
+ int64_t sW, // strides
60
+ int64_t pH,
61
+ int64_t pW, // padding
62
+ int64_t dH,
63
+ int64_t dW, // dilation
64
+ Tensor& qy);
65
+ using qmaxpool_3d_fn = void (*)(
66
+ const Tensor& qx,
67
+ int64_t iC, // input/output channels
68
+ int64_t iT,
69
+ int64_t iH,
70
+ int64_t iW, // input sizes
71
+ int64_t oT,
72
+ int64_t oH,
73
+ int64_t oW, // output sizes
74
+ int64_t kT,
75
+ int64_t kH,
76
+ int64_t kW, // kernel size
77
+ int64_t sT,
78
+ int64_t sH,
79
+ int64_t sW, // strides
80
+ int64_t pT,
81
+ int64_t pH,
82
+ int64_t pW, // padding
83
+ int64_t dT,
84
+ int64_t dH,
85
+ int64_t dW, // dilation
86
+ Tensor& qy);
87
+ using qadaptive_avg_pool2d_fn = void (*)(
88
+ const Tensor& qx,
89
+ Tensor& qy,
90
+ int64_t sizeB,
91
+ int64_t sizeC,
92
+ int64_t isizeH,
93
+ int64_t isizeW,
94
+ int64_t osizeH,
95
+ int64_t osizeW,
96
+ int64_t istrideB,
97
+ int64_t istrideC,
98
+ int64_t istrideH,
99
+ int64_t istrideW);
100
+ using qadaptive_avg_pool3d_fn = void (*)(
101
+ const Tensor& qx,
102
+ Tensor& qy,
103
+ int64_t sizeB,
104
+ int64_t sizeC,
105
+ int64_t isizeD,
106
+ int64_t isizeH,
107
+ int64_t isizeW,
108
+ int64_t osizeD,
109
+ int64_t osizeH,
110
+ int64_t osizeW,
111
+ int64_t istrideB,
112
+ int64_t istrideC,
113
+ int64_t istrideD,
114
+ int64_t istrideH,
115
+ int64_t istrideW);
116
+ using qavg_pool2d_fn = void (*)(
117
+ const Tensor& qx,
118
+ Tensor& qy,
119
+ int64_t nBatch,
120
+ int64_t nInputPlane,
121
+ int64_t inputWidth,
122
+ int64_t inputHeight,
123
+ int64_t outputWidth,
124
+ int64_t outputHeight,
125
+ int kW,
126
+ int kH,
127
+ int dW,
128
+ int dH,
129
+ int padW,
130
+ int padH,
131
+ bool count_include_pad,
132
+ std::optional<int64_t> divisor_override);
133
+
134
+ using qavg_pool3d_fn = void (*)(
135
+ const Tensor& qx,
136
+ Tensor& qy,
137
+ int64_t nBatch,
138
+ int64_t nInputPlane,
139
+ int64_t inputWidth,
140
+ int64_t inputHeight,
141
+ int64_t inputDepth,
142
+ int64_t outputWidth,
143
+ int64_t outputHeight,
144
+ int64_t outputDepth,
145
+ int kW,
146
+ int kH,
147
+ int kD,
148
+ int dW,
149
+ int dH,
150
+ int dD,
151
+ int padW,
152
+ int padH,
153
+ int padD,
154
+ bool count_include_pad,
155
+ std::optional<int64_t> divisor_override);
156
+
157
+ using qupsample_bilinear2d_fn = void (*)(
158
+ Tensor& output,
159
+ const Tensor& input,
160
+ int64_t input_height,
161
+ int64_t input_width,
162
+ int64_t output_height,
163
+ int64_t output_width,
164
+ int64_t nbatch,
165
+ int64_t channels,
166
+ bool align_corners,
167
+ std::optional<double> scales_h,
168
+ std::optional<double> scales_w);
169
+
170
+ using qcat_nhwc_fn = Tensor (*)(
171
+ const MaterializedITensorListRef& qxs,
172
+ int64_t dim,
173
+ double scale,
174
+ int64_t zero_point);
175
+ using qtopk_fn = void(*)(Tensor&, Tensor&, const Tensor&, int64_t, int64_t, bool, bool);
176
+
177
+ using qbatch_norm_fn = void(*)(int64_t, int64_t, int64_t, int64_t, int64_t, const Tensor&, const Tensor&, const Tensor&, Tensor&);
178
+
179
+ using qnormalize_fn = void (*)(
180
+ const Tensor& /* X */,
181
+ const Tensor& /* gamma */,
182
+ const Tensor& /* beta */,
183
+ bool /* affine_per_channel */,
184
+ int /* num_channels */,
185
+ int /* num_groups */,
186
+ int64_t /* M */,
187
+ int64_t /* N */,
188
+ double /* eps */,
189
+ Tensor* /* Y */);
190
+
191
+ using qmean_inner_dim_fn = void (*)(
192
+ const Tensor& /* X */,
193
+ OptionalIntArrayRef /* opt_dim */,
194
+ bool /* keepdim */,
195
+ std::optional<ScalarType> /* opt_dtype */,
196
+ Tensor& /* Y */);
197
+
198
+ using qstd_inner_dim_fn = void (*)(
199
+ const Tensor& /* X */,
200
+ OptionalIntArrayRef /* dim */,
201
+ const std::optional<Scalar>& /* correction */,
202
+ bool /* keepdim */,
203
+ Tensor& /* Y */);
204
+
205
+ using qnormalize_nhwc_fn = void (*)(
206
+ const Tensor& /* X */,
207
+ const Tensor& /* gamma */,
208
+ const Tensor& /* beta */,
209
+ bool /* affine_per_channel */,
210
+ int /* num_channels */,
211
+ int /* num_groups */,
212
+ int64_t /* M */,
213
+ int64_t /* N */,
214
+ double /* eps */,
215
+ Tensor* /* Y */);
216
+
217
+ using qprelu_fn = void (*)(Tensor& /*out*/, const Tensor& /*qx*/,
218
+ const Tensor& /*qw*/);
219
+
220
+ DECLARE_DISPATCH(qadaptive_avg_pool2d_fn, qadaptive_avg_pool2d_nhwc_stub);
221
+ DECLARE_DISPATCH(qadaptive_avg_pool3d_fn, qadaptive_avg_pool3d_ndhwc_stub);
222
+ DECLARE_DISPATCH(qadd_scalar_fn, qadd_scalar_relu_stub);
223
+ DECLARE_DISPATCH(qadd_scalar_fn, qadd_scalar_stub);
224
+ DECLARE_DISPATCH(qavg_pool2d_fn, qavg_pool2d_nhwc_stub);
225
+ DECLARE_DISPATCH(qavg_pool3d_fn, qavg_pool3d_nhwc_stub);
226
+ DECLARE_DISPATCH(qbatch_norm_fn, qbatch_norm_relu_stub);
227
+ DECLARE_DISPATCH(qbatch_norm_fn, qbatch_norm_stub);
228
+ DECLARE_DISPATCH(qbinary_fn, qadd_relu_stub);
229
+ DECLARE_DISPATCH(qbinary_fn, qadd_stub);
230
+ DECLARE_DISPATCH(qbinary_fn, qmul_relu_stub);
231
+ DECLARE_DISPATCH(qbinary_fn, qmul_stub);
232
+ DECLARE_DISPATCH(qcat_nhwc_fn, qcat_nhwc_stub);
233
+ DECLARE_DISPATCH(qcat_nhwc_fn, qcat_relu_nhwc_stub);
234
+ DECLARE_DISPATCH(qclamp_fn, qclamp_stub);
235
+ DECLARE_DISPATCH(qclamp_minmax_fn, qclamp_min_stub);
236
+ DECLARE_DISPATCH(qclamp_minmax_fn, qclamp_max_stub);
237
+ DECLARE_DISPATCH(qelu_fn, qelu_stub);
238
+ DECLARE_DISPATCH(qhardsigmoid_fn, qhardsigmoid_stub);
239
+ DECLARE_DISPATCH(qhardswish_fn, qhardswish_stub);
240
+ DECLARE_DISPATCH(qdropout_fn, qdropout_stub);
241
+ DECLARE_DISPATCH(qmaxpool_2d_fn, qmaxpool_2d_nhwc_stub);
242
+ DECLARE_DISPATCH(qmaxpool_3d_fn, qmaxpool_3d_nthwc_stub);
243
+ DECLARE_DISPATCH(qnormalize_fn, quantized_normalize_stub);
244
+ DECLARE_DISPATCH(qnormalize_nhwc_fn, quantized_groupnorm_nhwc_stub);
245
+ DECLARE_DISPATCH(qrelu_fn, qrelu_stub);
246
+ DECLARE_DISPATCH(qrelu_leaky_fn, qrelu_leaky_stub);
247
+ DECLARE_DISPATCH(qgelu_fn, qgelu_stub);
248
+ DECLARE_DISPATCH(qsigmoid_fn, qsigmoid_stub);
249
+ DECLARE_DISPATCH(qtanh_fn, qtanh_stub);
250
+ DECLARE_DISPATCH(qthreshold_fn, qthreshold_stub);
251
+ DECLARE_DISPATCH(qtopk_fn, qtopk_stub);
252
+ DECLARE_DISPATCH(qupsample_bilinear2d_fn, qupsample_bilinear2d_nhwc_stub);
253
+ DECLARE_DISPATCH(qmean_inner_dim_fn, qmean_inner_dim_stub);
254
+ DECLARE_DISPATCH(qstd_inner_dim_fn, qstd_inner_dim_stub);
255
+ DECLARE_DISPATCH(qprelu_fn, qprelu_stub);
256
+
257
+ } // namespace native
258
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/RuyUtils.h ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #ifdef USE_RUY_QMATMUL
4
+
5
+ #include <ruy/ruy.h>
6
+
7
+ namespace at {
8
+ namespace native {
9
+ namespace ruy_utils {
10
+
11
+ ruy::Context* get_ruy_context();
12
+
13
+ void quantize_multiplier(double scale,
14
+ int* multiplier_fixedpoint,
15
+ int* multiplier_exponent);
16
+
17
+ } // namespace ruy_utils
18
+ } // namespace native
19
+ } // namespace
20
+
21
+ #endif // USE_RUY_QMATMUL
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/XnnpackUtils.h ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #ifdef USE_XNNPACK
4
+ #include <cstdint>
5
+
6
+ #include <ATen/core/Tensor.h>
7
+ #include <ATen/native/xnnpack/Common.h>
8
+
9
+ using xnnpack_operator = at::native::xnnpack::Operator;
10
+
11
+ namespace at {
12
+ namespace native {
13
+ namespace xnnp_utils {
14
+
15
+ /*
16
+ * Return shape in the same order as the memory format
17
+ * e.g. channels_last will return NHWC instead of NCHW
18
+ */
19
+ std::vector<size_t> get_mem_format_aware_shape(const at::Tensor& in);
20
+
21
+ /*
22
+ * Input is always int8_t, output can be [int8_t, uint8_t].
23
+ * input + offset = output
24
+ * int8_t + 128 = uint8_t
25
+ * int8_t + 0 = int8_t
26
+ */
27
+ template <typename PT>
28
+ void q8_copy_int8_weight_and_add_offset(const at::Tensor& in, at::Tensor& out);
29
+
30
+ template <int kSpatialDim>
31
+ Tensor convert_conv_weights_to_channel_last_tensor(
32
+ const at::Tensor& src,
33
+ int groups,
34
+ bool transpose);
35
+
36
+ /*
37
+ * Series of create wrapper functions to call xnn_create_[de]conv* functions.
38
+ */
39
+ C10_ALWAYS_INLINE
40
+ enum xnn_status xnnp_create_convolution2d_nhwc(
41
+ uint32_t pad_top,
42
+ uint32_t pad_right,
43
+ uint32_t pad_bottom,
44
+ uint32_t pad_left,
45
+ uint32_t kernel_h,
46
+ uint32_t kernel_w,
47
+ uint32_t stride_h,
48
+ uint32_t stride_w,
49
+ uint32_t dilation_h,
50
+ uint32_t dilation_w,
51
+ uint32_t groups,
52
+ size_t group_input_channels,
53
+ size_t group_output_channels,
54
+ size_t ip_chan_stride,
55
+ size_t op_chan_stride,
56
+ int8_t izp,
57
+ float ip_scale,
58
+ int8_t kzp,
59
+ const float* k_scales,
60
+ const int8_t* kernel,
61
+ const int32_t* bias,
62
+ int8_t ozp,
63
+ float op_scale,
64
+ int8_t op_min,
65
+ int8_t op_max,
66
+ uint32_t flags,
67
+ xnn_operator_t* op,
68
+ bool per_channel,
69
+ bool transpose) {
70
+ /* Symmetric quantization forces kzp = 0 */
71
+ TORCH_CHECK(!kzp, "XNNPACK Q[SC]8 conv kernels expects kernel zero point to be zero."
72
+ "But got: ", kzp);
73
+
74
+ if (transpose) {
75
+ TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!");
76
+ return xnn_create_deconvolution2d_nhwc_qs8(
77
+ pad_top, /* uint32_t output_padding_top */
78
+ pad_right, /* uint32_t output_padding_right */
79
+ pad_bottom, /* uint32_t output_padding_bottom */
80
+ pad_left, /* uint32_t output_padding_left */
81
+ kernel_h, /* uint32_t kernel_height */
82
+ kernel_w, /* uint32_t kernel_width */
83
+ stride_h, /* uint32_t stride_height */
84
+ stride_w, /* uint32_t stride_width */
85
+ dilation_h, /* uint32_t dilation_height */
86
+ dilation_w, /* uint32_t dilation_width */
87
+ groups, /* uint32_t groups */
88
+ group_input_channels, /* size_t group_input_channels */
89
+ group_output_channels, /* size_t group_output_channels */
90
+ ip_chan_stride, /* size_t input_pixel_stride */
91
+ op_chan_stride, /* size_t output_pixel_stride */
92
+ izp, /* int8_t input_zero_point */
93
+ ip_scale, /* float input_scale */
94
+ k_scales[0], /* float kernel_scale */
95
+ kernel, /* const int8_t* kernel */
96
+ bias, /* const int32_t* bias */
97
+ ozp, /* int8_t output_zero_point */
98
+ op_scale, /* float output_scale */
99
+ op_min, /* int8_t output_min */
100
+ op_max, /* int8_t output_max */
101
+ flags, /* uint32_t flags */
102
+ nullptr, /* xnn_caches_t caches */
103
+ nullptr, /* xnn_weights_cache_t weights_cache */
104
+ op); /* xnn_operator_t* deconvolution_op_out */
105
+
106
+ }
107
+
108
+ if (!per_channel) {
109
+ return xnn_create_convolution2d_nhwc_qs8(
110
+ pad_top, /* uint32_t input_padding_top */
111
+ pad_right, /* uint32_t input_padding_right */
112
+ pad_bottom, /* uint32_t input_padding_bottom */
113
+ pad_left, /* uint32_t input_padding_left */
114
+ kernel_h, /* uint32_t kernel_height */
115
+ kernel_w, /* uint32_t kernel_width */
116
+ stride_h, /* uint32_t subsampling_height */
117
+ stride_w, /* uint32_t subsampling_width */
118
+ dilation_h, /* uint32_t dilation_height */
119
+ dilation_w, /* uint32_t dilation_width */
120
+ groups, /* uint32_t groups */
121
+ group_input_channels, /* size_t group_input_channels */
122
+ group_output_channels, /* size_t group_output_channels*/
123
+ ip_chan_stride, /* size_t input_channel_stride */
124
+ op_chan_stride, /* size_t output_channel_stride */
125
+ izp, /* int8_t input_zero_point */
126
+ ip_scale, /* float input_scale */
127
+ k_scales[0], /* float kernel_scale */
128
+ kernel, /* const int8_t* kernel */
129
+ bias, /* const int32_t* bias */
130
+ ozp, /* int8_t output_zero_point */
131
+ op_scale, /* float output_scale */
132
+ op_min, /* int8_t output_min */
133
+ op_max, /* int8_t output_max */
134
+ flags, /* uint32_t flags */
135
+ nullptr, /* xnn_caches_t caches */
136
+ nullptr, /* xnn_weights_cache_t weights_cache */
137
+ op); /* xnn_operator_t* convolution_op_out */
138
+ } else { /* per_channel */
139
+ return xnn_create_convolution2d_nhwc_qs8_qc8w(
140
+ pad_top, /* uint32_t input_padding_top */
141
+ pad_right, /* uint32_t input_padding_right */
142
+ pad_bottom, /* uint32_t input_padding_bottom */
143
+ pad_left, /* uint32_t input_padding_left */
144
+ kernel_h, /* uint32_t kernel_height */
145
+ kernel_w, /* uint32_t kernel_width */
146
+ stride_h, /* uint32_t subsampling_height */
147
+ stride_w, /* uint32_t subsampling_width */
148
+ dilation_h, /* uint32_t dilation_height */
149
+ dilation_w, /* uint32_t dilation_width */
150
+ groups, /* uint32_t groups */
151
+ group_input_channels, /* size_t group_input_channels */
152
+ group_output_channels, /* size_t group_output_channels*/
153
+ ip_chan_stride, /* size_t input_channel_stride */
154
+ op_chan_stride, /* size_t output_channel_stride */
155
+ izp, /* int8_t input_zero_point */
156
+ ip_scale, /* float input_scale */
157
+ k_scales, /* const float* kernel_scale */
158
+ kernel, /* const int8_t* kernel */
159
+ bias, /* const int32_t* bias */
160
+ ozp, /* int8_t output_zero_point */
161
+ op_scale, /* float output_scale */
162
+ op_min, /* int8_t output_min */
163
+ op_max, /* int8_t output_max */
164
+ flags, /* uint32_t flags */
165
+ nullptr, /* xnn_caches_t caches */
166
+ nullptr, /* xnn_weights_cache_t weights_cache */
167
+ op); /* xnn_operator_t* convolution_op_out */
168
+ }
169
+ }
170
+
171
+ /*
172
+ * Series of reshape wrapper functions to call xnn_reshape_[de]conv* functions.
173
+ */
174
+ C10_ALWAYS_INLINE
175
+ enum xnn_status xnnp_reshape_convolution2d_nhwc(
176
+ xnn_operator_t op,
177
+ size_t batch,
178
+ size_t in_h,
179
+ size_t in_w,
180
+ pthreadpool_t pt_pool,
181
+ bool per_channel = false,
182
+ bool transpose = false,
183
+ uint32_t adj_h = 0,
184
+ uint32_t adj_w = 0) {
185
+ if(transpose) {
186
+ TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!");
187
+ return xnn_reshape_deconvolution2d_nhwc_qs8(
188
+ op, /* xnn_operator_t deconvolution_op */
189
+ batch, /* size_t batch_size */
190
+ in_h, /* size_t input_height */
191
+ in_w, /* size_t input_width */
192
+ adj_h, /* uint32_t adjustment_height */
193
+ adj_w, /* uint32_t adjustment_width */
194
+ nullptr, /* size_t* output_height_out */
195
+ nullptr, /* size_t* output_width_out */
196
+ pt_pool); /* pthreadpool_t threadpool */
197
+ }
198
+
199
+ size_t workspace_size = SIZE_MAX;
200
+ size_t workspace_alignment = SIZE_MAX;
201
+
202
+ if (!per_channel) {
203
+ return xnn_reshape_convolution2d_nhwc_qs8(
204
+ op, /* xnn_operator_t convolution_op */
205
+ batch, /* size_t batch_size */
206
+ in_h, /* size_t input_height */
207
+ in_w, /* size_t input_width */
208
+ &workspace_size, /* size_t* workspace_size */
209
+ &workspace_alignment, /* size_t* workspace_alignment */
210
+ nullptr, /* size_t* output_height_out */
211
+ nullptr, /* size_t* output_width_out */
212
+ pt_pool); /* pthreadpool_t threadpool */
213
+ } else { /* per_channel */
214
+ return xnn_reshape_convolution2d_nhwc_qs8_qc8w(
215
+ op, /* xnn_operator_t convolution_op */
216
+ batch, /* size_t batch_size */
217
+ in_h, /* size_t input_height */
218
+ in_w, /* size_t input_width */
219
+ &workspace_size, /* size_t* workspace_size */
220
+ &workspace_alignment, /* size_t* workspace_alignment */
221
+ nullptr, /* size_t* output_height_out */
222
+ nullptr, /* size_t* output_width_out */
223
+ pt_pool); /* pthreadpool_t threadpool */
224
+ }
225
+ }
226
+
227
+
228
+ /*
229
+ * Series of setup wrapper functions to call xnn_setup_[de]conv* functions.
230
+ */
231
+ C10_ALWAYS_INLINE
232
+ enum xnn_status xnnp_setup_convolution2d_nhwc(
233
+ xnn_operator_t op,
234
+ const int8_t* inp,
235
+ int8_t* outp,
236
+ bool per_channel = false,
237
+ bool transpose = false) {
238
+ if(transpose) {
239
+ TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!");
240
+
241
+ return xnn_setup_deconvolution2d_nhwc_qs8(
242
+ op, /* xnn_operator_t deconvolution_op */
243
+ inp, /* const int8_t* input */
244
+ outp); /* int8_t* output */
245
+ }
246
+
247
+ if (!per_channel) {
248
+ return xnn_setup_convolution2d_nhwc_qs8(
249
+ op, /* xnn_operator_t deconvolution_op */
250
+ nullptr, /* void workspace */
251
+ inp, /* const int8_t* input */
252
+ outp); /* int8_t* output */
253
+ } else { /* per_channel */
254
+ return xnn_setup_convolution2d_nhwc_qs8_qc8w(
255
+ op, /* xnn_operator_t deconvolution_op */
256
+ nullptr, /* void workspace */
257
+ inp, /* const int8_t* input */
258
+ outp); /* int8_t* output */
259
+ }
260
+ }
261
+
262
+
263
+ /*
264
+ * Series of wrapper functions to call xnn_create* and xnn_setup*
265
+ * functions for linear
266
+ */
267
+ C10_ALWAYS_INLINE
268
+ enum xnn_status xnnp_create_fully_connected_nc(
269
+ size_t input_channels,
270
+ size_t output_channels,
271
+ size_t input_stride,
272
+ size_t output_stride,
273
+ int8_t input_zero_point,
274
+ float input_scale,
275
+ int8_t kernel_zero_point,
276
+ float kernel_scale,
277
+ const int8_t* kernel,
278
+ const int32_t* bias,
279
+ int8_t output_zero_point,
280
+ float output_scale,
281
+ int8_t output_min,
282
+ int8_t output_max,
283
+ uint32_t flags,
284
+ xnn_operator_t* fully_connected_op_out) {
285
+ /* Symmetric quantization forces kzp = 0 */
286
+ TORCH_CHECK(!kernel_zero_point, "XNNPACK QS8 linear kernel expects kernel zero point to be zero."
287
+ "But got: ", kernel_zero_point);
288
+ return xnn_create_fully_connected_nc_qs8(
289
+ input_channels, /* size_t input_channels */
290
+ output_channels, /* size_t output_channels */
291
+ input_stride, /* size_t input_stride */
292
+ output_stride, /* size_t output_stride */
293
+ input_zero_point, /* int8_t input_zero_point */
294
+ input_scale, /* float input_scale */
295
+ kernel_scale, /* float kernel_scale */
296
+ kernel, /* const int8_t* kernel */
297
+ bias, /* const int32_t* bias */
298
+ output_zero_point, /* int8_t output_zero_point */
299
+ output_scale, /* float output_scale */
300
+ output_min, /* int8_t output_min */
301
+ output_max, /* int8_t output_max */
302
+ flags, /* uint32_t flags */
303
+ nullptr, /* xnn_caches_t caches */
304
+ nullptr, /* xnn_weights_cache_t */
305
+ fully_connected_op_out); /* xnn_operator_t* fully_connected_op_out */
306
+ }
307
+
308
+ C10_ALWAYS_INLINE
309
+ enum xnn_status xnnp_reshape_fully_connected_nc(
310
+ xnn_operator_t fully_connected_op,
311
+ size_t batch_size,
312
+ pthreadpool_t threadpool) {
313
+ return xnn_reshape_fully_connected_nc_qs8(
314
+ fully_connected_op, /* xnn_operator_t fully_connected_op */
315
+ batch_size, /* size_t batch_size */
316
+ threadpool); /* pthreadpool_t threadpool */
317
+ }
318
+
319
+ C10_ALWAYS_INLINE
320
+ enum xnn_status xnnp_setup_fully_connected_nc(
321
+ xnn_operator_t fully_connected_op,
322
+ const int8_t* input,
323
+ int8_t* output) {
324
+ return xnn_setup_fully_connected_nc_qs8(
325
+ fully_connected_op, /* xnn_operator_t fully_connected_op */
326
+ input, /* const int8_t* input */
327
+ output /* int8_t* output */
328
+ );
329
+ }
330
+
331
+ } // namespace xnnp_utils
332
+ } // namespace native
333
+ } // namespace at
334
+
335
+ #endif // USE_XNNPACK
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/conv_serialization.h ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/core/List.h>
5
+ #include <ATen/native/quantized/cpu/fbgemm_utils.h>
6
+ #include <ATen/native/quantized/cpu/QnnpackUtils.h>
7
+ #include <ATen/native/quantized/cpu/OnednnUtils.h>
8
+ #include <c10/util/irange.h>
9
+ #if !defined(__s390x__) && !defined(__powerpc__)
10
+ #include <cpuinfo.h>
11
+ #endif
12
+
13
+ #ifndef AT_PER_OPERATOR_HEADERS
14
+ #include <ATen/Functions.h>
15
+ #else
16
+ #include <ATen/ops/from_blob.h>
17
+ #endif
18
+
19
+
20
+ #include <tuple>
21
+
22
+ /* Convolution prepacked parameters serialization.
23
+ *
24
+ * Version 1
25
+ *
26
+ * - Fields:
27
+ * 1. weight
28
+ * 2. bias
29
+ * 3. stride x kSpatialDim
30
+ * 4. padding x kSpatialDim
31
+ * 5. dilation x kSpatialDim
32
+ * 6. groups
33
+ *
34
+ * Version 2
35
+ *
36
+ * - Fields:
37
+ * 0. version (string)
38
+ * 1. list of non-optional tensors
39
+ * 0: packed parameters (int16_t)
40
+ * - kSpatialDim
41
+ * - stride x kSpatialDim
42
+ * - padding x kSpatialDim
43
+ * - dilation x kSpatialDim
44
+ * - output_padding x kSpatialDim
45
+ * - groups
46
+ * - transpose (0 or 1)
47
+ * 1: weight
48
+ * 2. list of optional tensors
49
+ * 0: bias
50
+ *
51
+ * Version 3
52
+ *
53
+ * - Fields:
54
+ * 0. version (int64_t)
55
+ * 1. list of int64_t configuration values
56
+ * - kSpatialDim
57
+ * - stride x kSpatialDim
58
+ * - padding x kSpatialDim
59
+ * - dilation x kSpatialDim
60
+ * - output_padding x kSpatialDim
61
+ * - groups
62
+ * - flags (bitmask)
63
+ * - (1 << 0) transpose (1 = yes)
64
+ * 2. list of optional tensors
65
+ * 0: None (helps with type inference)
66
+ * 1: weight (this must be present)
67
+ * 2: bias
68
+ */
69
+
70
+ using ConvParamsSerializationTypeV2 = std::tuple<
71
+ // version, for versions 2 and up
72
+ std::string,
73
+ // non-optional tensors
74
+ std::vector<at::Tensor>,
75
+ // optional tensors
76
+ std::vector<std::optional<at::Tensor>>>;
77
+
78
+ using ConvParamsSerializationTypeV3 = std::tuple<
79
+ // version, int for versions 3 and up
80
+ int64_t,
81
+ // configuration values
82
+ std::vector<int64_t>,
83
+ // optional tensors
84
+ std::vector<std::optional<at::Tensor>>>;
85
+
86
+ // Parses any historical conv packed params format into
87
+ // the current format.
88
+ template <uint32_t kSpatialDim>
89
+ ConvParamsSerializationTypeV3 parse_conv_serialized_state(c10::IValue v) {
90
+
91
+ // determine the version based on IValue contents
92
+ int version = -1;
93
+ if (v.isTuple()) {
94
+ const auto& elements = v.toTupleRef().elements();
95
+ if (!elements.empty()) {
96
+ auto firstElement = elements[0];
97
+ if (firstElement.isTensor()) {
98
+ version = 1;
99
+ } else if (firstElement.isString()) {
100
+ const std::string& version_str = firstElement.toStringRef();
101
+ // note: not parsing the string to automatically handle bad
102
+ // inputs
103
+ if (version_str == "2") {
104
+ version = 2;
105
+ }
106
+ } else if (firstElement.isInt()) {
107
+ auto raw_version = firstElement.toInt();
108
+ if (raw_version == 3) {
109
+ version = 3;
110
+ }
111
+ }
112
+ }
113
+ }
114
+ TORCH_INTERNAL_ASSERT(version != -1, "Unable to parse serialization version");
115
+
116
+ if (version == 1) {
117
+ // version 1 - convert to version 3 manually
118
+
119
+ const auto& elements = v.toTupleRef().elements();
120
+
121
+ at::Tensor weight = elements[0].toTensor();
122
+ std::optional<at::Tensor> bias = elements[1].toOptional<at::Tensor>();
123
+ torch::List<at::Tensor> stride_x_kSpatialDim = elements[2].toTensorList();
124
+ torch::List<at::Tensor> padding_x_kSpatialDim = elements[3].toTensorList();
125
+ torch::List<at::Tensor> dilation_x_kSpatialDim = elements[4].toTensorList();
126
+ at::Tensor groups = elements[5].toTensor();
127
+
128
+ std::vector<int64_t> config_vals;
129
+ config_vals.reserve(
130
+ stride_x_kSpatialDim.size() + padding_x_kSpatialDim.size() +
131
+ dilation_x_kSpatialDim.size() + kSpatialDim + 3);
132
+ config_vals.push_back(kSpatialDim);
133
+ for (const auto i : c10::irange(stride_x_kSpatialDim.size())) {
134
+ auto stride = stride_x_kSpatialDim.get(i);
135
+ config_vals.push_back(stride[0].item<int16_t>());
136
+ }
137
+ for (const auto i : c10::irange(padding_x_kSpatialDim.size())) {
138
+ auto padding = padding_x_kSpatialDim.get(i);
139
+ config_vals.push_back(padding[0].item<int16_t>());
140
+ }
141
+ for (const auto i : c10::irange(dilation_x_kSpatialDim.size())) {
142
+ auto dilation = dilation_x_kSpatialDim.get(i);
143
+ config_vals.push_back(dilation[0].item<int16_t>());
144
+ }
145
+ // output_padding does not exist in v1, so we fill in a default value
146
+ for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) {
147
+ config_vals.push_back(0);
148
+ }
149
+ config_vals.push_back(groups[0].item<int16_t>());
150
+ // transpose does not exist in v1, so we fill in a default value
151
+ config_vals.push_back(0);
152
+
153
+ std::vector<std::optional<at::Tensor>> tensors;
154
+ tensors.emplace_back();
155
+ tensors.emplace_back(weight);
156
+ tensors.emplace_back(bias);
157
+
158
+ int64_t version = 3;
159
+ return std::tie(version, config_vals, tensors);
160
+ } else if (version == 2) {
161
+ // version 2
162
+ const auto& elements = v.toTupleRef().elements();
163
+ std::vector<at::Tensor> non_optional = elements[1].toTensorList().vec();
164
+ std::vector<std::optional<at::Tensor>> optional;
165
+
166
+ if (elements[2].isTensorList()) {
167
+ for (const auto& elem : elements[2].toTensorList()) {
168
+ optional.emplace_back(static_cast<at::Tensor>(elem));
169
+ }
170
+ } else {
171
+ for (const auto& elem : elements[2].toList()) {
172
+ optional.emplace_back(static_cast<c10::IValue>(elem).toOptional<at::Tensor>());
173
+ }
174
+ }
175
+ // create default optional value for bias
176
+ if (optional.empty()) {
177
+ optional.emplace_back();
178
+ }
179
+
180
+ auto config_a = non_optional[0].accessor<int16_t, 1>();
181
+ std::vector<int64_t> config_vals;
182
+ config_vals.reserve(config_a.size(0));
183
+ for (const auto i : c10::irange(config_a.size(0))) {
184
+ config_vals.emplace_back(config_a[i]);
185
+ }
186
+
187
+ auto weight = non_optional[1];
188
+ auto bias = optional[0];
189
+
190
+ std::vector<std::optional<at::Tensor>> tensors;
191
+ tensors.emplace_back();
192
+ tensors.emplace_back(weight);
193
+ tensors.emplace_back(bias);
194
+
195
+ int64_t version = 3;
196
+ return std::tie(version, config_vals, tensors);
197
+ } else if (version == 3) {
198
+ return v.to<ConvParamsSerializationTypeV3>();
199
+ } else {
200
+ TORCH_INTERNAL_ASSERT(false, "Unexpected serialized qconv version: ",
201
+ version);
202
+ }
203
+ }
204
+
205
+ #define QCONV_SERIALIZATION_VERSION 2
206
+
207
+ #if QCONV_SERIALIZATION_VERSION == 2
208
+ using ConvParamsSerializationType = ConvParamsSerializationTypeV2;
209
+
210
+ template <uint32_t kSpatialDim>
211
+ ConvParamsSerializationTypeV2 serialize_conv(
212
+ const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>>& params) {
213
+
214
+ std::string version = "2";
215
+ std::vector<at::Tensor> non_optional;
216
+ std::vector<std::optional<at::Tensor>> optional;
217
+
218
+ // create a packed int8_t tensor for conv params
219
+ std::vector<int16_t> params_vec;
220
+ params_vec.push_back(kSpatialDim);
221
+ auto stride = params->stride().vec();
222
+ params_vec.insert(params_vec.end(), stride.begin(), stride.end());
223
+ auto padding = params->padding().vec();
224
+ params_vec.insert(params_vec.end(), padding.begin(), padding.end());
225
+ auto dilation = params->dilation().vec();
226
+ params_vec.insert(params_vec.end(), dilation.begin(), dilation.end());
227
+ auto output_padding = params->output_padding().vec();
228
+ params_vec.insert(params_vec.end(), output_padding.begin(),
229
+ output_padding.end());
230
+ params_vec.push_back(params->groups());
231
+ params_vec.push_back(params->transpose());
232
+ int64_t vec_size = params_vec.size();
233
+ at::Tensor params_tensor = at::from_blob(
234
+ params_vec.data(), {vec_size},
235
+ at::TensorOptions().dtype(at::kShort))
236
+ // clone to retain ownership of the data
237
+ .clone();
238
+
239
+ auto [weight, bias] = params->unpack();
240
+
241
+ non_optional.emplace_back(std::move(params_tensor));
242
+ non_optional.emplace_back(std::move(weight));
243
+ optional.emplace_back(std::move(bias));
244
+
245
+ return std::tie(version, non_optional, optional);
246
+ }
247
+
248
+ #elif QCONV_SERIALIZATION_VERSION == 3
249
+ using ConvParamsSerializationType = ConvParamsSerializationTypeV3;
250
+
251
+ template <uint32_t kSpatialDim>
252
+ ConvParamsSerializationTypeV3 serialize_conv(
253
+ const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>>& params) {
254
+ std::vector<int64_t> config_vals;
255
+ config_vals.push_back(kSpatialDim);
256
+ auto stride = params->stride().vec();
257
+ config_vals.insert(config_vals.end(), stride.begin(), stride.end());
258
+ auto padding = params->padding().vec();
259
+ config_vals.insert(config_vals.end(), padding.begin(), padding.end());
260
+ auto dilation = params->dilation().vec();
261
+ config_vals.insert(config_vals.end(), dilation.begin(), dilation.end());
262
+ auto output_padding = params->output_padding().vec();
263
+ config_vals.insert(config_vals.end(), output_padding.begin(),
264
+ output_padding.end());
265
+ config_vals.push_back(params->groups());
266
+ config_vals.push_back(params->transpose());
267
+
268
+ auto [weight, bias] = params->unpack();
269
+
270
+ std::vector<std::optional<at::Tensor>> tensors;
271
+ tensors.emplace_back();
272
+ tensors.emplace_back(weight);
273
+ tensors.emplace_back(bias);
274
+
275
+ int64_t version = 3;
276
+ return std::tie(version, config_vals, tensors);
277
+ }
278
+
279
+ #else
280
+ #error "Invalid qconv serialization version."
281
+ #endif
282
+
283
+ template <uint32_t kSpatialDim>
284
+ c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> deserialize_conv(
285
+ ConvParamsSerializationTypeV3 state) {
286
+ auto [version, config_vals, tensors] = state;
287
+ TORCH_INTERNAL_ASSERT(version == 3, "Unexpected serialized qconv version: ", version);
288
+
289
+ TORCH_CHECK(tensors.size() == 3, "Wrong number of tensors", tensors.size());
290
+ std::optional<at::Tensor> weight = tensors[1];
291
+ std::optional<at::Tensor> bias = tensors[2];
292
+ TORCH_INTERNAL_ASSERT(weight, "Weight should always be present in serialized qconv.");
293
+
294
+ torch::List<int64_t> stride, padding, output_padding, dilation;
295
+ // skip kSpatialDim
296
+ int idx = 1;
297
+ for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) {
298
+ stride.emplace_back(config_vals.at(idx));
299
+ idx++;
300
+ }
301
+ for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) {
302
+ padding.emplace_back(config_vals.at(idx));
303
+ idx++;
304
+ }
305
+ for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) {
306
+ dilation.emplace_back(config_vals.at(idx));
307
+ idx++;
308
+ }
309
+ for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) {
310
+ TORCH_INTERNAL_ASSERT(idx < static_cast<int64_t>(config_vals.size()),
311
+ "Unexpected index = ", idx, " for config_vals of size ",
312
+ config_vals.size());
313
+ output_padding.emplace_back(config_vals.at(idx));
314
+ idx++;
315
+ }
316
+ int64_t groups = config_vals.at(idx);
317
+ idx++;
318
+ int64_t flags = config_vals.at(idx);
319
+ idx++;
320
+ TORCH_INTERNAL_ASSERT(idx == static_cast<int64_t>(config_vals.size()),
321
+ "Unexpected length of config_vals, expected ",
322
+ idx,
323
+ " got ",
324
+ config_vals.size());
325
+
326
+ bool transpose = flags & (1 << 0);
327
+
328
+ int64_t other_flags = flags & ~(1 << 0);
329
+ TORCH_INTERNAL_ASSERT(other_flags == 0, "Unexpected flags set in ", flags, ".");
330
+
331
+ auto& ctx = at::globalContext();
332
+
333
+ #ifdef USE_FBGEMM
334
+ if (ctx.qEngine() == at::QEngine::X86) {
335
+ #if AT_MKLDNN_ENABLED()
336
+ bool use_onednn = onednn_utils::should_use_onednn_quant(
337
+ weight.value(), transpose, groups, output_padding);
338
+ if (use_onednn) {
339
+ return PackedConvWeightsOnednn<kSpatialDim>::prepack(
340
+ weight.value(),
341
+ bias,
342
+ stride,
343
+ padding,
344
+ output_padding,
345
+ dilation,
346
+ groups,
347
+ transpose
348
+ );
349
+ }
350
+ #endif
351
+ return PackedConvWeight<kSpatialDim>::prepack(
352
+ weight.value(),
353
+ bias,
354
+ stride,
355
+ padding,
356
+ output_padding,
357
+ dilation,
358
+ groups,
359
+ transpose
360
+ );
361
+ } // x86
362
+ #endif
363
+
364
+ #ifdef USE_FBGEMM
365
+ if (ctx.qEngine() == at::QEngine::FBGEMM) {
366
+ return PackedConvWeight<kSpatialDim>::prepack(
367
+ weight.value(),
368
+ bias,
369
+ stride,
370
+ padding,
371
+ output_padding,
372
+ dilation,
373
+ groups,
374
+ transpose
375
+ );
376
+ }
377
+ #endif // USE_FBGEMM
378
+ #ifdef USE_PYTORCH_QNNPACK
379
+ if (ctx.qEngine() == at::QEngine::QNNPACK) {
380
+ TORCH_CHECK(
381
+ kSpatialDim == 2,
382
+ "prepack/__setstate__: QNNPACK only supports Conv2d "
383
+ "now.");
384
+ return PackedConvWeightsQnnp<kSpatialDim>::prepack(
385
+ weight.value(),
386
+ bias,
387
+ stride,
388
+ padding,
389
+ output_padding,
390
+ dilation,
391
+ groups,
392
+ transpose
393
+ );
394
+ }
395
+ #endif // USE_PYTORCH_QNNPACK
396
+ #if AT_MKLDNN_ENABLED()
397
+ if (ctx.qEngine() == at::QEngine::ONEDNN) {
398
+ return PackedConvWeightsOnednn<kSpatialDim>::prepack(
399
+ weight.value(),
400
+ bias,
401
+ stride,
402
+ padding,
403
+ output_padding,
404
+ dilation,
405
+ groups,
406
+ transpose
407
+ );
408
+ }
409
+ #endif // AT_MKLDNN_ENABLED()
410
+ TORCH_CHECK(
411
+ false,
412
+ "Didn't find engine for when deserializing ConvPackedParams: ",
413
+ toString(ctx.qEngine()));
414
+ }
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/fbgemm_utils.h ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/Tensor.h>
4
+ #include <ATen/native/quantized/PackedParams.h>
5
+ #include <ATen/native/quantized/cpu/EmbeddingPackedParams.h>
6
+ #include <c10/core/QScheme.h>
7
+ #include <c10/util/irange.h>
8
+
9
+ #ifdef USE_FBGEMM
10
+ #include <fbgemm/Fbgemm.h>
11
+ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Winconsistent-missing-destructor-override")
12
+ #include <fbgemm/FbgemmFP16.h>
13
+ C10_DIAGNOSTIC_POP()
14
+ #include <fbgemm/QuantUtils.h>
15
+
16
+ // The struct for the packed weight matrix (PackBMatrix) and the corresponding
17
+ // column offsets used for the fully connect layer, which are both prepared in
18
+ // the prepacking step to save the computations in the inference. Note the
19
+ // column offsets include the sum of the B columns as well as the scalar term
20
+ // B_zero_point * K, whereas the row offsets created by
21
+ // PackAWithQuantRowOffset/PackAWithIm2Col/PackAWithRowOffset are only the sum
22
+ // of the A rows. The column offsets are needed for the asymmetric quantization
23
+ // (affine quantization) of input matrix.
24
+ // Note that in JIT mode we can think of a way to fuse col_offsets with bias.
25
+ struct TORCH_API PackedLinearWeight : public LinearPackedParamsBase {
26
+ PackedLinearWeight(
27
+ std::unique_ptr<fbgemm::PackBMatrix<int8_t>> w,
28
+ std::optional<at::Tensor> bias,
29
+ std::vector<int32_t> col_offsets,
30
+ std::vector<float> w_scale,
31
+ std::vector<int32_t> w_zp,
32
+ c10::QScheme q_scheme)
33
+ : w(std::move(w)),
34
+ bias_(std::move(bias)),
35
+ col_offsets(std::move(col_offsets)),
36
+ w_scale(std::move(w_scale)),
37
+ w_zp(std::move(w_zp)),
38
+ q_scheme(std::move(q_scheme)) {}
39
+ std::unique_ptr<fbgemm::PackBMatrix<int8_t>> w;
40
+ std::optional<at::Tensor> bias_;
41
+ std::vector<int32_t> col_offsets;
42
+ std::vector<float> w_scale;
43
+ std::vector<int32_t> w_zp;
44
+ c10::QScheme q_scheme;
45
+
46
+ at::Tensor apply(
47
+ at::Tensor input,
48
+ double output_scale,
49
+ int64_t output_zero_point) override;
50
+
51
+ at::Tensor apply_relu(
52
+ at::Tensor input,
53
+ double output_scale,
54
+ int64_t output_zero_point) override;
55
+
56
+ at::Tensor& apply_out(
57
+ const at::Tensor& input,
58
+ double output_scale,
59
+ int64_t output_zero_point,
60
+ at::Tensor& output) override;
61
+
62
+ at::Tensor& apply_relu_out(
63
+ const at::Tensor& input,
64
+ double output_scale,
65
+ int64_t output_zero_point,
66
+ at::Tensor& output) override;
67
+
68
+ at::Tensor apply_with_input_q_dq_qweight_dq_output_fp32(
69
+ at::Tensor input,
70
+ double input_scale,
71
+ int64_t input_zero_point) override;
72
+
73
+ at::Tensor apply_with_input_q_dq_qweight_dq_relu_output_fp32(
74
+ at::Tensor input,
75
+ double input_scale,
76
+ int64_t input_zero_point) override;
77
+
78
+ at::Tensor apply_dynamic(at::Tensor input, bool reduce_range = false)
79
+ override;
80
+
81
+ at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range = false)
82
+ override;
83
+
84
+ std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() override;
85
+
86
+ std::optional<at::Tensor> bias() override {
87
+ return bias_;
88
+ }
89
+
90
+ static c10::intrusive_ptr<LinearPackedParamsBase> prepack(
91
+ at::Tensor weight,
92
+ std::optional<at::Tensor> bias);
93
+
94
+ private:
95
+ template <bool ReluFused>
96
+ at::Tensor& apply_impl(
97
+ const at::Tensor& input,
98
+ double output_scale,
99
+ int64_t output_zero_point,
100
+ at::Tensor& output);
101
+
102
+ template <bool ReluFused>
103
+ at::Tensor apply_with_input_q_dq_qweight_dq_output_fp32_impl(
104
+ const at::Tensor& input,
105
+ double input_scale,
106
+ int64_t input_zero_point);
107
+
108
+ template <bool ReluFused>
109
+ at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range = false);
110
+ };
111
+
112
+ struct TORCH_API PackedLinearWeightFp16 : public LinearPackedParamsBase {
113
+ PackedLinearWeightFp16(
114
+ std::unique_ptr<fbgemm::PackedGemmMatrixFP16> w,
115
+ std::optional<at::Tensor> bias)
116
+ : w(std::move(w)), bias_(std::move(bias)) {}
117
+
118
+ std::unique_ptr<fbgemm::PackedGemmMatrixFP16> w;
119
+ std::optional<at::Tensor> bias_;
120
+
121
+ at::Tensor apply(
122
+ at::Tensor /*input*/,
123
+ double /*output_scale*/,
124
+ int64_t /*output_zero_point*/) override {
125
+ TORCH_INTERNAL_ASSERT(false);
126
+ }
127
+ at::Tensor apply_relu(
128
+ at::Tensor /*input*/,
129
+ double /*output_scale*/,
130
+ int64_t /*output_zero_point*/) override {
131
+ TORCH_INTERNAL_ASSERT(false);
132
+ }
133
+
134
+ at::Tensor apply_dynamic(at::Tensor input, bool reduce_range = false)
135
+ override;
136
+ at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range = false)
137
+ override;
138
+
139
+ at::Tensor& apply_dynamic_out(
140
+ const at::Tensor& input,
141
+ at::Tensor& output,
142
+ bool reduce_range = false) override;
143
+ at::Tensor& apply_dynamic_relu_out(
144
+ const at::Tensor& input,
145
+ at::Tensor& output,
146
+ bool reduce_range = false) override;
147
+
148
+ std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() override;
149
+
150
+ std::optional<at::Tensor> bias() override {
151
+ return bias_;
152
+ }
153
+
154
+ static c10::intrusive_ptr<LinearPackedParamsBase> prepack(
155
+ at::Tensor weight,
156
+ std::optional<at::Tensor> bias);
157
+
158
+ void set_bias(std::optional<at::Tensor> bias) override;
159
+
160
+ private:
161
+ template <bool ReluFused>
162
+ at::Tensor& apply_dynamic_impl(const at::Tensor& input, at::Tensor& output);
163
+ };
164
+
165
+ template <int kSpatialDim = 2>
166
+ struct TORCH_API PackedConvWeight : public ConvPackedParamsBase<kSpatialDim> {
167
+ PackedConvWeight(
168
+ std::unique_ptr<fbgemm::PackWeightsForConv<kSpatialDim>> w,
169
+ std::optional<at::Tensor> bias,
170
+ torch::List<int64_t> stride,
171
+ torch::List<int64_t> padding,
172
+ torch::List<int64_t> output_padding,
173
+ torch::List<int64_t> dilation,
174
+ int64_t groups,
175
+ uint8_t transpose,
176
+ std::vector<int32_t> col_offsets,
177
+ std::vector<int64_t> kernel,
178
+ std::vector<float> w_scale,
179
+ std::vector<int32_t> w_zp,
180
+ c10::QScheme q_scheme)
181
+ : w(std::move(w)),
182
+ bias(std::move(bias)),
183
+ stride_(std::move(stride)),
184
+ padding_(std::move(padding)),
185
+ output_padding_(std::move(output_padding)),
186
+ dilation_(std::move(dilation)),
187
+ groups_(groups),
188
+ transpose_(transpose),
189
+ col_offsets(std::move(col_offsets)),
190
+ kernel(std::move(kernel)),
191
+ w_scale(std::move(w_scale)),
192
+ w_zp(std::move(w_zp)),
193
+ q_scheme(q_scheme) {}
194
+
195
+ std::unique_ptr<fbgemm::PackWeightsForConv<kSpatialDim>> w;
196
+ std::optional<at::Tensor> bias;
197
+ torch::List<int64_t> stride_;
198
+ torch::List<int64_t> padding_;
199
+ torch::List<int64_t> output_padding_;
200
+ torch::List<int64_t> dilation_;
201
+ int64_t groups_;
202
+ uint8_t transpose_;
203
+ std::vector<int32_t> col_offsets;
204
+ std::vector<int64_t> kernel;
205
+ std::vector<float> w_scale;
206
+ std::vector<int32_t> w_zp;
207
+ c10::QScheme q_scheme;
208
+
209
+ at::Tensor apply(
210
+ const at::Tensor& input,
211
+ double output_scale,
212
+ int64_t output_zero_point) override;
213
+
214
+ at::Tensor apply_relu(
215
+ const at::Tensor& input,
216
+ double output_scale,
217
+ int64_t output_zero_point) override;
218
+
219
+ at::Tensor apply_dynamic(
220
+ const at::Tensor& input,
221
+ bool reduce_range) override;
222
+
223
+ std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() override;
224
+
225
+ static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> prepack(
226
+ at::Tensor weight,
227
+ std::optional<at::Tensor> bias,
228
+ torch::List<int64_t> stride,
229
+ torch::List<int64_t> padding,
230
+ torch::List<int64_t> output_padding,
231
+ torch::List<int64_t> dilation,
232
+ int64_t groups,
233
+ bool transpose);
234
+
235
+ const float* GetBiasData(at::Tensor* bias);
236
+
237
+ void GetQuantizationParams(
238
+ float act_scale,
239
+ float out_scale,
240
+ std::vector<float>* output_multiplier_float,
241
+ std::vector<float>* act_times_w_scale);
242
+
243
+ torch::List<int64_t> stride() const override {
244
+ return stride_;
245
+ }
246
+
247
+ torch::List<int64_t> padding() const override {
248
+ return padding_;
249
+ }
250
+
251
+ torch::List<int64_t> output_padding() const override {
252
+ return output_padding_;
253
+ }
254
+
255
+ torch::List<int64_t> dilation() const override {
256
+ return dilation_;
257
+ }
258
+
259
+ int64_t groups() const override {
260
+ return groups_;
261
+ }
262
+
263
+ bool transpose() const override {
264
+ return (bool)transpose_;
265
+ }
266
+
267
+ private:
268
+ template <bool ReluFused>
269
+ at::Tensor apply_impl(
270
+ const at::Tensor& input,
271
+ double output_scale,
272
+ int64_t output_zero_point);
273
+ };
274
+
275
+ // PackWeight: Convert the weight from uint8 to int8.
276
+ inline void convert_uint8_int8(
277
+ int len,
278
+ const uint8_t* src_uint8,
279
+ int8_t* dst_int8) {
280
+ for (const auto i : c10::irange(len)) {
281
+ dst_int8[i] = static_cast<int8_t>(static_cast<int32_t>(src_uint8[i]) - 128);
282
+ }
283
+ }
284
+
285
+ // UnpackWeight: Convert the weight from int8 to uint8.
286
+ inline void convert_int8_uint8(
287
+ int len,
288
+ const int8_t* src_int8,
289
+ uint8_t* dst_uint8) {
290
+ for (const auto i : c10::irange(len)) {
291
+ dst_uint8[i] =
292
+ static_cast<uint8_t>(static_cast<int32_t>(src_int8[i]) + 128);
293
+ }
294
+ }
295
+
296
+ namespace at {
297
+ namespace native {
298
+ namespace fbgemm_utils {
299
+
300
+ template <int kSpatialDim = 2>
301
+ fbgemm::conv_param_t<kSpatialDim> MakeFbgemmConvParam(
302
+ int N,
303
+ int C,
304
+ int M,
305
+ const std::vector<int>& image_shape,
306
+ int groups,
307
+ const std::vector<int>& kernels,
308
+ const std::vector<int>& strides,
309
+ const std::vector<int>& pads,
310
+ const std::vector<int>& dilations,
311
+ const std::vector<int>& output_padding = std::vector<int>(kSpatialDim, 0),
312
+ bool transposed = false);
313
+
314
+ // TODO: Remove functions below when ChannelsLast3d is ready.
315
+ Tensor MakeStridedQTensorCPU(
316
+ const IntArrayRef& sizes,
317
+ const IntArrayRef& strides,
318
+ const TensorOptions& options,
319
+ QuantizerPtr quantizer);
320
+
321
+ Tensor MakeEmptyAffineQuantizedChannelsLast3dTensor(
322
+ int64_t N,
323
+ int64_t C,
324
+ int64_t D,
325
+ int64_t H,
326
+ int64_t W,
327
+ const TensorOptions& options,
328
+ double scale,
329
+ int64_t zero_point);
330
+
331
+ Tensor MakeEmptyPerChannelAffineQuantizedChannelsLast3dTensor(
332
+ int64_t N,
333
+ int64_t C,
334
+ int64_t D,
335
+ int64_t H,
336
+ int64_t W,
337
+ const TensorOptions& options,
338
+ const Tensor& scales,
339
+ const Tensor& zero_points);
340
+
341
+ Tensor ConvertToChannelsLast3dTensor(const Tensor& src);
342
+
343
+ template <int kSpatialDim = 2>
344
+ Tensor TransposeConvTensorUnpackConversion(const Tensor& src, int groups);
345
+
346
+ template <int kSpatialDim>
347
+ Tensor ConvertConvWeightsToChannelLastTensor(
348
+ const at::Tensor& src,
349
+ int groups,
350
+ bool transpose);
351
+ } // namespace fbgemm_utils
352
+ } // namespace native
353
+ } // namespace at
354
+
355
+ #endif // USE_FBGEMM
356
+
357
+ struct TORCH_API PackedEmbeddingBagWeight : public EmbeddingPackedParamsBase {
358
+ PackedEmbeddingBagWeight(
359
+ at::Tensor packed_w,
360
+ std::vector<float> w_scale,
361
+ std::vector<float> w_zp,
362
+ int64_t bit_rate,
363
+ c10::QScheme q_scheme,
364
+ int64_t version)
365
+ : packed_w(std::move(packed_w)),
366
+ w_scale(std::move(w_scale)),
367
+ w_zp(std::move(w_zp)),
368
+ bit_rate_(bit_rate),
369
+ q_scheme(q_scheme),
370
+ version_(version) {
371
+ // NOLINTNEXTLINE(clang-analyzer-cplusplus.Move)
372
+ if (!packed_w.is_contiguous()) {
373
+ packed_w = packed_w.contiguous();
374
+ }
375
+ }
376
+
377
+ at::Tensor packed_w;
378
+ std::vector<float> w_scale;
379
+ std::vector<float> w_zp;
380
+ int64_t bit_rate_;
381
+ c10::QScheme q_scheme;
382
+ int64_t version_;
383
+
384
+ at::Tensor unpack() override;
385
+ static c10::intrusive_ptr<EmbeddingPackedParamsBase> prepack(
386
+ at::Tensor weight);
387
+
388
+ int64_t bit_rate() const override {
389
+ return bit_rate_;
390
+ }
391
+
392
+ int64_t version() const override {
393
+ return version_;
394
+ }
395
+
396
+ at::Tensor embeddingbag_byte(
397
+ const at::Tensor& indices,
398
+ const std::optional<at::Tensor>& offsets,
399
+ bool pruned_weights,
400
+ const std::optional<at::Tensor>& per_sample_weights_,
401
+ const std::optional<at::Tensor>& compressed_indices_mapping,
402
+ bool include_last_offset,
403
+ bool is_embedding_op) override;
404
+
405
+ at::Tensor embeddingbag_4bit(
406
+ const at::Tensor& indices,
407
+ const std::optional<at::Tensor>& offsets,
408
+ bool pruned_weights,
409
+ const std::optional<at::Tensor>& per_sample_weights_,
410
+ const std::optional<at::Tensor>& compressed_indices_mapping,
411
+ bool include_last_offset,
412
+ bool is_embedding_op) override;
413
+ };
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/init_qnnpack.h ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #ifdef USE_PYTORCH_QNNPACK
4
+
5
+ namespace at {
6
+ namespace native {
7
+
8
+ void initQNNPACK();
9
+
10
+ } // namespace native
11
+ } // namespace at
12
+
13
+ #endif
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/qembeddingbag.h ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/Tensor.h>
3
+ #include <cstdint>
4
+
5
+ namespace at {
6
+ namespace native {
7
+ Tensor& embedding_bag_byte_rowwise_offsets_out(
8
+ Tensor& output,
9
+ const Tensor& weight,
10
+ const Tensor& indices,
11
+ const std::optional<Tensor>& offsets_in,
12
+ const bool /* scale_grad_by_freq */,
13
+ const int64_t /* mode */,
14
+ bool pruned_weights,
15
+ const std::optional<Tensor>& per_sample_weights_,
16
+ const std::optional<Tensor>& compressed_indices_mapping,
17
+ bool include_last_offset);
18
+
19
+ Tensor& embedding_bag_4bit_rowwise_offsets_out(
20
+ Tensor& output,
21
+ const Tensor& weight,
22
+ const Tensor& indices,
23
+ const std::optional<Tensor>& offsets_in,
24
+ const bool /* scale_grad_by_freq */,
25
+ const int64_t /* mode */,
26
+ bool pruned_weights,
27
+ const std::optional<Tensor>& per_sample_weights_,
28
+ const std::optional<Tensor>& compressed_indices_mapping,
29
+ bool include_last_offset);
30
+
31
+ Tensor& qembeddingbag_byte_unpack_out(Tensor& output, const Tensor& packed_weight);
32
+
33
+ } // native
34
+ } // at
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/qembeddingbag_prepack.h ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/Tensor.h>
3
+
4
+ namespace at { namespace native {
5
+
6
+ Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight);
7
+
8
+ Tensor qembeddingbag_byte_prepack(const Tensor& weight);
9
+
10
+ Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight);
11
+
12
+ } // namespace native
13
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/native/transformers/attention.h ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/Tensor.h>
3
+ #include <c10/macros/Export.h>
4
+ #include <ATen/native/DispatchStub.h>
5
+ #include <ATen/native/transformers/attention.h>
6
+ #include <optional>
7
+
8
+ namespace at {
9
+ namespace native {
10
+
11
+ using fused_sdp_choice_fn = int64_t (*)(const Tensor& query_, const Tensor& key, const Tensor& value,
12
+ const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale, bool enable_gqa);
13
+
14
+ DECLARE_DISPATCH(fused_sdp_choice_fn, _fused_sdp_choice_stub);
15
+
16
+ TORCH_API Tensor bmm_nt(const Tensor& a, const Tensor& b);
17
+ TORCH_API Tensor masked_softmax(
18
+ Tensor& attn_scores,
19
+ std::optional<Tensor> attn_mask,
20
+ const Tensor& query,
21
+ std::optional<int64_t> mask_type = {});
22
+
23
+ using transform_bias_rescale_qkv_fn = void(*)(
24
+ at::ScalarType type,
25
+ void* _q_k_v,
26
+ const void* _qkv,
27
+ const void* _qkv_bias,
28
+ int64_t B,
29
+ int64_t T,
30
+ int64_t D,
31
+ int64_t num_head);
32
+
33
+ DECLARE_DISPATCH(transform_bias_rescale_qkv_fn, transform_bias_rescale_qkv_stub);
34
+
35
+ TORCH_API Tensor transform0213_gemm_nt_bias(
36
+ const Tensor& a,
37
+ const Tensor& b,
38
+ const Tensor& c,
39
+ const Tensor& query);
40
+
41
+ TORCH_API Tensor bmm_nn(Tensor& out, const Tensor& a, const Tensor& b);
42
+
43
+ TORCH_API void debug_assert_shape(int line, const Tensor& t, c10::IntArrayRef shape);
44
+
45
+ TORCH_API Tensor qkv_projection(
46
+ const Tensor& query,
47
+ const Tensor& key,
48
+ const Tensor& value,
49
+ const int64_t embed_dim,
50
+ const Tensor& qkv_weight);
51
+
52
+ using flash_attention_fn = void (*)(
53
+ const Tensor& output, const Tensor& logsumexp,
54
+ const Tensor& query, const Tensor& key, const Tensor& value,
55
+ double dropout_p, bool is_causal,
56
+ std::optional<Tensor> attn_mask,
57
+ std::optional<double> scale);
58
+
59
+ using flash_attention_backward_fn = void (*)(
60
+ const Tensor& grad_q, const Tensor& grad_k,
61
+ const Tensor& grad_v, const Tensor& grad_out,
62
+ const Tensor& query, const Tensor& key,
63
+ const Tensor& value, const Tensor& out, const Tensor& logsumexp,
64
+ double dropout_p, bool is_causal,
65
+ std::optional<Tensor> attn_mask,
66
+ std::optional<double> scale);
67
+
68
+ DECLARE_DISPATCH(flash_attention_fn, flash_attention_kernel);
69
+ DECLARE_DISPATCH(flash_attention_backward_fn, flash_attention_backward_kernel);
70
+
71
+ } // namespace native
72
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/native/transformers/sdp_utils_cpp.h ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/Context.h>
3
+ #include <ATen/NestedTensorImpl.h>
4
+ #include <ATen/TensorSubclassLikeUtils.h>
5
+ #include <ATen/TensorUtils.h>
6
+ #include <ATen/core/Tensor.h>
7
+ #include <ATen/core/grad_mode.h>
8
+ #include <ATen/native/DispatchStub.h>
9
+ #include <c10/core/ScalarType.h>
10
+
11
+ #include <c10/util/Exception.h>
12
+ #include <c10/util/env.h>
13
+ #include <c10/util/irange.h>
14
+
15
+ #include <c10/core/SymInt.h>
16
+ #include <c10/core/SymFloat.h>
17
+ #include <c10/util/string_view.h>
18
+ #include <c10/util/Array.h>
19
+ #include <cmath>
20
+ #include <cstdint>
21
+ #include <functional>
22
+
23
+ namespace sdp {
24
+
25
+ constexpr int32_t num_backends = 5;
26
+ enum class SDPBackend {
27
+ error = -1,
28
+ math = 0,
29
+ flash_attention = 1,
30
+ efficient_attention = 2,
31
+ cudnn_attention = 3,
32
+ overrideable = 4
33
+ };
34
+
35
+ // Note that if this changed make sure to update
36
+ // the templated enum in mem_eff/kernel_forward.h and mem_eff/kernel_backward.h
37
+ enum class CustomMaskType {
38
+ NoCustomMask = 0,
39
+ CausalFromTopLeft = 1,
40
+ CausalFromBottomRight = 2,
41
+ NumCustomMaskTypes,
42
+ };
43
+
44
+ struct sdp_params {
45
+ at::Tensor query;
46
+ at::Tensor key;
47
+ at::Tensor value;
48
+ std::optional<at::Tensor> attn_mask;
49
+ double dropout;
50
+ bool is_causal;
51
+ bool enable_gqa;
52
+ };
53
+
54
+ SDPBackend select_sdp_backend_cpp(sdp_params const& kernel_params);
55
+
56
+ inline c10::SymFloat calculate_scale(
57
+ const at::Tensor& query,
58
+ std::optional<double> scale) {
59
+ const auto softmax_scale = scale.has_value()
60
+ ? scale.value()
61
+ : (c10::SymFloat(1.0) / (c10::SymFloat(query.sym_size(-1)).sqrt()));
62
+ return c10::SymFloat(softmax_scale);
63
+ }
64
+
65
+ using c10::array_of;
66
+
67
+ inline bool input_requires_grad(sdp_params const& params) {
68
+ const bool any_inputs_require_grad = params.query.requires_grad() ||
69
+ params.key.requires_grad() || params.value.requires_grad();
70
+ const bool gradmode_enabled = at::GradMode::is_enabled();
71
+ return any_inputs_require_grad && gradmode_enabled;
72
+ }
73
+
74
+ inline bool has_for_nested_inputs(sdp_params const& params) {
75
+ return
76
+ (params.query.is_nested() && params.query.layout() == c10::kStrided) ||
77
+ (params.key.is_nested() && params.key.layout() == c10::kStrided) ||
78
+ (params.value.is_nested() && params.value.layout() == c10::kStrided);
79
+ }
80
+
81
+ inline bool has_for_dense_inputs(sdp_params const& params) {
82
+ return !params.query.is_nested() || !params.key.is_nested() || !params.value.is_nested();
83
+ }
84
+
85
+ inline bool has_only_dense_inputs(sdp_params const& params) {
86
+ return !params.query.is_nested() && !params.key.is_nested() && !params.value.is_nested();
87
+ }
88
+
89
+ template <typename dtype_vector>
90
+ inline bool check_tensor_dtype(
91
+ sdp_params const& params,
92
+ dtype_vector allowed_dtypes,
93
+ bool debug) {
94
+ auto query_dtype = params.query.dtype();
95
+ if (!(query_dtype == params.key.dtype() &&
96
+ query_dtype == params.value.dtype() &&
97
+ (std::find(allowed_dtypes.begin(), allowed_dtypes.end(), query_dtype) !=
98
+ allowed_dtypes.end()))) {
99
+ if (debug) {
100
+ TORCH_WARN(
101
+ "Expected query, key and value to all be of dtype: {",
102
+ c10::Join(", ", allowed_dtypes),
103
+ "}. Got ",
104
+ "Query dtype: ",
105
+ params.query.dtype(),
106
+ ", Key dtype: ",
107
+ params.key.dtype(),
108
+ ", and Value dtype: ",
109
+ params.value.dtype(),
110
+ " instead.");
111
+ }
112
+ return false;
113
+ }
114
+ return true;
115
+ }
116
+
117
+
118
+ inline bool try_broadcast_param_size(
119
+ const c10::SymInt q_size,
120
+ const c10::SymInt k_size,
121
+ const c10::SymInt v_size,
122
+ c10::string_view param_name,
123
+ bool debug) {
124
+ auto max_size = std::max({q_size, k_size, v_size});
125
+ if ((q_size != max_size && q_size != 1) ||
126
+ (k_size != max_size && k_size != 1) ||
127
+ (v_size != max_size && v_size != 1)) {
128
+ if (debug) {
129
+ TORCH_WARN(
130
+ "Both fused kernels require query, key and value to have broadcastable ",
131
+ param_name,
132
+ "got Query ",
133
+ param_name,
134
+ q_size,
135
+ ", Key ",
136
+ param_name,
137
+ k_size,
138
+ ", Value ",
139
+ param_name,
140
+ v_size,
141
+ " instead.");
142
+ }
143
+ return false;
144
+ }
145
+ return true;
146
+ }
147
+
148
+ inline bool check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper(
149
+ at::Tensor const& param,
150
+ c10::string_view param_name,
151
+ bool debug) {
152
+ const auto nt_tensor_impl = at::native::get_nested_tensor_impl(param);
153
+ const at::Tensor& sizes = nt_tensor_impl->get_nested_sizes();
154
+ auto num_head_dims = nt_tensor_impl->opt_size(1);
155
+ if (!num_head_dims.has_value()) {
156
+ // num_head_dims is ragged
157
+ if (debug) {
158
+ TORCH_WARN(
159
+ "Fused kernels do not support ragged num_head_dims, ",
160
+ param_name,
161
+ "has a ragged num_heads.");
162
+ }
163
+ return false;
164
+ }
165
+
166
+ auto* sizes_ptr = sizes.data_ptr<int64_t>();
167
+ const int64_t n_tensors = param.size(0);
168
+ const int64_t size_tensor_stride = sizes.stride(0);
169
+
170
+ // This is being called inside sdp with shape [batch, heads, {seq_len}, dim]
171
+ for (const auto i : c10::irange(n_tensors)) {
172
+ if (sizes_ptr[(i * size_tensor_stride) + 1] == 0) {
173
+ if (debug) {
174
+ TORCH_WARN(
175
+ "Fused kernels do not support seq_len == 0, ",
176
+ param_name,
177
+ "has a seq len of 0.");
178
+ }
179
+ return false;
180
+ }
181
+ }
182
+ return true;
183
+ }
184
+
185
+ inline bool check_for_seq_len_0_nested_tensor(sdp_params const& params, bool debug) {
186
+ // When this function is called we are assured that the nt is dim==4
187
+ bool q_is_safe = params.query.is_nested()
188
+ ? check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper(
189
+ params.query, "query ", debug)
190
+ : true;
191
+ // short circuit if any is unsafe
192
+ if (!q_is_safe) {
193
+ return false;
194
+ }
195
+
196
+ bool k_is_safe = params.key.is_nested()
197
+ ? check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper(
198
+ params.key, "key ", debug)
199
+ : true;
200
+ if (!k_is_safe) {
201
+ return false;
202
+ }
203
+
204
+ bool v_is_safe = params.value.is_nested()
205
+ ? check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper(
206
+ params.value, "value ", debug)
207
+ : true;
208
+ if (!v_is_safe) {
209
+ return false;
210
+ }
211
+
212
+ // We now know none of the inputs have ragged num_heads, so we can safely
213
+ // access .size(1)
214
+ auto q_num_heads = params.query.size(1);
215
+ auto k_num_heads = params.key.size(1);
216
+ auto v_num_heads = params.value.size(1);
217
+ bool same_num_heads =
218
+ q_num_heads == k_num_heads && q_num_heads == v_num_heads;
219
+
220
+ if (!same_num_heads) {
221
+ if (input_requires_grad(params)){
222
+ if (debug) {
223
+ TORCH_WARN(
224
+ "Both fused kernels do not support training with broadcasted NT inputs.");
225
+ }
226
+ return false;
227
+ }
228
+ return try_broadcast_param_size(
229
+ q_num_heads, k_num_heads, v_num_heads, "num heads ", debug);
230
+ }
231
+
232
+ return true;
233
+ }
234
+
235
+ inline bool check_nested_tensor(sdp_params const& params, bool debug) {
236
+ // Return false if have nested tensor
237
+ if (!has_only_dense_inputs(params)) {
238
+ if (debug) {
239
+ TORCH_WARN(
240
+ "Both fused kernels of cpp version currently do not support Nested Tensor inputs.");
241
+ }
242
+ return false;
243
+ }
244
+ return true;
245
+ }
246
+
247
+ inline bool check_for_dropout(sdp_params const& params, bool debug) {
248
+ if (params.dropout > 0.0) {
249
+ if (debug) {
250
+ TORCH_WARN("Both fused kernels do not support non-zero dropout.");
251
+ }
252
+ return false;
253
+ }
254
+ return true;
255
+ }
256
+
257
+ inline bool check_requires_grad_and_nested(sdp_params const& params, bool debug) {
258
+ if (input_requires_grad(params)) {
259
+ if (debug) {
260
+ TORCH_WARN(
261
+ "Memory efficient attention currently doesn't support training with NT inputs.");
262
+ }
263
+ return false;
264
+ }
265
+ return true;
266
+ }
267
+
268
+ inline bool check_for_attn_mask(sdp_params const& params, bool debug) {
269
+ if (params.attn_mask.has_value()) {
270
+ if (debug) {
271
+ TORCH_WARN("Flash Attention does not support non-null attn_mask.");
272
+ }
273
+ return false;
274
+ }
275
+ return true;
276
+ }
277
+
278
+ inline bool check_attn_mask_shape(sdp_params const& params, bool debug) {
279
+ auto attn_mask = params.attn_mask;
280
+ if (!attn_mask.has_value()) {
281
+ return true;
282
+ }
283
+ if (attn_mask.value().requires_grad()) {
284
+ return false;
285
+ }
286
+ auto batchSize = params.query.sym_size(0);
287
+ auto qSize = params.query.sym_size(2);
288
+ auto kvSize = params.key.sym_size(2);
289
+ auto num_head = params.query.sym_size(1);
290
+ if (attn_mask.value().sym_size(-2) != qSize && attn_mask.value().sym_size(-2) != 1) {
291
+ return false;
292
+ }
293
+ if (attn_mask.value().sym_size(-1) != kvSize && attn_mask.value().sym_size(-1) != 1) {
294
+ return false;
295
+ }
296
+ if (attn_mask.value().dim() == 2) {
297
+ return true;
298
+ } else if (attn_mask.value().dim() == 4) {
299
+ if ((attn_mask.value().sym_size(0) == 1 || attn_mask.value().sym_size(0) == batchSize)
300
+ && (attn_mask.value().sym_size(1) == 1 || attn_mask.value().sym_size(1) == num_head)) {
301
+ return true;
302
+ }
303
+ }
304
+ if (debug) {
305
+ TORCH_WARN("Please use the following attn mask shapes: ",
306
+ "2d - ({Q_seq_len, 1} x {KV_seq_len, 1}); ",
307
+ "4d - ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1} x {KV_seq_len, 1})");
308
+ }
309
+ return false;
310
+ }
311
+
312
+ inline bool check_tensor_shapes(sdp_params const& params, bool debug) {
313
+ auto query_dim = params.query.dim();
314
+ if (!(query_dim == params.key.dim() && query_dim == params.value.dim() &&
315
+ (query_dim == 4))) {
316
+ if (debug) {
317
+ TORCH_WARN(
318
+ "All fused kernels requires query, key and value to be 4 dimensional, but got Query dim: ",
319
+ query_dim,
320
+ ", Key dim: ",
321
+ params.key.dim(),
322
+ ", Value dim: ",
323
+ params.value.dim(),
324
+ " instead.");
325
+ }
326
+ return false;
327
+ }
328
+ return true;
329
+ }
330
+
331
+ inline bool check_safe_kv_broadcast(at::Tensor const& param, bool debug) {
332
+ const auto nt_tensor_impl = at::native::get_nested_tensor_impl(param);
333
+ auto seq_len = nt_tensor_impl->opt_size(2);
334
+ if (!seq_len.has_value()) {
335
+ if (debug) {
336
+ TORCH_WARN(
337
+ "For both fused kernels, if one of key/value batch_size requires "
338
+ "broadcasting and the other does not, then the other must have a ",
339
+ "consistent seq_len dim.")
340
+ }
341
+ return false;
342
+ }
343
+ return true;
344
+ }
345
+
346
+ inline bool check_grouped_query_attention(sdp_params const& params, bool debug) {
347
+ const auto q_num_heads = params.query.sym_size(-3);
348
+ const auto k_num_heads = params.key.sym_size(-3);
349
+ const auto v_num_heads = params.value.sym_size(-3);
350
+ const bool same_kv_heads = k_num_heads == v_num_heads;
351
+
352
+ if (!(same_kv_heads)){
353
+ if (debug) {
354
+ TORCH_WARN(
355
+ "Both fused kernels require key and value to have the same num_heads and batch_size but got: ",
356
+ "Key sizes: ",
357
+ params.key.sizes(),
358
+ ", Value sizes: ",
359
+ params.value.sizes(),
360
+ ", Query sizes: ",
361
+ params.query.sizes(),
362
+ " instead.");
363
+ }
364
+ return false;
365
+ }
366
+ // Check if grouped query attention is supported and validate the number of
367
+ // heads
368
+ if (q_num_heads % k_num_heads != 0) {
369
+ if (debug) {
370
+ TORCH_WARN(
371
+ "FlashAttentionV2 only supports grouped query attention, where the number of heads in key/value must divide number of heads in query.",
372
+ "Got input Key sizes(): ",
373
+ params.key.sym_size(-3),
374
+ ", Value sizes(): ",
375
+ params.value.sym_size(-3),
376
+ ", Query sizes(): ",
377
+ params.query.sym_size(-3),
378
+ " instead.");
379
+ }
380
+ return false;
381
+ }
382
+ return true;
383
+ }
384
+
385
+ template <bool supports_gqa>
386
+ inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool debug) {
387
+ // This is expected to be called after check_tensor_shapes ensuring that the
388
+ // size() calls won't error since the inputs are all 4 dimensional
389
+
390
+ auto q_batch_size = params.query.sym_size(0);
391
+ auto k_batch_size = params.key.sym_size(0);
392
+ auto v_batch_size = params.value.sym_size(0);
393
+
394
+ bool same_batch_size =
395
+ q_batch_size == k_batch_size && q_batch_size == v_batch_size;
396
+
397
+ auto q_num_heads = params.query.sym_size(-3);
398
+ auto k_num_heads = params.key.sym_size(-3);
399
+ auto v_num_heads = params.value.sym_size(-3);
400
+
401
+ bool same_num_heads =
402
+ q_num_heads == k_num_heads && q_num_heads == v_num_heads;
403
+
404
+ if (!same_batch_size){
405
+ if(debug) {
406
+ TORCH_WARN(
407
+ "For dense inputs, both fused kernels require query, key and value to have the same batch_size. ",
408
+ "Query.sizes(): ",
409
+ params.query.sizes(),
410
+ ", Key.sizes(): ",
411
+ params.key.sizes(),
412
+ ", Value.sizes(): ",
413
+ params.value.sizes(),
414
+ " instead. To broadcast dense inputs, try using unsqueeze and expand_to before passing them into the kernel.");
415
+ }
416
+ return false;
417
+ }
418
+
419
+ if(params.enable_gqa && supports_gqa){
420
+ return check_grouped_query_attention(params, debug);
421
+ }
422
+
423
+ if (!same_num_heads){
424
+ if (debug) {
425
+ TORCH_WARN(
426
+ "For dense input, both fused kernels require query, key and value to have the same num_heads. ",
427
+ "Query.sizes(): ",
428
+ params.query.sizes(),
429
+ ", Key sizes(): ",
430
+ params.key.sizes(),
431
+ ", Value sizes(): ",
432
+ params.value.sizes(),
433
+ " instead. To broadcast dense inputs, try using unsqueeze and expand_to before passing them into the kernel.");
434
+ }
435
+ return false;
436
+ }
437
+ // If all checks pass, return true
438
+ return true;
439
+ }
440
+
441
+ inline bool check_batch_size_nested(sdp_params const& params, bool debug) {
442
+ // This is expected to be called after check_tensor_shapes ensuring that the
443
+ // size() calls won't error since the inputs are all 4 dimensional
444
+ auto q_batch_size = params.query.sym_size(0);
445
+ auto k_batch_size = params.key.sym_size(0);
446
+ auto v_batch_size = params.value.sym_size(0);
447
+
448
+ bool same_batch_size =
449
+ q_batch_size == k_batch_size && q_batch_size == v_batch_size;
450
+
451
+ // num_heads logic for nested input is checked in
452
+ // check_for_seq_len_0_nested_tensor as there is handling there to make sure
453
+ // num_heads is not ragged
454
+ bool broadcastable_batch_size = true;
455
+ if (!same_batch_size) {
456
+ if (input_requires_grad(params)){
457
+ if (debug) {
458
+ TORCH_WARN(
459
+ "Both fused kernels do not support training with broadcasted NT inputs.");
460
+ }
461
+ return false;
462
+ }
463
+ // try to broadcast batchsize
464
+ broadcastable_batch_size = try_broadcast_param_size(
465
+ q_batch_size, k_batch_size, v_batch_size, "batch size ", debug);
466
+
467
+ // if only one of k or v require broadcasting of batch size, the other
468
+ // must have a consistent seq_len dim
469
+ if (broadcastable_batch_size) {
470
+ if (k_batch_size == 1 && v_batch_size != 1 &&
471
+ !check_safe_kv_broadcast(params.value, debug)) {
472
+ return false;
473
+ }
474
+ if (v_batch_size == 1 && k_batch_size != 1 &&
475
+ !check_safe_kv_broadcast(params.key, debug)) {
476
+ return false;
477
+ }
478
+ }
479
+ }
480
+ return broadcastable_batch_size;
481
+ }
482
+
483
+ inline bool check_nonzero_sequence_lengths_dense(sdp_params const& params, bool debug) {
484
+ // In some cases people will pass in 0 sized tensors, this will
485
+ // cause the fused path to error with unaligned mask
486
+ bool zero_seq_len_q = params.query.sym_size(-2) == 0;
487
+ bool zero_seq_len_k = params.key.sym_size(-2) == 0;
488
+ if (zero_seq_len_q || zero_seq_len_k) {
489
+ if (debug) {
490
+ TORCH_WARN(
491
+ "All fused kernels do not support zero seq_len_q or seq_len_kv.");
492
+ }
493
+ return false;
494
+ }
495
+ return true;
496
+ }
497
+
498
+ template<bool ignore_singleton_dim>
499
+ inline bool check_last_dim_stride_equals_1_dense(sdp_params const& params, bool debug) {
500
+ // The stride checking for NestedTensors is done within the kernel
501
+ // And .contiguous will be called if needed
502
+
503
+ // This function checks that the last dimension of the inputs to
504
+ // fused_attention have stride 1
505
+ bool qkv_strides_equal_1 = params.query.sym_stride(-1) == 1 &&
506
+ params.key.sym_stride(-1) == 1 && params.value.sym_stride(-1) == 1;
507
+
508
+ // https://github.com/pytorch/pytorch/issues/116333
509
+ // If the head_dim is size 1 the stride won't matter, but we
510
+ // check this condition before padding the head_dim to 1
511
+ if (ignore_singleton_dim){
512
+ qkv_strides_equal_1 = qkv_strides_equal_1 || params.query.sym_size(-1) == 1;
513
+ }
514
+ bool mask_stride_equal_1 = params.attn_mask.has_value()
515
+ ? params.attn_mask.value().sym_stride(-1) == 1
516
+ : true;
517
+ if (!(qkv_strides_equal_1 && mask_stride_equal_1)) {
518
+ if (debug) {
519
+ std::ostringstream epilogue_message;
520
+ if (params.attn_mask.has_value()) {
521
+ epilogue_message << ", Attn_mask.stride(-1): "
522
+ << params.attn_mask.value().sym_stride(-1);
523
+ }
524
+ epilogue_message << " instead.";
525
+ TORCH_WARN(
526
+ "All fused kernels require the last dimension of the input to have stride 1. ",
527
+ "Got Query.stride(-1): ",
528
+ params.query.sym_stride(-1),
529
+ ", Key.stride(-1): ",
530
+ params.key.sym_stride(-1),
531
+ ", Value.stride(-1): ",
532
+ params.value.sym_stride(-1),
533
+ epilogue_message.str());
534
+ }
535
+
536
+ return false;
537
+ }
538
+ return true;
539
+ }
540
+
541
+ inline bool check_runtime_disabled_flash(sdp_params const& params, bool debug) {
542
+ // We check the global context to see if user has explicitly turned of flash
543
+ // sdp kernels
544
+ if (!at::globalContext().userEnabledFlashSDP()) {
545
+ if (debug) {
546
+ TORCH_WARN("Flash attention has been runtime disabled.");
547
+ }
548
+ return false;
549
+ }
550
+ return true;
551
+ }
552
+
553
+ inline bool check_runtime_disabled_mem_efficient(sdp_params const& params, bool debug) {
554
+ // We check the global context to see if user has explicitly turned of
555
+ // mem_efficient sdp kernels
556
+ if (!at::globalContext().userEnabledMemEfficientSDP()) {
557
+ if (debug) {
558
+ TORCH_WARN("Memory Efficient attention has been runtime disabled.");
559
+ }
560
+ return false;
561
+ }
562
+ return true;
563
+ }
564
+
565
+
566
+ } // namespace sdp
.venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/Factory.h ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+
5
+ namespace at {
6
+ namespace native {
7
+ namespace mobile {
8
+
9
+ Tensor allocate_padded_contiguous_if_needed(
10
+ const Tensor& input,
11
+ c10::MemoryFormat memory_format);
12
+
13
+ // TODO: Remove this function when at::native::empty() is modified to accept a
14
+ // custom memory allocator.
15
+
16
+ at::Tensor empty_with_tail_padding(
17
+ IntArrayRef size,
18
+ const caffe2::TypeMeta dtype,
19
+ c10::MemoryFormat memory_format,
20
+ std::optional<DimnameList> maybe_names);
21
+
22
+ } // namespace mobile
23
+ } // namespace native
24
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/ParamUtils.h ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/util/ArrayRef.h>
4
+ #include <vector>
5
+
6
+ namespace at {
7
+ namespace native {
8
+
9
+ template <typename T>
10
+ inline std::vector<T> _expand_param_if_needed(
11
+ ArrayRef<T> list_param,
12
+ const char* param_name,
13
+ int64_t expected_dim) {
14
+ if (list_param.size() == 1) {
15
+ return std::vector<T>(expected_dim, list_param[0]);
16
+ } else if ((int64_t)list_param.size() != expected_dim) {
17
+ std::ostringstream ss;
18
+ ss << "expected " << param_name << " to be a single integer value or a "
19
+ << "list of " << expected_dim << " values to match the convolution "
20
+ << "dimensions, but got " << param_name << "=" << list_param;
21
+ AT_ERROR(ss.str());
22
+ } else {
23
+ return list_param.vec();
24
+ }
25
+ }
26
+
27
+ inline std::vector<int64_t> expand_param_if_needed(
28
+ IntArrayRef list_param,
29
+ const char* param_name,
30
+ int64_t expected_dim) {
31
+ return _expand_param_if_needed(list_param, param_name, expected_dim);
32
+ }
33
+
34
+ inline std::vector<c10::SymInt> expand_param_if_needed(
35
+ SymIntArrayRef list_param,
36
+ const char* param_name,
37
+ int64_t expected_dim) {
38
+ return _expand_param_if_needed(list_param, param_name, expected_dim);
39
+ }
40
+
41
+ } // namespace native
42
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/ParamsHash.h ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/util/irange.h>
4
+ #include <memory>
5
+ #include <mutex>
6
+
7
+ namespace at::native {
8
+
9
+ // Hashing machinery for Params
10
+ // Fowler–Noll–Vo hash function
11
+ // see
12
+ // https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
13
+ template <typename Params>
14
+ struct ParamsHash {
15
+ // Params must be a POD because we read out its memory
16
+ // contents as char* when hashing
17
+ static_assert(std::is_standard_layout_v<Params>, "Params is not POD");
18
+
19
+ size_t operator()(const Params& params) const {
20
+ auto ptr = reinterpret_cast<const uint8_t*>(&params);
21
+ uint32_t value = 0x811C9DC5;
22
+ for (const auto i : c10::irange(sizeof(Params))) {
23
+ value ^= ptr[i];
24
+ value *= 0x01000193;
25
+ }
26
+ return (size_t)value;
27
+ }
28
+ };
29
+
30
+ template <typename Params>
31
+ struct ParamsEqual {
32
+ // Params must be a POD because we read out its memory
33
+ // contents as char* when comparing
34
+ static_assert(std::is_standard_layout_v<Params>, "Params is not POD");
35
+
36
+ bool operator()(const Params& a, const Params& b) const {
37
+ auto ptr1 = reinterpret_cast<const uint8_t*>(&a);
38
+ auto ptr2 = reinterpret_cast<const uint8_t*>(&b);
39
+ return memcmp(ptr1, ptr2, sizeof(Params)) == 0;
40
+ }
41
+ };
42
+
43
+ // Provide explicit byte-for-byte constructors to avoid uwittingly leaving
44
+ // padding bytes unitialized (e.g., when passing Params by value)
45
+ template <typename T>
46
+ struct ParamsWrapper {
47
+ T pod;
48
+ static_assert(
49
+ std::is_standard_layout_v<T>,
50
+ "ParamsWrapper cannot wrap non-POD data");
51
+
52
+ ParamsWrapper() {
53
+ memset(&(this->pod), 0, sizeof(this->pod));
54
+ }
55
+
56
+ ParamsWrapper(const ParamsWrapper& other) {
57
+ memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
58
+ }
59
+
60
+ ParamsWrapper(ParamsWrapper&& other) noexcept {
61
+ memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
62
+ }
63
+
64
+ ParamsWrapper& operator=(const ParamsWrapper& other) {
65
+ memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
66
+ return *this;
67
+ }
68
+
69
+ ParamsWrapper& operator=(ParamsWrapper&& other) noexcept {
70
+ memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
71
+ return *this;
72
+ }
73
+
74
+ inline friend bool operator==(
75
+ const ParamsWrapper& lhs,
76
+ const ParamsWrapper& rhs) noexcept {
77
+ auto ptr1 = reinterpret_cast<const uint8_t*>(&(lhs.pod));
78
+ auto ptr2 = reinterpret_cast<const uint8_t*>(&(rhs.pod));
79
+ return memcmp(ptr1, ptr2, sizeof(lhs.pod)) == 0;
80
+ }
81
+ };
82
+
83
+ // Wrapped version: this allows the outer struct to have custom copy and move
84
+ // constructors for additional safety
85
+ template <typename ParamsWrapper>
86
+ struct ParamsWrapperHash {
87
+ // Params must be a POD because we read out its memory
88
+ // contents as char* when hashing
89
+ static_assert(
90
+ std::is_standard_layout_v<decltype(ParamsWrapper::pod)>,
91
+ "ParamsWrapper cannot wrap non-POD data");
92
+
93
+ size_t operator()(const ParamsWrapper& params_wrapper) const {
94
+ auto ptr = reinterpret_cast<const uint8_t*>(&(params_wrapper.pod));
95
+ uint32_t value = 0x811C9DC5;
96
+ for (const auto i : c10::irange(sizeof(params_wrapper.pod))) {
97
+ value ^= ptr[i];
98
+ value *= 0x01000193;
99
+ }
100
+ return (size_t)value;
101
+ }
102
+ };
103
+
104
+ } // namespace at::native
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_empty_per_channel_affine_quantized.h ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Function.h
4
+
5
+ #include <ATen/Context.h>
6
+ #include <ATen/DeviceGuard.h>
7
+ #include <ATen/TensorUtils.h>
8
+ #include <ATen/TracerMode.h>
9
+ #include <ATen/core/Generator.h>
10
+ #include <ATen/core/Reduction.h>
11
+ #include <ATen/core/Tensor.h>
12
+ #include <c10/core/Scalar.h>
13
+ #include <c10/core/Storage.h>
14
+ #include <c10/core/TensorOptions.h>
15
+ #include <c10/util/Deprecated.h>
16
+ #include <optional>
17
+
18
+
19
+
20
+ #include <ATen/ops/_empty_per_channel_affine_quantized_ops.h>
21
+
22
+ namespace at {
23
+
24
+
25
+ // aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor
26
+ inline at::Tensor _empty_per_channel_affine_quantized(at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::TensorOptions options={}, ::std::optional<at::MemoryFormat> memory_format=c10::MemoryFormat::Contiguous) {
27
+ return at::_ops::_empty_per_channel_affine_quantized::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
28
+ }
29
+ namespace symint {
30
+ template <typename T, typename = std::enable_if_t<std::is_same<T, int64_t>::value>>
31
+ at::Tensor _empty_per_channel_affine_quantized(at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::TensorOptions options={}, ::std::optional<at::MemoryFormat> memory_format=c10::MemoryFormat::Contiguous) {
32
+ return at::_ops::_empty_per_channel_affine_quantized::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
33
+ }
34
+ }
35
+
36
+ // aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor
37
+ inline at::Tensor _empty_per_channel_affine_quantized(at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional<at::ScalarType> dtype, ::std::optional<at::Layout> layout, ::std::optional<at::Device> device, ::std::optional<bool> pin_memory, ::std::optional<at::MemoryFormat> memory_format) {
38
+ return at::_ops::_empty_per_channel_affine_quantized::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, dtype, layout, device, pin_memory, memory_format);
39
+ }
40
+ namespace symint {
41
+ template <typename T, typename = std::enable_if_t<std::is_same<T, int64_t>::value>>
42
+ at::Tensor _empty_per_channel_affine_quantized(at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional<at::ScalarType> dtype, ::std::optional<at::Layout> layout, ::std::optional<at::Device> device, ::std::optional<bool> pin_memory, ::std::optional<at::MemoryFormat> memory_format) {
43
+ return at::_ops::_empty_per_channel_affine_quantized::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, dtype, layout, device, pin_memory, memory_format);
44
+ }
45
+ }
46
+
47
+ // aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor
48
+ inline at::Tensor _empty_per_channel_affine_quantized_symint(c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::TensorOptions options={}, ::std::optional<at::MemoryFormat> memory_format=c10::MemoryFormat::Contiguous) {
49
+ return at::_ops::_empty_per_channel_affine_quantized::call(size, scales, zero_points, axis, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
50
+ }
51
+ namespace symint {
52
+ template <typename T, typename = std::enable_if_t<std::is_same<T, c10::SymInt>::value>>
53
+ at::Tensor _empty_per_channel_affine_quantized(c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::TensorOptions options={}, ::std::optional<at::MemoryFormat> memory_format=c10::MemoryFormat::Contiguous) {
54
+ return at::_ops::_empty_per_channel_affine_quantized::call(size, scales, zero_points, axis, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
55
+ }
56
+ }
57
+
58
+ // aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor
59
+ inline at::Tensor _empty_per_channel_affine_quantized_symint(c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional<at::ScalarType> dtype, ::std::optional<at::Layout> layout, ::std::optional<at::Device> device, ::std::optional<bool> pin_memory, ::std::optional<at::MemoryFormat> memory_format) {
60
+ return at::_ops::_empty_per_channel_affine_quantized::call(size, scales, zero_points, axis, dtype, layout, device, pin_memory, memory_format);
61
+ }
62
+ namespace symint {
63
+ template <typename T, typename = std::enable_if_t<std::is_same<T, c10::SymInt>::value>>
64
+ at::Tensor _empty_per_channel_affine_quantized(c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional<at::ScalarType> dtype, ::std::optional<at::Layout> layout, ::std::optional<at::Device> device, ::std::optional<bool> pin_memory, ::std::optional<at::MemoryFormat> memory_format) {
65
+ return at::_ops::_empty_per_channel_affine_quantized::call(size, scales, zero_points, axis, dtype, layout, device, pin_memory, memory_format);
66
+ }
67
+ }
68
+
69
+ // aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!)
70
+ inline at::Tensor & _empty_per_channel_affine_quantized_out(at::Tensor & out, at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional<at::MemoryFormat> memory_format=c10::MemoryFormat::Contiguous) {
71
+ return at::_ops::_empty_per_channel_affine_quantized_out::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, memory_format, out);
72
+ }
73
+ namespace symint {
74
+ template <typename T, typename = std::enable_if_t<std::is_same<T, int64_t>::value>>
75
+ at::Tensor & _empty_per_channel_affine_quantized_out(at::Tensor & out, at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional<at::MemoryFormat> memory_format=c10::MemoryFormat::Contiguous) {
76
+ return at::_ops::_empty_per_channel_affine_quantized_out::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, memory_format, out);
77
+ }
78
+ }
79
+
80
+ // aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!)
81
+ inline at::Tensor & _empty_per_channel_affine_quantized_outf(at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional<at::MemoryFormat> memory_format, at::Tensor & out) {
82
+ return at::_ops::_empty_per_channel_affine_quantized_out::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, memory_format, out);
83
+ }
84
+ namespace symint {
85
+ template <typename T, typename = std::enable_if_t<std::is_same<T, int64_t>::value>>
86
+ at::Tensor & _empty_per_channel_affine_quantized_outf(at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional<at::MemoryFormat> memory_format, at::Tensor & out) {
87
+ return at::_ops::_empty_per_channel_affine_quantized_out::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, memory_format, out);
88
+ }
89
+ }
90
+
91
+ // aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!)
92
+ inline at::Tensor & _empty_per_channel_affine_quantized_symint_out(at::Tensor & out, c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional<at::MemoryFormat> memory_format=c10::MemoryFormat::Contiguous) {
93
+ return at::_ops::_empty_per_channel_affine_quantized_out::call(size, scales, zero_points, axis, memory_format, out);
94
+ }
95
+ namespace symint {
96
+ template <typename T, typename = std::enable_if_t<std::is_same<T, c10::SymInt>::value>>
97
+ at::Tensor & _empty_per_channel_affine_quantized_out(at::Tensor & out, c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional<at::MemoryFormat> memory_format=c10::MemoryFormat::Contiguous) {
98
+ return at::_ops::_empty_per_channel_affine_quantized_out::call(size, scales, zero_points, axis, memory_format, out);
99
+ }
100
+ }
101
+
102
+ // aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!)
103
+ inline at::Tensor & _empty_per_channel_affine_quantized_symint_outf(c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional<at::MemoryFormat> memory_format, at::Tensor & out) {
104
+ return at::_ops::_empty_per_channel_affine_quantized_out::call(size, scales, zero_points, axis, memory_format, out);
105
+ }
106
+ namespace symint {
107
+ template <typename T, typename = std::enable_if_t<std::is_same<T, c10::SymInt>::value>>
108
+ at::Tensor & _empty_per_channel_affine_quantized_outf(c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional<at::MemoryFormat> memory_format, at::Tensor & out) {
109
+ return at::_ops::_empty_per_channel_affine_quantized_out::call(size, scales, zero_points, axis, memory_format, out);
110
+ }
111
+ }
112
+
113
+ }
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_addcmul_native.h ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from NativeFunction.h
4
+
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/core/Storage.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/util/Deprecated.h>
9
+ #include <optional>
10
+ #include <c10/core/QScheme.h>
11
+ #include <ATen/core/Reduction.h>
12
+ #include <ATen/core/Tensor.h>
13
+ #include <tuple>
14
+ #include <vector>
15
+
16
+
17
+ namespace at {
18
+ namespace native {
19
+ TORCH_API ::std::vector<at::Tensor> foreach_tensor_addcmul_scalar_slow(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1);
20
+ TORCH_API void _foreach_addcmul_Scalar_out(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value, at::TensorList out);
21
+ TORCH_API void foreach_tensor_addcmul_scalar_slow_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1);
22
+ TORCH_API ::std::vector<at::Tensor> foreach_tensor_addcmul_scalar_cuda(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1);
23
+ TORCH_API void foreach_tensor_addcmul_scalar_cuda_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1);
24
+ TORCH_API ::std::vector<at::Tensor> foreach_tensor_addcmul_scalarlist_slow(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef<at::Scalar> scalars);
25
+ TORCH_API void _foreach_addcmul_ScalarList_out(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef<at::Scalar> scalars, at::TensorList out);
26
+ TORCH_API void foreach_tensor_addcmul_scalarlist_slow_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef<at::Scalar> scalars);
27
+ TORCH_API ::std::vector<at::Tensor> foreach_tensor_addcmul_scalarlist_cuda(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef<at::Scalar> scalars);
28
+ TORCH_API void foreach_tensor_addcmul_scalarlist_cuda_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef<at::Scalar> scalars);
29
+ TORCH_API ::std::vector<at::Tensor> foreach_tensor_addcmul_tensor_slow(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars);
30
+ TORCH_API void _foreach_addcmul_Tensor_out(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars, at::TensorList out);
31
+ TORCH_API void foreach_tensor_addcmul_tensor_slow_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars);
32
+ TORCH_API ::std::vector<at::Tensor> foreach_tensor_addcmul_tensor_cuda(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars);
33
+ TORCH_API void foreach_tensor_addcmul_tensor_cuda_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars);
34
+ } // namespace native
35
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_expm1_ops.h ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Operator.h
4
+
5
+ #include <tuple>
6
+ #include <vector>
7
+
8
+ // Forward declarations of any types needed in the operator signatures.
9
+ // We can't directly include these classes because it will cause circular include dependencies.
10
+ // This file is included by TensorBody.h, which defines the Tensor class.
11
+ #include <ATen/core/ATen_fwd.h>
12
+
13
+ namespace at {
14
+ namespace _ops {
15
+
16
+
17
+ struct TORCH_API _foreach_expm1 {
18
+ using schema = ::std::vector<at::Tensor> (at::TensorList);
19
+ using ptr_schema = schema*;
20
+ // See Note [static constexpr char* members for windows NVCC]
21
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_foreach_expm1")
22
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
23
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_foreach_expm1(Tensor[] self) -> Tensor[]")
24
+ static ::std::vector<at::Tensor> call(at::TensorList self);
25
+ static ::std::vector<at::Tensor> redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList self);
26
+ };
27
+
28
+ struct TORCH_API _foreach_expm1_ {
29
+ using schema = void (at::TensorList);
30
+ using ptr_schema = schema*;
31
+ // See Note [static constexpr char* members for windows NVCC]
32
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_foreach_expm1_")
33
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
34
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_foreach_expm1_(Tensor(a!)[] self) -> ()")
35
+ static void call(at::TensorList self);
36
+ static void redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList self);
37
+ };
38
+
39
+ struct TORCH_API _foreach_expm1_out {
40
+ using schema = void (at::TensorList, at::TensorList);
41
+ using ptr_schema = schema*;
42
+ // See Note [static constexpr char* members for windows NVCC]
43
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_foreach_expm1")
44
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
45
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_foreach_expm1.out(Tensor[] self, *, Tensor(a!)[] out) -> ()")
46
+ static void call(at::TensorList self, at::TensorList out);
47
+ static void redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out);
48
+ };
49
+
50
+ }} // namespace at::_ops
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_sparse_mask_projection_ops.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Operator.h
4
+
5
+ #include <tuple>
6
+ #include <vector>
7
+
8
+ // Forward declarations of any types needed in the operator signatures.
9
+ // We can't directly include these classes because it will cause circular include dependencies.
10
+ // This file is included by TensorBody.h, which defines the Tensor class.
11
+ #include <ATen/core/ATen_fwd.h>
12
+
13
+ namespace at {
14
+ namespace _ops {
15
+
16
+
17
+ struct TORCH_API _sparse_mask_projection {
18
+ using schema = at::Tensor (const at::Tensor &, const at::Tensor &, bool);
19
+ using ptr_schema = schema*;
20
+ // See Note [static constexpr char* members for windows NVCC]
21
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_sparse_mask_projection")
22
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
23
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_sparse_mask_projection(Tensor self, Tensor mask, bool accumulate_matches=False) -> Tensor")
24
+ static at::Tensor call(const at::Tensor & self, const at::Tensor & mask, bool accumulate_matches);
25
+ static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, bool accumulate_matches);
26
+ };
27
+
28
+ struct TORCH_API _sparse_mask_projection_out {
29
+ using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, bool, at::Tensor &);
30
+ using ptr_schema = schema*;
31
+ // See Note [static constexpr char* members for windows NVCC]
32
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_sparse_mask_projection")
33
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
34
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_sparse_mask_projection.out(Tensor self, Tensor mask, bool accumulate_matches=False, *, Tensor(a!) out) -> Tensor(a!)")
35
+ static at::Tensor & call(const at::Tensor & self, const at::Tensor & mask, bool accumulate_matches, at::Tensor & out);
36
+ static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, bool accumulate_matches, at::Tensor & out);
37
+ };
38
+
39
+ }} // namespace at::_ops
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_sparse_sum_backward_ops.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Operator.h
4
+
5
+ #include <tuple>
6
+ #include <vector>
7
+
8
+ // Forward declarations of any types needed in the operator signatures.
9
+ // We can't directly include these classes because it will cause circular include dependencies.
10
+ // This file is included by TensorBody.h, which defines the Tensor class.
11
+ #include <ATen/core/ATen_fwd.h>
12
+
13
+ namespace at {
14
+ namespace _ops {
15
+
16
+
17
+ struct TORCH_API _sparse_sum_backward {
18
+ using schema = at::Tensor (const at::Tensor &, const at::Tensor &, at::IntArrayRef);
19
+ using ptr_schema = schema*;
20
+ // See Note [static constexpr char* members for windows NVCC]
21
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_sparse_sum_backward")
22
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
23
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_sparse_sum_backward(Tensor grad, Tensor self, int[] dim) -> Tensor")
24
+ static at::Tensor call(const at::Tensor & grad, const at::Tensor & self, at::IntArrayRef dim);
25
+ static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, at::IntArrayRef dim);
26
+ };
27
+
28
+ struct TORCH_API _sparse_sum_backward_out {
29
+ using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, at::IntArrayRef, at::Tensor &);
30
+ using ptr_schema = schema*;
31
+ // See Note [static constexpr char* members for windows NVCC]
32
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_sparse_sum_backward")
33
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
34
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_sparse_sum_backward.out(Tensor grad, Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!)")
35
+ static at::Tensor & call(const at::Tensor & grad, const at::Tensor & self, at::IntArrayRef dim, at::Tensor & out);
36
+ static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, at::IntArrayRef dim, at::Tensor & out);
37
+ };
38
+
39
+ }} // namespace at::_ops
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_test_autograd_multiple_dispatch_view_copy_ops.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Operator.h
4
+
5
+ #include <tuple>
6
+ #include <vector>
7
+
8
+ // Forward declarations of any types needed in the operator signatures.
9
+ // We can't directly include these classes because it will cause circular include dependencies.
10
+ // This file is included by TensorBody.h, which defines the Tensor class.
11
+ #include <ATen/core/ATen_fwd.h>
12
+
13
+ namespace at {
14
+ namespace _ops {
15
+
16
+
17
+ struct TORCH_API _test_autograd_multiple_dispatch_view_copy {
18
+ using schema = at::Tensor (const at::Tensor &);
19
+ using ptr_schema = schema*;
20
+ // See Note [static constexpr char* members for windows NVCC]
21
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_test_autograd_multiple_dispatch_view_copy")
22
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
23
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_test_autograd_multiple_dispatch_view_copy(Tensor self) -> Tensor")
24
+ static at::Tensor call(const at::Tensor & self);
25
+ static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self);
26
+ };
27
+
28
+ struct TORCH_API _test_autograd_multiple_dispatch_view_copy_out {
29
+ using schema = at::Tensor & (const at::Tensor &, at::Tensor &);
30
+ using ptr_schema = schema*;
31
+ // See Note [static constexpr char* members for windows NVCC]
32
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_test_autograd_multiple_dispatch_view_copy")
33
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
34
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_test_autograd_multiple_dispatch_view_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)")
35
+ static at::Tensor & call(const at::Tensor & self, at::Tensor & out);
36
+ static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out);
37
+ };
38
+
39
+ }} // namespace at::_ops
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_unsafe_masked_index_put_accumulate_ops.h ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Operator.h
4
+
5
+ #include <tuple>
6
+ #include <vector>
7
+
8
+ // Forward declarations of any types needed in the operator signatures.
9
+ // We can't directly include these classes because it will cause circular include dependencies.
10
+ // This file is included by TensorBody.h, which defines the Tensor class.
11
+ #include <ATen/core/ATen_fwd.h>
12
+
13
+ namespace at {
14
+ namespace _ops {
15
+
16
+
17
+ struct TORCH_API _unsafe_masked_index_put_accumulate {
18
+ using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const c10::List<::std::optional<at::Tensor>> &, const at::Tensor &);
19
+ using ptr_schema = schema*;
20
+ // See Note [static constexpr char* members for windows NVCC]
21
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_unsafe_masked_index_put_accumulate")
22
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
23
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_unsafe_masked_index_put_accumulate(Tensor self, Tensor mask, Tensor?[] indices, Tensor values) -> Tensor")
24
+ static at::Tensor call(const at::Tensor & self, const at::Tensor & mask, const c10::List<::std::optional<at::Tensor>> & indices, const at::Tensor & values);
25
+ static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, const c10::List<::std::optional<at::Tensor>> & indices, const at::Tensor & values);
26
+ };
27
+
28
+ }} // namespace at::_ops
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/alias_ops.h ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from Operator.h
4
+
5
+ #include <tuple>
6
+ #include <vector>
7
+
8
+ // Forward declarations of any types needed in the operator signatures.
9
+ // We can't directly include these classes because it will cause circular include dependencies.
10
+ // This file is included by TensorBody.h, which defines the Tensor class.
11
+ #include <ATen/core/ATen_fwd.h>
12
+
13
+ namespace at {
14
+ namespace _ops {
15
+
16
+
17
+ struct TORCH_API alias {
18
+ using schema = at::Tensor (const at::Tensor &);
19
+ using ptr_schema = schema*;
20
+ // See Note [static constexpr char* members for windows NVCC]
21
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::alias")
22
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
23
+ STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "alias(Tensor(a) self) -> Tensor(a)")
24
+ static at::Tensor call(const at::Tensor & self);
25
+ static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self);
26
+ };
27
+
28
+ }} // namespace at::_ops
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/arcsinh_compositeimplicitautograd_dispatch.h ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace compositeimplicitautograd {
19
+
20
+ TORCH_API at::Tensor arcsinh(const at::Tensor & self);
21
+ TORCH_API at::Tensor & arcsinh_out(at::Tensor & out, const at::Tensor & self);
22
+ TORCH_API at::Tensor & arcsinh_outf(const at::Tensor & self, at::Tensor & out);
23
+ TORCH_API at::Tensor & arcsinh_(at::Tensor & self);
24
+
25
+ } // namespace compositeimplicitautograd
26
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/block_diag_native.h ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // @generated by torchgen/gen.py from NativeFunction.h
4
+
5
+ #include <c10/core/Scalar.h>
6
+ #include <c10/core/Storage.h>
7
+ #include <c10/core/TensorOptions.h>
8
+ #include <c10/util/Deprecated.h>
9
+ #include <optional>
10
+ #include <c10/core/QScheme.h>
11
+ #include <ATen/core/Reduction.h>
12
+ #include <ATen/core/Tensor.h>
13
+ #include <tuple>
14
+ #include <vector>
15
+
16
+
17
+ namespace at {
18
+ namespace native {
19
+ TORCH_API at::Tensor block_diag(at::TensorList tensors);
20
+ TORCH_API at::Tensor & block_diag_out(at::TensorList tensors, at::Tensor & out);
21
+ } // namespace native
22
+ } // namespace at
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/cat_compositeexplicitautogradnonfunctional_dispatch.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ // Forward declarations of any types needed in the operator signatures.
12
+ // We can't directly include these classes because it will cause circular include dependencies.
13
+ // This file is included by TensorBody.h, which defines the Tensor class.
14
+ #include <ATen/core/ATen_fwd.h>
15
+
16
+ namespace at {
17
+
18
+ namespace compositeexplicitautogradnonfunctional {
19
+
20
+ TORCH_API at::Tensor cat(const at::ITensorListRef & tensors, int64_t dim=0);
21
+
22
+ } // namespace compositeexplicitautogradnonfunctional
23
+ } // namespace at