Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/torch/include/ATen/mps/EmptyTensor.h +29 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/mps/IndexKernels.h +535 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSAllocator.h +403 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSAllocatorInterface.h +64 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSDevice.h +84 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSEvent.h +100 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSGeneratorImpl.h +52 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSGuardImpl.h +179 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSHooks.h +60 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSProfiler.h +402 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSStream.h +133 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/CatKernel.h +12 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Loops.h +394 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/ReduceUtils.h +238 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/AffineQuantizer.h +130 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/AffineQuantizerBase.h +47 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/ConvUtils.h +62 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/Copy.h +10 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/FakeQuantAffine.h +67 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/IndexKernel.h +14 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/PackedParams.h +147 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/BinaryOps.h +8 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/EmbeddingPackedParams.h +29 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/OnednnUtils.h +457 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/QnnpackUtils.h +527 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/QuantUtils.h +239 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/QuantizedOps.h +258 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/RuyUtils.h +21 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/XnnpackUtils.h +335 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/conv_serialization.h +414 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/fbgemm_utils.h +413 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/init_qnnpack.h +13 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/qembeddingbag.h +34 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/qembeddingbag_prepack.h +13 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/transformers/attention.h +72 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/transformers/sdp_utils_cpp.h +566 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/Factory.h +24 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/ParamUtils.h +42 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/ParamsHash.h +104 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_empty_per_channel_affine_quantized.h +113 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_addcmul_native.h +35 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_expm1_ops.h +50 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_sparse_mask_projection_ops.h +39 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_sparse_sum_backward_ops.h +39 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_test_autograd_multiple_dispatch_view_copy_ops.h +39 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/ops/_unsafe_masked_index_put_accumulate_ops.h +28 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/ops/alias_ops.h +28 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/ops/arcsinh_compositeimplicitautograd_dispatch.h +26 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/ops/block_diag_native.h +22 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/ops/cat_compositeexplicitautogradnonfunctional_dispatch.h +23 -0
.venv/lib/python3.11/site-packages/torch/include/ATen/mps/EmptyTensor.h
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2022 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
#include <ATen/core/TensorBase.h>
|
| 5 |
+
|
| 6 |
+
namespace at::detail {
|
| 7 |
+
|
| 8 |
+
C10_EXPORT TensorBase empty_mps(
|
| 9 |
+
IntArrayRef size,
|
| 10 |
+
std::optional<ScalarType> dtype_opt,
|
| 11 |
+
std::optional<Layout> layout_opt,
|
| 12 |
+
std::optional<Device> device_opt,
|
| 13 |
+
std::optional<bool> pin_memory_opt,
|
| 14 |
+
std::optional<c10::MemoryFormat> memory_format_opt);
|
| 15 |
+
C10_EXPORT TensorBase empty_mps(
|
| 16 |
+
IntArrayRef size, const TensorOptions &options);
|
| 17 |
+
|
| 18 |
+
C10_EXPORT TensorBase empty_strided_mps(
|
| 19 |
+
IntArrayRef size,
|
| 20 |
+
IntArrayRef stride,
|
| 21 |
+
ScalarType dtype,
|
| 22 |
+
std::optional<Device> device_opt);
|
| 23 |
+
|
| 24 |
+
C10_EXPORT TensorBase empty_strided_mps(
|
| 25 |
+
IntArrayRef size,
|
| 26 |
+
IntArrayRef stride,
|
| 27 |
+
const TensorOptions &options);
|
| 28 |
+
|
| 29 |
+
} // namespace at::detail
|
.venv/lib/python3.11/site-packages/torch/include/ATen/mps/IndexKernels.h
ADDED
|
@@ -0,0 +1,535 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
namespace at::mps {
|
| 4 |
+
|
| 5 |
+
static const char * indexing_metal_shaders = R"INDEX_METAL(
|
| 6 |
+
#include <metal_stdlib>
|
| 7 |
+
#include <metal_atomic>
|
| 8 |
+
|
| 9 |
+
using namespace metal;
|
| 10 |
+
|
| 11 |
+
struct IndexAB {
|
| 12 |
+
constant int64_t* indexArray;
|
| 13 |
+
};
|
| 14 |
+
|
| 15 |
+
template<typename T, typename OffsetsT>
|
| 16 |
+
kernel void index_select(
|
| 17 |
+
constant IndexAB * indexAB [[buffer(0)]],
|
| 18 |
+
constant void * indexSizes [[buffer(1)]],
|
| 19 |
+
constant void * indexStrides [[buffer(2)]],
|
| 20 |
+
constant OffsetsT * offsets [[buffer(3)]],
|
| 21 |
+
constant void * inputData [[buffer(4)]],
|
| 22 |
+
device void * outputData [[buffer(5)]],
|
| 23 |
+
constant uint32_t & num_indices [[buffer(6)]],
|
| 24 |
+
uint thread_index [[thread_position_in_grid]]) {
|
| 25 |
+
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
|
| 26 |
+
constant int64_t * index_strides = (constant int64_t *)indexStrides;
|
| 27 |
+
int64_t offset = 0;
|
| 28 |
+
for (uint32_t i = 0; i < num_indices; i++) {
|
| 29 |
+
constant int64_t* indexArray = indexAB[i].indexArray;
|
| 30 |
+
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
|
| 31 |
+
if (index < 0) {
|
| 32 |
+
index += index_sizes[i];
|
| 33 |
+
}
|
| 34 |
+
offset += index * index_strides[i];
|
| 35 |
+
}
|
| 36 |
+
device T * out = (device T*)((device char*)outputData + offsets[thread_index].x);
|
| 37 |
+
constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y + offset);
|
| 38 |
+
*out = *in;
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
template<typename T, typename OffsetsT>
|
| 42 |
+
void index_put_impl(
|
| 43 |
+
constant IndexAB * indexAB,
|
| 44 |
+
constant int64_t * index_sizes,
|
| 45 |
+
constant int64_t * index_strides,
|
| 46 |
+
constant OffsetsT * offsets,
|
| 47 |
+
constant void * inputData,
|
| 48 |
+
device void * outputData,
|
| 49 |
+
constant uint32_t & num_indices,
|
| 50 |
+
uint thread_index) {
|
| 51 |
+
int64_t offset = 0;
|
| 52 |
+
for (uint32_t i = 0; i < num_indices; i++) {
|
| 53 |
+
constant int64_t* indexArray = indexAB[i].indexArray;
|
| 54 |
+
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
|
| 55 |
+
|
| 56 |
+
if (index < 0) {
|
| 57 |
+
index += index_sizes[i];
|
| 58 |
+
}
|
| 59 |
+
offset += index * index_strides[i];
|
| 60 |
+
}
|
| 61 |
+
device T * out = (device T*)((device char*)outputData + offsets[thread_index].x + offset);
|
| 62 |
+
constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y);
|
| 63 |
+
*out = *in;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
template<typename T, typename OffsetsT>
|
| 67 |
+
kernel void index_put_serial(
|
| 68 |
+
constant IndexAB * indexAB [[buffer(0)]],
|
| 69 |
+
constant void * indexSizes [[buffer(1)]],
|
| 70 |
+
constant void * indexStrides [[buffer(2)]],
|
| 71 |
+
constant OffsetsT * offsets [[buffer(3)]],
|
| 72 |
+
constant void * inputData [[buffer(4)]],
|
| 73 |
+
device void * outputData [[buffer(5)]],
|
| 74 |
+
constant uint32_t & num_indices [[buffer(6)]],
|
| 75 |
+
constant uint * numIters [[buffer(7)]],
|
| 76 |
+
uint thread_index [[thread_position_in_grid]]) {
|
| 77 |
+
|
| 78 |
+
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
|
| 79 |
+
constant int64_t * index_strides = (constant int64_t *)indexStrides;
|
| 80 |
+
|
| 81 |
+
for (uint iter_i = 0; iter_i < *numIters; iter_i++) {
|
| 82 |
+
index_put_impl<T>(indexAB, index_sizes, index_strides, offsets, inputData, outputData, num_indices, iter_i);
|
| 83 |
+
}
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
template<typename T, typename OffsetsT>
|
| 87 |
+
kernel void index_put(
|
| 88 |
+
constant IndexAB * indexAB [[buffer(0)]],
|
| 89 |
+
constant void * indexSizes [[buffer(1)]],
|
| 90 |
+
constant void * indexStrides [[buffer(2)]],
|
| 91 |
+
constant OffsetsT * offsets [[buffer(3)]],
|
| 92 |
+
constant void * inputData [[buffer(4)]],
|
| 93 |
+
device void * outputData [[buffer(5)]],
|
| 94 |
+
constant uint32_t & num_indices [[buffer(6)]],
|
| 95 |
+
uint thread_index [[thread_position_in_grid]]) {
|
| 96 |
+
|
| 97 |
+
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
|
| 98 |
+
constant int64_t * index_strides = (constant int64_t *)indexStrides;
|
| 99 |
+
index_put_impl<T>(indexAB, index_sizes, index_strides, offsets, inputData, outputData, num_indices, thread_index);
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
#define REGISTER_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
|
| 103 |
+
template \
|
| 104 |
+
[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
|
| 105 |
+
kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
|
| 106 |
+
constant IndexAB * indexAB [[buffer(0)]], \
|
| 107 |
+
constant void * indexSizes [[buffer(1)]], \
|
| 108 |
+
constant void * indexStrides [[buffer(2)]], \
|
| 109 |
+
constant IDX_DTYPE * offsets [[buffer(3)]], \
|
| 110 |
+
constant void * inputData [[buffer(4)]], \
|
| 111 |
+
device void * outputData [[buffer(5)]], \
|
| 112 |
+
constant uint32_t & num_indices [[buffer(6)]], \
|
| 113 |
+
uint thread_index [[thread_position_in_grid]]);
|
| 114 |
+
|
| 115 |
+
#define REGISTER_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \
|
| 116 |
+
REGISTER_INDEX_OP(8bit, idx32, char, INDEX_OP_TYPE, uint3); \
|
| 117 |
+
REGISTER_INDEX_OP(8bit, idx64, char, INDEX_OP_TYPE, ulong3); \
|
| 118 |
+
REGISTER_INDEX_OP(16bit, idx32, short, INDEX_OP_TYPE, uint3); \
|
| 119 |
+
REGISTER_INDEX_OP(16bit, idx64, short, INDEX_OP_TYPE, ulong3); \
|
| 120 |
+
REGISTER_INDEX_OP(32bit, idx32, int, INDEX_OP_TYPE, uint3); \
|
| 121 |
+
REGISTER_INDEX_OP(32bit, idx64, int, INDEX_OP_TYPE, ulong3); \
|
| 122 |
+
REGISTER_INDEX_OP(64bit, idx32, long, INDEX_OP_TYPE, uint3); \
|
| 123 |
+
REGISTER_INDEX_OP(64bit, idx64, long, INDEX_OP_TYPE, ulong3);
|
| 124 |
+
|
| 125 |
+
REGISTER_INDEX_OP_ALL_DTYPES(select);
|
| 126 |
+
REGISTER_INDEX_OP_ALL_DTYPES(put);
|
| 127 |
+
|
| 128 |
+
#define REGISTER_SINGLE_THREADED_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
|
| 129 |
+
template \
|
| 130 |
+
[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
|
| 131 |
+
kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
|
| 132 |
+
constant IndexAB * indexAB [[buffer(0)]], \
|
| 133 |
+
constant void * indexSizes [[buffer(1)]], \
|
| 134 |
+
constant void * indexStrides [[buffer(2)]], \
|
| 135 |
+
constant IDX_DTYPE * offsets [[buffer(3)]], \
|
| 136 |
+
constant void * inputData [[buffer(4)]], \
|
| 137 |
+
device void * outputData [[buffer(5)]], \
|
| 138 |
+
constant uint32_t & num_indices [[buffer(6)]], \
|
| 139 |
+
constant uint * numIters [[buffer(7)]], \
|
| 140 |
+
uint thread_index [[thread_position_in_grid]]);
|
| 141 |
+
|
| 142 |
+
#define REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \
|
| 143 |
+
REGISTER_SINGLE_THREADED_INDEX_OP(8bit, idx32, char, INDEX_OP_TYPE, uint3); \
|
| 144 |
+
REGISTER_SINGLE_THREADED_INDEX_OP(8bit, idx64, char, INDEX_OP_TYPE, ulong3); \
|
| 145 |
+
REGISTER_SINGLE_THREADED_INDEX_OP(16bit, idx32, short, INDEX_OP_TYPE, uint3); \
|
| 146 |
+
REGISTER_SINGLE_THREADED_INDEX_OP(16bit, idx64, short, INDEX_OP_TYPE, ulong3); \
|
| 147 |
+
REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx32, int, INDEX_OP_TYPE, uint3); \
|
| 148 |
+
REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx64, int, INDEX_OP_TYPE, ulong3); \
|
| 149 |
+
REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx32, long, INDEX_OP_TYPE, uint3); \
|
| 150 |
+
REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx64, long, INDEX_OP_TYPE, ulong3);
|
| 151 |
+
|
| 152 |
+
REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(put_serial);
|
| 153 |
+
|
| 154 |
+
template<typename StridesT, typename DataT>
|
| 155 |
+
kernel void kernel_index_offsets(constant StridesT * strides [[buffer(0)]],
|
| 156 |
+
device DataT * data_offsets [[buffer(1)]],
|
| 157 |
+
constant uint * iter_shape [[buffer(2)]],
|
| 158 |
+
constant uint & num_dimensions [[buffer(3)]],
|
| 159 |
+
uint thread_index [[thread_position_in_grid]]) {
|
| 160 |
+
data_offsets[thread_index] = 0;
|
| 161 |
+
uint32_t idx = thread_index;
|
| 162 |
+
for (uint32_t dim = 0; dim < num_dimensions; dim++) {
|
| 163 |
+
uint32_t remainder = idx % iter_shape[dim];
|
| 164 |
+
idx /= iter_shape[dim];
|
| 165 |
+
|
| 166 |
+
data_offsets[thread_index] += remainder * DataT(strides[dim]);
|
| 167 |
+
}
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
template
|
| 171 |
+
[[host_name("kernel_index_offsets_32")]]
|
| 172 |
+
kernel void kernel_index_offsets<packed_uint3, uint3>(
|
| 173 |
+
constant packed_uint3 * strides [[buffer(0)]],
|
| 174 |
+
device uint3 * data_offsets [[buffer(1)]],
|
| 175 |
+
constant uint * iter_shape [[buffer(2)]],
|
| 176 |
+
constant uint & num_dimensions [[buffer(3)]],
|
| 177 |
+
uint thread_index [[thread_position_in_grid]]);
|
| 178 |
+
|
| 179 |
+
template
|
| 180 |
+
[[host_name("kernel_index_offsets_64")]]
|
| 181 |
+
kernel void kernel_index_offsets<packed_uint3, ulong3>(
|
| 182 |
+
constant packed_uint3 * strides [[buffer(0)]],
|
| 183 |
+
device ulong3 * data_offsets [[buffer(1)]],
|
| 184 |
+
constant uint * iter_shape [[buffer(2)]],
|
| 185 |
+
constant uint & num_dimensions [[buffer(3)]],
|
| 186 |
+
uint thread_index [[thread_position_in_grid]]);
|
| 187 |
+
|
| 188 |
+
template<typename T, typename E, typename OffsetsT>
|
| 189 |
+
kernel void index_put_accumulate_native_dtypes(
|
| 190 |
+
constant IndexAB * indexAB [[buffer(0)]],
|
| 191 |
+
constant void * indexSizes [[buffer(1)]],
|
| 192 |
+
constant void * indexStrides [[buffer(2)]],
|
| 193 |
+
constant OffsetsT * offsets [[buffer(3)]],
|
| 194 |
+
constant void * inputData [[buffer(4)]],
|
| 195 |
+
device void * outputData [[buffer(5)]],
|
| 196 |
+
constant uint32_t & num_indices [[buffer(6)]],
|
| 197 |
+
uint thread_index [[thread_position_in_grid]]) {
|
| 198 |
+
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
|
| 199 |
+
constant int64_t * index_strides = (constant int64_t *)indexStrides;
|
| 200 |
+
int64_t offset = 0;
|
| 201 |
+
for (uint32_t i = 0; i < num_indices; i++) {
|
| 202 |
+
constant int64_t* indexArray = indexAB[i].indexArray;
|
| 203 |
+
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
|
| 204 |
+
if (index < 0) {
|
| 205 |
+
index += index_sizes[i];
|
| 206 |
+
}
|
| 207 |
+
offset += index * index_strides[i];
|
| 208 |
+
}
|
| 209 |
+
device T * out = (device T*)((device char*)outputData + offsets[thread_index].x + offset);
|
| 210 |
+
constant E * in = (constant E*)((constant char*)inputData + offsets[thread_index].y);
|
| 211 |
+
atomic_fetch_add_explicit(out, *in, memory_order_relaxed);
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
template<typename T>
|
| 215 |
+
__attribute__((__always_inline__)) void atomic_fetch_add_relaxed(device void * addr, T value) {
|
| 216 |
+
device atomic_uint* uintAddr = (device atomic_uint*)addr;
|
| 217 |
+
uint expected = atomic_load_explicit(uintAddr, memory_order_relaxed);
|
| 218 |
+
T updated = as_type<T>(expected) + value;
|
| 219 |
+
while (!atomic_compare_exchange_weak_explicit(uintAddr, &expected, as_type<uint>(updated), memory_order_relaxed, memory_order_relaxed)) {
|
| 220 |
+
updated = as_type<T>(expected) + value;
|
| 221 |
+
}
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
template<typename T, typename OffsetsT>
|
| 225 |
+
kernel void atomic_index_put_accumulate(
|
| 226 |
+
constant IndexAB * indexAB [[buffer(0)]],
|
| 227 |
+
constant void * indexSizes [[buffer(1)]],
|
| 228 |
+
constant void * indexStrides [[buffer(2)]],
|
| 229 |
+
constant OffsetsT * offsets [[buffer(3)]],
|
| 230 |
+
constant void * inputData [[buffer(4)]],
|
| 231 |
+
device void * outputData [[buffer(5)]],
|
| 232 |
+
constant uint32_t & num_indices [[buffer(6)]],
|
| 233 |
+
uint thread_index [[thread_position_in_grid]]) {
|
| 234 |
+
constant int64_t * index_sizes = (constant int64_t *)indexSizes;
|
| 235 |
+
constant int64_t * index_strides = (constant int64_t *)indexStrides;
|
| 236 |
+
int64_t offset = 0;
|
| 237 |
+
for (uint32_t i = 0; i < num_indices; i++) {
|
| 238 |
+
constant int64_t* indexArray = indexAB[i].indexArray;
|
| 239 |
+
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
|
| 240 |
+
if (index < 0) {
|
| 241 |
+
index += index_sizes[i];
|
| 242 |
+
}
|
| 243 |
+
offset += index * index_strides[i];
|
| 244 |
+
}
|
| 245 |
+
device void * out = (device void*)((device char*)outputData + offsets[thread_index].x + offset);
|
| 246 |
+
constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y);
|
| 247 |
+
atomic_fetch_add_relaxed<T>(out, *in);
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
template
|
| 251 |
+
[[host_name("index_put_accumulate_32bit_float_idx32")]]
|
| 252 |
+
kernel void atomic_index_put_accumulate<float, uint3>(
|
| 253 |
+
constant IndexAB * indexAB [[buffer(0)]],
|
| 254 |
+
constant void * indexSizes [[buffer(1)]],
|
| 255 |
+
constant void * indexStrides [[buffer(2)]],
|
| 256 |
+
constant uint3 * offsets [[buffer(3)]],
|
| 257 |
+
constant void * inputData [[buffer(4)]],
|
| 258 |
+
device void * outputData [[buffer(5)]],
|
| 259 |
+
constant uint32_t & num_indices [[buffer(6)]],
|
| 260 |
+
uint thread_index [[thread_position_in_grid]]);
|
| 261 |
+
|
| 262 |
+
template
|
| 263 |
+
[[host_name("index_put_accumulate_32bit_float_idx64")]]
|
| 264 |
+
kernel void atomic_index_put_accumulate<float, ulong3>(
|
| 265 |
+
constant IndexAB * indexAB [[buffer(0)]],
|
| 266 |
+
constant void * indexSizes [[buffer(1)]],
|
| 267 |
+
constant void * indexStrides [[buffer(2)]],
|
| 268 |
+
constant ulong3 * offsets [[buffer(3)]],
|
| 269 |
+
constant void * inputData [[buffer(4)]],
|
| 270 |
+
device void * outputData [[buffer(5)]],
|
| 271 |
+
constant uint32_t & num_indices [[buffer(6)]],
|
| 272 |
+
uint thread_index [[thread_position_in_grid]]);
|
| 273 |
+
|
| 274 |
+
template
|
| 275 |
+
[[host_name("index_put_accumulate_32bit_int_idx32")]]
|
| 276 |
+
kernel void index_put_accumulate_native_dtypes<atomic_int, int, uint3>(
|
| 277 |
+
constant IndexAB * indexAB [[buffer(0)]],
|
| 278 |
+
constant void * indexSizes [[buffer(1)]],
|
| 279 |
+
constant void * indexStrides [[buffer(2)]],
|
| 280 |
+
constant uint3 * offsets [[buffer(3)]],
|
| 281 |
+
constant void * inputData [[buffer(4)]],
|
| 282 |
+
device void * outputData [[buffer(5)]],
|
| 283 |
+
constant uint32_t & num_indices [[buffer(6)]],
|
| 284 |
+
uint thread_index [[thread_position_in_grid]]);
|
| 285 |
+
|
| 286 |
+
template
|
| 287 |
+
[[host_name("index_put_accumulate_32bit_int_idx64")]]
|
| 288 |
+
kernel void index_put_accumulate_native_dtypes<atomic_int, int, ulong3>(
|
| 289 |
+
constant IndexAB * indexAB [[buffer(0)]],
|
| 290 |
+
constant void * indexSizes [[buffer(1)]],
|
| 291 |
+
constant void * indexStrides [[buffer(2)]],
|
| 292 |
+
constant ulong3 * offsets [[buffer(3)]],
|
| 293 |
+
constant void * inputData [[buffer(4)]],
|
| 294 |
+
device void * outputData [[buffer(5)]],
|
| 295 |
+
constant uint32_t & num_indices [[buffer(6)]],
|
| 296 |
+
uint thread_index [[thread_position_in_grid]]);
|
| 297 |
+
)INDEX_METAL";
|
| 298 |
+
|
| 299 |
+
static const char *SCATTER_OPS_TEMPLATE = R"METAL_SCATTER(
|
| 300 |
+
struct __attribute__ ((packed)) packed_uint5{{
|
| 301 |
+
uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u;
|
| 302 |
+
}};
|
| 303 |
+
|
| 304 |
+
template<typename Y, typename X>
|
| 305 |
+
Y cast(const X x);
|
| 306 |
+
|
| 307 |
+
template<>
|
| 308 |
+
{1} cast<{1}, {0}>(const {0} x) {{
|
| 309 |
+
return {2};
|
| 310 |
+
}}
|
| 311 |
+
|
| 312 |
+
kernel void scatter_kernel_5(uint linear_index [[thread_position_in_grid]],
|
| 313 |
+
constant void * src_ [[buffer(0)]],
|
| 314 |
+
device void * dst_ [[buffer(1)]],
|
| 315 |
+
constant packed_uint5 & size [[buffer(2)]],
|
| 316 |
+
constant packed_uint5 & stride [[buffer(3)]],
|
| 317 |
+
constant uint32_t & numel [[buffer(4)]]) {{
|
| 318 |
+
if (linear_index >= numel) return;
|
| 319 |
+
|
| 320 |
+
constant {0} * src = (constant {0} *)src_;
|
| 321 |
+
device {1} * dst = (device {1} *)dst_;
|
| 322 |
+
|
| 323 |
+
packed_uint5 local_index;
|
| 324 |
+
local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x;
|
| 325 |
+
local_index.y = linear_index / (size.u * size.w * size.z) % size.y;
|
| 326 |
+
local_index.z = linear_index / (size.u * size.w) % size.z;
|
| 327 |
+
local_index.w = linear_index / size.u % size.w;
|
| 328 |
+
local_index.u = linear_index % size.u;
|
| 329 |
+
|
| 330 |
+
packed_uint5 strided_index;
|
| 331 |
+
strided_index.x = local_index.x * stride.x;
|
| 332 |
+
strided_index.y = local_index.y * stride.y;
|
| 333 |
+
strided_index.z = local_index.z * stride.z;
|
| 334 |
+
strided_index.w = local_index.w * stride.w;
|
| 335 |
+
strided_index.u = local_index.u * stride.u;
|
| 336 |
+
|
| 337 |
+
dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u] = cast<{1}>(src[linear_index]);
|
| 338 |
+
}}
|
| 339 |
+
|
| 340 |
+
kernel void scatter_kernel_4(uint linear_index [[thread_position_in_grid]],
|
| 341 |
+
constant void * src_ [[buffer(0)]],
|
| 342 |
+
device void * dst_ [[buffer(1)]],
|
| 343 |
+
constant packed_uint4 & size [[buffer(2)]],
|
| 344 |
+
constant packed_uint4 & stride [[buffer(3)]],
|
| 345 |
+
constant uint32_t & numel [[buffer(4)]]) {{
|
| 346 |
+
if (linear_index >= numel) return;
|
| 347 |
+
|
| 348 |
+
constant {0} * src = (constant {0} *)src_;
|
| 349 |
+
device {1} * dst = (device {1} *)dst_;
|
| 350 |
+
|
| 351 |
+
packed_uint4 local_index;
|
| 352 |
+
local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
|
| 353 |
+
local_index.y = linear_index / (size[3] * size[2]) % size[1];
|
| 354 |
+
local_index.z = linear_index / size[3] % size[2];
|
| 355 |
+
local_index.w = linear_index % size[3];
|
| 356 |
+
|
| 357 |
+
const packed_uint4 strided_index = local_index * stride;
|
| 358 |
+
dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w] = cast<{1}>(src[linear_index]);
|
| 359 |
+
}}
|
| 360 |
+
|
| 361 |
+
kernel void scatter_kernel_3(uint linear_index [[thread_position_in_grid]],
|
| 362 |
+
constant void * src_ [[buffer(0)]],
|
| 363 |
+
device void * dst_ [[buffer(1)]],
|
| 364 |
+
constant packed_uint3 & size [[buffer(2)]],
|
| 365 |
+
constant packed_uint3 & stride [[buffer(3)]],
|
| 366 |
+
constant uint32_t & numel [[buffer(4)]]) {{
|
| 367 |
+
if (linear_index >= numel) return;
|
| 368 |
+
|
| 369 |
+
constant {0} * src = (constant {0} *)src_;
|
| 370 |
+
device {1} * dst = (device {1} *)dst_;
|
| 371 |
+
|
| 372 |
+
packed_uint3 local_index;
|
| 373 |
+
local_index.x = linear_index / (size[2] * size[1]) % size[0];
|
| 374 |
+
local_index.y = linear_index / size[2] % size[1];
|
| 375 |
+
local_index.z = linear_index % size[2];
|
| 376 |
+
|
| 377 |
+
const packed_uint3 strided_index = local_index * stride;
|
| 378 |
+
dst[strided_index.x + strided_index.y + strided_index.z] = cast<{1}>(src[linear_index]);
|
| 379 |
+
}}
|
| 380 |
+
|
| 381 |
+
kernel void scatter_kernel_2(uint linear_index [[thread_position_in_grid]],
|
| 382 |
+
constant void * src_ [[buffer(0)]],
|
| 383 |
+
device void * dst_ [[buffer(1)]],
|
| 384 |
+
constant packed_uint2 & size [[buffer(2)]],
|
| 385 |
+
constant packed_uint2 & stride [[buffer(3)]],
|
| 386 |
+
constant uint32_t & numel [[buffer(4)]]) {{
|
| 387 |
+
if (linear_index >= numel) return;
|
| 388 |
+
|
| 389 |
+
constant {0} * src = (constant {0} *)src_;
|
| 390 |
+
device {1} * dst = (device {1} *)dst_;
|
| 391 |
+
|
| 392 |
+
packed_uint2 local_index;
|
| 393 |
+
local_index.x = linear_index / size[1] % size[0];
|
| 394 |
+
local_index.y = linear_index % size[1];
|
| 395 |
+
|
| 396 |
+
const packed_uint2 strided_index = local_index * stride;
|
| 397 |
+
dst[strided_index.x + strided_index.y] = cast<{1}>(src[linear_index]);
|
| 398 |
+
}}
|
| 399 |
+
|
| 400 |
+
kernel void scatter_kernel_1(uint linear_index [[thread_position_in_grid]],
|
| 401 |
+
constant void * src_ [[buffer(0)]],
|
| 402 |
+
device void * dst_ [[buffer(1)]],
|
| 403 |
+
constant int & size [[buffer(2)]],
|
| 404 |
+
constant int & stride [[buffer(3)]],
|
| 405 |
+
constant uint32_t & numel [[buffer(4)]]) {{
|
| 406 |
+
if (linear_index >= numel) return;
|
| 407 |
+
|
| 408 |
+
constant {0} * src = (constant {0} *)src_;
|
| 409 |
+
device {1} * dst = (device {1} *)dst_;
|
| 410 |
+
|
| 411 |
+
const int local_index = linear_index % size;
|
| 412 |
+
const int strided_index = local_index * stride;
|
| 413 |
+
dst[strided_index] = cast<{1}>(src[linear_index]);
|
| 414 |
+
}}
|
| 415 |
+
)METAL_SCATTER";
|
| 416 |
+
|
| 417 |
+
static const char *GATHER_OPS_TEMPLATE = R"METAL_GATHER(
|
| 418 |
+
struct __attribute__ ((packed)) packed_uint5{{
|
| 419 |
+
uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u;
|
| 420 |
+
}};
|
| 421 |
+
|
| 422 |
+
template<typename Y, typename X>
|
| 423 |
+
Y cast(const X x);
|
| 424 |
+
|
| 425 |
+
template<>
|
| 426 |
+
{1} cast<{1}, {0}>(const {0} x) {{
|
| 427 |
+
return {2};
|
| 428 |
+
}}
|
| 429 |
+
|
| 430 |
+
kernel void gather_kernel_5(uint linear_index [[thread_position_in_grid]],
|
| 431 |
+
constant void * src_ [[buffer(0)]],
|
| 432 |
+
device void * dst_ [[buffer(1)]],
|
| 433 |
+
constant packed_uint5 & size [[buffer(2)]],
|
| 434 |
+
constant packed_uint5 & stride [[buffer(3)]],
|
| 435 |
+
constant uint32_t & numel [[buffer(4)]]) {{
|
| 436 |
+
if (linear_index >= numel) return;
|
| 437 |
+
|
| 438 |
+
constant {0} * src = (constant {0} *)src_;
|
| 439 |
+
device {1} * dst = (device {1} *)dst_;
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
packed_uint5 local_index;
|
| 443 |
+
local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x;
|
| 444 |
+
local_index.y = linear_index / (size.u * size.w * size.z) % size.y;
|
| 445 |
+
local_index.z = linear_index / (size.u * size.w) % size.z;
|
| 446 |
+
local_index.w = linear_index / size.u % size.w;
|
| 447 |
+
local_index.u = linear_index % size.u;
|
| 448 |
+
|
| 449 |
+
packed_uint5 strided_index;
|
| 450 |
+
strided_index.x = local_index.x * stride.x;
|
| 451 |
+
strided_index.y = local_index.y * stride.y;
|
| 452 |
+
strided_index.z = local_index.z * stride.z;
|
| 453 |
+
strided_index.w = local_index.w * stride.w;
|
| 454 |
+
strided_index.u = local_index.u * stride.u;
|
| 455 |
+
|
| 456 |
+
dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u]);
|
| 457 |
+
}}
|
| 458 |
+
|
| 459 |
+
kernel void gather_kernel_4(uint linear_index [[thread_position_in_grid]],
|
| 460 |
+
constant void * src_ [[buffer(0)]],
|
| 461 |
+
device void * dst_ [[buffer(1)]],
|
| 462 |
+
constant packed_uint4 & size [[buffer(2)]],
|
| 463 |
+
constant packed_uint4 & stride [[buffer(3)]],
|
| 464 |
+
constant uint32_t & numel [[buffer(4)]]) {{
|
| 465 |
+
if (linear_index >= numel) return;
|
| 466 |
+
|
| 467 |
+
constant {0} * src = (constant {0} *)src_;
|
| 468 |
+
device {1} * dst = (device {1} *)dst_;
|
| 469 |
+
|
| 470 |
+
packed_uint4 local_index;
|
| 471 |
+
local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
|
| 472 |
+
local_index.y = linear_index / (size[3] * size[2]) % size[1];
|
| 473 |
+
local_index.z = linear_index / size[3] % size[2];
|
| 474 |
+
local_index.w = linear_index % size[3];
|
| 475 |
+
|
| 476 |
+
const packed_uint4 strided_index = local_index * stride;
|
| 477 |
+
dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w]);
|
| 478 |
+
}}
|
| 479 |
+
|
| 480 |
+
kernel void gather_kernel_3(uint linear_index [[thread_position_in_grid]],
|
| 481 |
+
constant void * src_ [[buffer(0)]],
|
| 482 |
+
device void * dst_ [[buffer(1)]],
|
| 483 |
+
constant packed_uint3 & size [[buffer(2)]],
|
| 484 |
+
constant packed_uint3 & stride [[buffer(3)]],
|
| 485 |
+
constant uint32_t & numel [[buffer(4)]]) {{
|
| 486 |
+
if (linear_index >= numel) return;
|
| 487 |
+
|
| 488 |
+
constant {0} * src = (constant {0} *)src_;
|
| 489 |
+
device {1} * dst = (device {1} *)dst_;
|
| 490 |
+
|
| 491 |
+
packed_uint3 local_index;
|
| 492 |
+
local_index.x = linear_index / (size[2] * size[1]) % size[0];
|
| 493 |
+
local_index.y = linear_index / size[2] % size[1];
|
| 494 |
+
local_index.z = linear_index % size[2];
|
| 495 |
+
|
| 496 |
+
const packed_uint3 strided_index = local_index * stride;
|
| 497 |
+
dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z]);
|
| 498 |
+
}}
|
| 499 |
+
|
| 500 |
+
kernel void gather_kernel_2(uint linear_index [[thread_position_in_grid]],
|
| 501 |
+
constant void * src_ [[buffer(0)]],
|
| 502 |
+
device void * dst_ [[buffer(1)]],
|
| 503 |
+
constant packed_uint2 & size [[buffer(2)]],
|
| 504 |
+
constant packed_uint2 & stride [[buffer(3)]],
|
| 505 |
+
constant uint32_t & numel [[buffer(4)]]) {{
|
| 506 |
+
if (linear_index >= numel) return;
|
| 507 |
+
|
| 508 |
+
constant {0} * src = (constant {0} *)src_;
|
| 509 |
+
device {1} * dst = (device {1} *)dst_;
|
| 510 |
+
|
| 511 |
+
packed_uint2 local_index;
|
| 512 |
+
local_index.x = linear_index / size[1] % size[0];
|
| 513 |
+
local_index.y = linear_index % size[1];
|
| 514 |
+
|
| 515 |
+
const packed_uint2 strided_index = local_index * stride;
|
| 516 |
+
dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y]);
|
| 517 |
+
}}
|
| 518 |
+
|
| 519 |
+
kernel void gather_kernel_1(uint linear_index [[thread_position_in_grid]],
|
| 520 |
+
constant void * src_ [[buffer(0)]],
|
| 521 |
+
device void * dst_ [[buffer(1)]],
|
| 522 |
+
constant int & size [[buffer(2)]],
|
| 523 |
+
constant int & stride [[buffer(3)]],
|
| 524 |
+
constant uint32_t & numel [[buffer(4)]]) {{
|
| 525 |
+
if (linear_index >= numel) return;
|
| 526 |
+
|
| 527 |
+
constant {0} * src = (constant {0} *)src_;
|
| 528 |
+
device {1} * dst = (device {1} *)dst_;
|
| 529 |
+
|
| 530 |
+
const int local_index = linear_index % size;
|
| 531 |
+
const int strided_index = local_index * stride;
|
| 532 |
+
dst[linear_index] = cast<{1}>(src[strided_index]);
|
| 533 |
+
}}
|
| 534 |
+
)METAL_GATHER";
|
| 535 |
+
} // namespace at::mps
|
.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSAllocator.h
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2022 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <ATen/mps/MPSAllocatorInterface.h>
|
| 6 |
+
#include <ATen/mps/MPSEvent.h>
|
| 7 |
+
#include <ATen/mps/MPSStream.h>
|
| 8 |
+
|
| 9 |
+
#include <cstdio>
|
| 10 |
+
#include <mutex>
|
| 11 |
+
#include <set>
|
| 12 |
+
#include <unordered_set>
|
| 13 |
+
#include <mach/vm_page_size.h>
|
| 14 |
+
#include <c10/util/flat_hash_map.h>
|
| 15 |
+
|
| 16 |
+
// this implementation is based on CUDACachingAllocator.
|
| 17 |
+
// It utilizes Metal Heaps to improve the performance with buffer allocation.
|
| 18 |
+
// Do not include this header. Use MPSAllocatorInterface.h instead.
|
| 19 |
+
// TODO: Unify the logic with CUDACachingAllocator and remove redundant code.
|
| 20 |
+
namespace at::mps::HeapAllocator {
|
| 21 |
+
|
| 22 |
+
static const size_t kMaxSmallAlloc = MB(1); // largest "small" allocation is 1 MiB
|
| 23 |
+
static const size_t kMinLargeAlloc = MB(10); // allocations between 1 and 10 MiB may use kLargeHeap
|
| 24 |
+
static const size_t kRoundLarge = MB(2); // round up large allocations to 2 MiB
|
| 25 |
+
static const size_t kSmallHeap = MB(8); // "small" allocations are packed in 8 MiB heaps
|
| 26 |
+
static const size_t kLargeHeap = MB(32); // "large" allocations may be packed in 32 MiB heaps
|
| 27 |
+
static const size_t kXLargeHeapD = MB(128); // "extra large" allocations on Discrete devices may be packed in 128 MiB heaps
|
| 28 |
+
static const size_t kXLargeHeapU = MB(1024); // "extra large" allocations on Unified devices may be packed in 1 GiB heaps
|
| 29 |
+
static const size_t kMaxScalarAlloc = (sizeof(int64_t)); // largest "scalar" allocation
|
| 30 |
+
|
| 31 |
+
// buffer pools could be customized with a combination of usage flags
|
| 32 |
+
enum UsageFlags : uint32_t {
|
| 33 |
+
PRIVATE = 0,
|
| 34 |
+
SMALL = (1 << 0), // small heaps have sizes of kSmallHeap, and large ones kLargeHeap
|
| 35 |
+
SHARED = (1 << 1), // shared pools allocated on devices with unified memory; otherwise, private between host/device
|
| 36 |
+
MANAGED = (1 << 2), // managed storage mode
|
| 37 |
+
HAZARD = (1 << 3), // enables Automatic Hazard Tracking for the resources allocated on the pool
|
| 38 |
+
SCALAR = (1 << 4), // used to import CPU scalar values to GPU and use them in MPS Stream
|
| 39 |
+
};
|
| 40 |
+
// debug verbosity flags
|
| 41 |
+
enum DebugVerbosity : uint32_t {
|
| 42 |
+
SILENT = 0,
|
| 43 |
+
PROFILING = (1 << 0), // print generic profiling data for total system memory usage
|
| 44 |
+
ALLOCATIONS = (1 << 1), // print buffer allocations
|
| 45 |
+
RECYCLES = (1 << 2), // print buffer recycling
|
| 46 |
+
RELEASES = (1 << 3), // print buffer releases
|
| 47 |
+
LARGE_ONLY = (1 << 4), // only log large buffer pool transactions
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
+
struct HeapBlock;
|
| 51 |
+
|
| 52 |
+
struct BufferBlock {
|
| 53 |
+
id<MTLBuffer> buffer;
|
| 54 |
+
void* cpu_ptr = nullptr; // stores the pointer to CPU mapping of a Shared MTLBuffer
|
| 55 |
+
size_t size; // size after alignment
|
| 56 |
+
size_t requested_size; // requested size (before alignment)
|
| 57 |
+
// buffer shape is used for retrieving base of views in cached graphs
|
| 58 |
+
std::vector<int64_t> shape;
|
| 59 |
+
bool in_use = false;
|
| 60 |
+
HeapBlock* heap;
|
| 61 |
+
id_t buf_id;
|
| 62 |
+
// counter to candidate least recently used buffers for garbage collection
|
| 63 |
+
uint32_t gc_count = 0;
|
| 64 |
+
uint32_t use_count = 0;
|
| 65 |
+
// counter to assign unique ids to buffer blocks
|
| 66 |
+
static uint64_t buffer_counter;
|
| 67 |
+
// Metal events used to sync GPU/CPU operations on the shared-storage buffers
|
| 68 |
+
MPSEventPtr event;
|
| 69 |
+
|
| 70 |
+
BufferBlock(size_t Size, size_t RequestedSize = 0, const id<MTLBuffer> Buffer = nullptr,
|
| 71 |
+
HeapBlock* Heap = nullptr) :
|
| 72 |
+
buffer(Buffer), size(Size), requested_size(RequestedSize),
|
| 73 |
+
heap(Heap), buf_id(Buffer ? ++buffer_counter : 0) { }
|
| 74 |
+
|
| 75 |
+
static bool Comparator(const BufferBlock* a, const BufferBlock* b) {
|
| 76 |
+
return (a->size != b->size) ? a->size < b->size : (uintptr_t)a->buffer < (uintptr_t)b->buffer;
|
| 77 |
+
}
|
| 78 |
+
static size_t alignUp(size_t Size, size_t Alignment) {
|
| 79 |
+
assert(((Alignment - 1) & Alignment) == 0);
|
| 80 |
+
return ((Size + Alignment - 1) & ~(Alignment - 1));
|
| 81 |
+
}
|
| 82 |
+
uint32_t retainCount() const { return [buffer retainCount]; }
|
| 83 |
+
};
|
| 84 |
+
typedef bool (*BufferComparison)(const BufferBlock*, const BufferBlock*);
|
| 85 |
+
|
| 86 |
+
struct BufferPool;
|
| 87 |
+
struct AllocParams {
|
| 88 |
+
AllocParams(size_t Alloc_Size, size_t Requested_Size, BufferPool* Pool) :
|
| 89 |
+
search_key(Alloc_Size), pool(Pool), requested_size(Requested_Size) { }
|
| 90 |
+
size_t size() const { return search_key.size; }
|
| 91 |
+
|
| 92 |
+
BufferBlock search_key;
|
| 93 |
+
BufferPool* pool;
|
| 94 |
+
BufferBlock* buffer_block = nullptr;
|
| 95 |
+
size_t requested_size;
|
| 96 |
+
// true if we exceed the low watermark limit. In this case
|
| 97 |
+
// we apply strategies to relieve the pressure before allocation.
|
| 98 |
+
bool has_memory_pressure = false;
|
| 99 |
+
// true if we're allocating on a unified memory device
|
| 100 |
+
bool has_unified_memory = true;
|
| 101 |
+
};
|
| 102 |
+
|
| 103 |
+
struct HeapBlock {
|
| 104 |
+
id<MTLHeap> heap;
|
| 105 |
+
struct { size_t total, available; } size;
|
| 106 |
+
BufferPool* pool;
|
| 107 |
+
unsigned int n_buffers = 0;
|
| 108 |
+
id_t heap_id;
|
| 109 |
+
// indicates if we split this heap to sub-allocate 'several' buffers (otherwise single buffer)
|
| 110 |
+
bool is_split;
|
| 111 |
+
// counter to assign unique ids to heap blocks
|
| 112 |
+
static uint64_t heap_counter;
|
| 113 |
+
|
| 114 |
+
HeapBlock(size_t Size, const id<MTLHeap> Heap = nullptr, BufferPool *Pool = nullptr) :
|
| 115 |
+
heap(Heap), size({.total = Size, .available = Size}), pool(Pool),
|
| 116 |
+
heap_id(Heap ? ++heap_counter : 0), is_split(true) { }
|
| 117 |
+
|
| 118 |
+
static MTLResourceOptions getOptions(uint32_t usage) {
|
| 119 |
+
// TODO: check the caching performance of write-combined mode
|
| 120 |
+
MTLResourceOptions options = MTLResourceCPUCacheModeDefaultCache;
|
| 121 |
+
|
| 122 |
+
if (usage & UsageFlags::MANAGED)
|
| 123 |
+
options |= MTLResourceStorageModeManaged;
|
| 124 |
+
else if (usage & UsageFlags::SHARED)
|
| 125 |
+
options |= MTLResourceStorageModeShared;
|
| 126 |
+
else
|
| 127 |
+
options |= MTLResourceStorageModePrivate;
|
| 128 |
+
|
| 129 |
+
options |= (usage & UsageFlags::HAZARD) ? MTLResourceHazardTrackingModeTracked : MTLResourceHazardTrackingModeUntracked;
|
| 130 |
+
|
| 131 |
+
return options;
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
static HeapBlock* createHeapBlock(AllocParams& params, id<MTLDevice> device, uint32_t usage) {
|
| 135 |
+
HeapBlock *heapBlock = nullptr;
|
| 136 |
+
bool is_split = true;
|
| 137 |
+
const size_t size = params.size();
|
| 138 |
+
MTLHeapDescriptor *d = [MTLHeapDescriptor new];
|
| 139 |
+
if (d) {
|
| 140 |
+
const size_t kXLargeHeap = params.has_unified_memory ? kXLargeHeapU : kXLargeHeapD;
|
| 141 |
+
if (size <= kMaxSmallAlloc) {
|
| 142 |
+
d.size = kSmallHeap;
|
| 143 |
+
} else if (size < kMinLargeAlloc) {
|
| 144 |
+
d.size = kLargeHeap;
|
| 145 |
+
} else if (size < kXLargeHeap / 2 && !params.has_memory_pressure) {
|
| 146 |
+
d.size = kXLargeHeap;
|
| 147 |
+
} else {
|
| 148 |
+
d.size = kRoundLarge * ((size + kRoundLarge - 1) / kRoundLarge);
|
| 149 |
+
is_split = false;
|
| 150 |
+
}
|
| 151 |
+
d.storageMode = (usage & UsageFlags::SHARED) ? MTLStorageModeShared : MTLStorageModePrivate;
|
| 152 |
+
d.cpuCacheMode = MTLCPUCacheModeDefaultCache;
|
| 153 |
+
// this automatically handles Metal buffer access synchronizations at the
|
| 154 |
+
// cost of slightly lower performance.
|
| 155 |
+
d.hazardTrackingMode = (usage & UsageFlags::HAZARD) ? MTLHazardTrackingModeTracked : MTLHazardTrackingModeUntracked;
|
| 156 |
+
d.resourceOptions = getOptions(usage);
|
| 157 |
+
d.type = MTLHeapTypeAutomatic;
|
| 158 |
+
id<MTLHeap> heap = [device newHeapWithDescriptor: d];
|
| 159 |
+
if (heap) {
|
| 160 |
+
[heap setPurgeableState:MTLPurgeableStateNonVolatile];
|
| 161 |
+
const size_t heap_size = heapAvailableSize(heap);
|
| 162 |
+
heapBlock = new HeapBlock(heap_size, heap, params.pool);
|
| 163 |
+
if (heapBlock) {
|
| 164 |
+
heapBlock->is_split = is_split;
|
| 165 |
+
}
|
| 166 |
+
}
|
| 167 |
+
[d release];
|
| 168 |
+
}
|
| 169 |
+
return heapBlock;
|
| 170 |
+
}
|
| 171 |
+
static bool Comparator(const HeapBlock* a, const HeapBlock* b) {
|
| 172 |
+
return (a->size.available != b->size.available) ? a->size.available < b->size.available :
|
| 173 |
+
(uintptr_t)a->heap < (uintptr_t)b->heap;
|
| 174 |
+
}
|
| 175 |
+
static NSUInteger heapAvailableSize(id<MTLHeap> heap, size_t Alignment = vm_page_size) {
|
| 176 |
+
return [heap maxAvailableSizeWithAlignment:Alignment];
|
| 177 |
+
}
|
| 178 |
+
NSUInteger Size() {
|
| 179 |
+
return [heap size];
|
| 180 |
+
}
|
| 181 |
+
id<MTLBuffer> newMTLBuffer(size_t length, uint32_t usage) {
|
| 182 |
+
id<MTLBuffer> buf = [heap newBufferWithLength:length options:getOptions(usage)];
|
| 183 |
+
if (buf) {
|
| 184 |
+
updateAvailableSize();
|
| 185 |
+
n_buffers++;
|
| 186 |
+
}
|
| 187 |
+
return buf;
|
| 188 |
+
}
|
| 189 |
+
// returns the retainCount before releasing the buffer
|
| 190 |
+
uint32_t releaseMTLBuffer(id<MTLBuffer>& buffer) {
|
| 191 |
+
const uint32_t retainCount = [buffer retainCount];
|
| 192 |
+
[buffer release];
|
| 193 |
+
buffer = nil;
|
| 194 |
+
updateAvailableSize();
|
| 195 |
+
n_buffers--;
|
| 196 |
+
return retainCount;
|
| 197 |
+
}
|
| 198 |
+
// returns the retainCount before releasing the heap
|
| 199 |
+
uint32_t releaseMTLHeap() {
|
| 200 |
+
const uint32_t retainCount = [heap retainCount];
|
| 201 |
+
TORCH_INTERNAL_ASSERT(!n_buffers); // assert if heap isn't empty
|
| 202 |
+
[heap setPurgeableState:MTLPurgeableStateEmpty];
|
| 203 |
+
[heap release];
|
| 204 |
+
heap = nil;
|
| 205 |
+
size.available = 0;
|
| 206 |
+
return retainCount;
|
| 207 |
+
}
|
| 208 |
+
uint32_t retainCount() const { return [heap retainCount]; }
|
| 209 |
+
void updateAvailableSize() { size.available = heapAvailableSize(heap); }
|
| 210 |
+
};
|
| 211 |
+
typedef bool (*HeapComparison)(const HeapBlock*, const HeapBlock*);
|
| 212 |
+
|
| 213 |
+
struct BufferPool {
|
| 214 |
+
enum class Kind {
|
| 215 |
+
PRIVATE_SMALL,
|
| 216 |
+
PRIVATE_LARGE,
|
| 217 |
+
SHARED_SMALL,
|
| 218 |
+
SHARED_LARGE,
|
| 219 |
+
SCALAR,
|
| 220 |
+
};
|
| 221 |
+
|
| 222 |
+
BufferPool(const id<MTLDevice> Device, uint32_t Usage) :
|
| 223 |
+
device(Device), usage(Usage),
|
| 224 |
+
heaps(HeapBlock::Comparator), available_buffers(BufferBlock::Comparator) { }
|
| 225 |
+
|
| 226 |
+
const id<MTLDevice> device;
|
| 227 |
+
// usage flags to customize the pool for various purposes (see UsageFlags enum)
|
| 228 |
+
const uint32_t usage;
|
| 229 |
+
// total number of buffers in the pool
|
| 230 |
+
uint32_t n_buffers = 0;
|
| 231 |
+
// total allocations size on this pool
|
| 232 |
+
size_t allocated_size = 0;
|
| 233 |
+
// total memory available in the pool
|
| 234 |
+
size_t available_size = 0;
|
| 235 |
+
// list of heaps ordered by their "available" (not total) memory size
|
| 236 |
+
std::set<HeapBlock*, HeapComparison> heaps;
|
| 237 |
+
// list of only "available" buffers in the pool (i.e., buffers not in-use)
|
| 238 |
+
std::set<BufferBlock*, BufferComparison> available_buffers;
|
| 239 |
+
// list of buffers that are in a state of "limbo" where they've already been freed
|
| 240 |
+
// from PyTorch-side, but were not returned to pool due to still being
|
| 241 |
+
// in-use by command buffers with retainCount > 1. In this state, the buffer is
|
| 242 |
+
// neither ready to be recycled, nor could be returned to pool as available.
|
| 243 |
+
// These buffers will be returned to pool once the command buffer's
|
| 244 |
+
// completionHandler callbacks are called.
|
| 245 |
+
std::unordered_set<BufferBlock*> buffers_pending_free;
|
| 246 |
+
// list of heaps pending size update
|
| 247 |
+
std::unordered_set<HeapBlock*> heaps_pending_update;
|
| 248 |
+
};
|
| 249 |
+
|
| 250 |
+
class MPSHeapAllocatorImpl {
|
| 251 |
+
public:
|
| 252 |
+
explicit MPSHeapAllocatorImpl() :
|
| 253 |
+
m_device(at::mps::MPSDevice::getInstance()->device()),
|
| 254 |
+
m_max_buffer_size([m_device maxBufferLength]),
|
| 255 |
+
m_stream(getDefaultMPSStream()),
|
| 256 |
+
m_event_pool(getMPSEventPool()) {
|
| 257 |
+
init_allocator();
|
| 258 |
+
}
|
| 259 |
+
~MPSHeapAllocatorImpl() {
|
| 260 |
+
emptyCache();
|
| 261 |
+
}
|
| 262 |
+
// interface exposed to at::Allocator
|
| 263 |
+
id<MTLBuffer> malloc(size_t size, uint32_t usage);
|
| 264 |
+
// frees a buffer and returns it into buffer pool
|
| 265 |
+
void free(void* ptr);
|
| 266 |
+
// releases all the cached buffers and their associated heaps
|
| 267 |
+
void emptyCache();
|
| 268 |
+
// free inactive buffers that are pending to be freed
|
| 269 |
+
void freeInactiveBuffers();
|
| 270 |
+
// returns true if buffer was allocated from the shared pool
|
| 271 |
+
bool isSharedBuffer(const void* ptr);
|
| 272 |
+
// get the requested unaligned size of an MTLBuffer
|
| 273 |
+
ssize_t getUnalignedBufferSize(const void* ptr);
|
| 274 |
+
// set the shape of a base tensor from a view tensor
|
| 275 |
+
void setBufferShape(const void* ptr, const IntArrayRef& shape);
|
| 276 |
+
// retrieve the shape of a base tensor from a view tensor
|
| 277 |
+
IntArrayRef getBufferShape(const void* ptr);
|
| 278 |
+
// get the unique ID of the buffer
|
| 279 |
+
id_t getBufferId(const void* ptr);
|
| 280 |
+
// allocate a buffer from a specialized pool to import CPU scalars into GPU
|
| 281 |
+
id<MTLBuffer> allocScalarBufferWithValue(void* value, size_t size);
|
| 282 |
+
// returns a CPU-mapping of the input buffer and its retainCount,
|
| 283 |
+
// if only it has Shared storage-mode and allocated on MPSAllocator
|
| 284 |
+
std::pair<const void*, uint32_t> getSharedBufferPtr(const void* buffer);
|
| 285 |
+
// records events for a list of MTLBuffers (list is used to lock the mutex once)
|
| 286 |
+
// returns true if records any event (given if passed buffers exist and are shared-storage)
|
| 287 |
+
bool recordEvents(c10::ArrayRef<const void*> buffers);
|
| 288 |
+
// waits for the event to signal the completion of GPU execution
|
| 289 |
+
// on the passed shared buffers (list is used to lock the mutex once)
|
| 290 |
+
// returns true if actually waited on any event
|
| 291 |
+
bool waitForEvents(c10::ArrayRef<const void*> buffers);
|
| 292 |
+
// this indicates how far (in Megabytes) the current total allocations are from the
|
| 293 |
+
// low watermark limit which is used to detect if we're under memory pressure
|
| 294 |
+
// This returns zero if we've reached the low watermark limit
|
| 295 |
+
ssize_t getLowWatermarkValue();
|
| 296 |
+
// (see m_low_watermark_ratio for description)
|
| 297 |
+
void setLowWatermarkRatio(double ratio);
|
| 298 |
+
// (see m_high_watermark_ratio for description)
|
| 299 |
+
void setHighWatermarkRatio(double ratio);
|
| 300 |
+
// (see m_low_watermark_limit for description)
|
| 301 |
+
size_t getLowWatermarkLimit() const { return m_low_watermark_limit; }
|
| 302 |
+
// (see m_max_total_allowed_size for description)
|
| 303 |
+
size_t getHighWatermarkLimit() const { return m_max_total_allowed_size; }
|
| 304 |
+
// (see m_total_allocated_memory for description)
|
| 305 |
+
size_t getTotalAllocatedMemory() const { return m_total_allocated_memory; }
|
| 306 |
+
// (see m_current_allocated_memory for description)
|
| 307 |
+
size_t getCurrentAllocatedMemory() const { return m_current_allocated_memory; }
|
| 308 |
+
// total GPU memory allocated in the process by Metal driver; including
|
| 309 |
+
// implicit allocations from MPS/MPSGraph frameworks and MPSHeapAllocatorImpl.
|
| 310 |
+
size_t getDriverAllocatedMemory() const { return current_allocated_size(); }
|
| 311 |
+
// recommended Max memory for Metal
|
| 312 |
+
size_t getRecommendedMaxMemory() const { return max_device_size(); }
|
| 313 |
+
// (see enum DebugVerbosity for description)
|
| 314 |
+
uint32_t getDebugVerbosity() const { return m_debug_verbosity; }
|
| 315 |
+
// returns the device that we allocate from
|
| 316 |
+
inline id<MTLDevice> Device() const { return m_device; }
|
| 317 |
+
|
| 318 |
+
// TODO: make a common function to do size unit conversions in PyTorch.
|
| 319 |
+
inline std::string format_size(uint64_t size) const;
|
| 320 |
+
|
| 321 |
+
private:
|
| 322 |
+
// (see m_high_watermark_ratio for description)
|
| 323 |
+
constexpr static double default_high_watermark_ratio = 1.7;
|
| 324 |
+
// we set the allowed upper bound to twice the size of recommendedMaxWorkingSetSize.
|
| 325 |
+
constexpr static double default_high_watermark_upper_bound = 2.0;
|
| 326 |
+
// (see m_low_watermark_ratio for description)
|
| 327 |
+
// on unified memory, we could allocate beyond the recommendedMaxWorkingSetSize
|
| 328 |
+
constexpr static double default_low_watermark_ratio_unified = 1.4;
|
| 329 |
+
constexpr static double default_low_watermark_ratio_discrete = 1.0;
|
| 330 |
+
|
| 331 |
+
const id<MTLDevice> m_device;
|
| 332 |
+
std::recursive_mutex m_mutex;
|
| 333 |
+
// allocated buffers by device pointer
|
| 334 |
+
ska::flat_hash_map<const void*, BufferBlock*> m_allocated_buffers;
|
| 335 |
+
// using a container for pools to simplify iterating them
|
| 336 |
+
ska::flat_hash_map<BufferPool::Kind, std::unique_ptr<BufferPool>> m_pools;
|
| 337 |
+
// total memory allocated by HeapAllocator (including blocks in pools)
|
| 338 |
+
size_t m_total_allocated_memory = 0;
|
| 339 |
+
// currently active memory allocations in use (i.e., blocks not in pools)
|
| 340 |
+
size_t m_current_allocated_memory = 0;
|
| 341 |
+
// max buffer size allowed by Metal
|
| 342 |
+
size_t m_max_buffer_size = 0;
|
| 343 |
+
// maximum total size allowed to be allocated
|
| 344 |
+
size_t m_max_total_allowed_size = 0;
|
| 345 |
+
// high watermark ratio is a hard limit for the total allowed allocations
|
| 346 |
+
// 0. : disables high watermark limit (may cause system failure if system-wide OOM occurs)
|
| 347 |
+
// 1. : recommended maximum allocation size (i.e., device.recommendedMaxWorkingSetSize)
|
| 348 |
+
// >1.: allows limits beyond the device.recommendedMaxWorkingSetSize
|
| 349 |
+
// e.g., value 0.95 means we allocate up to 95% of recommended maximum
|
| 350 |
+
// allocation size; beyond that, the allocations would fail with OOM error.
|
| 351 |
+
double m_high_watermark_ratio;
|
| 352 |
+
// low watermark ratio is a soft limit to attempt limiting memory allocations up to the lower watermark
|
| 353 |
+
// level by garbage collection or committing command buffers more frequently (a.k.a, adaptive commit).
|
| 354 |
+
// Value between 0 to m_high_watermark_ratio (setting 0.0 disables adaptive commit and garbage collection)
|
| 355 |
+
// e.g., value 0.9 means we 'attempt' to limit allocations up to 90% of recommended maximum
|
| 356 |
+
// allocation size.
|
| 357 |
+
double m_low_watermark_ratio;
|
| 358 |
+
// low watermark size limit (in Bytes) at the time we initialize the allocator
|
| 359 |
+
size_t m_low_watermark_limit;
|
| 360 |
+
// use "PYTORCH_DEBUG_MPS_ALLOCATOR" env-var to set debug verbosity
|
| 361 |
+
uint32_t m_debug_verbosity;
|
| 362 |
+
// default MPS stream
|
| 363 |
+
MPSStream* m_stream;
|
| 364 |
+
// we hold a reference to MPSEventPool so it could get destroyed after MPSAllocator
|
| 365 |
+
std::shared_ptr<MPSEventPool> m_event_pool;
|
| 366 |
+
|
| 367 |
+
void init_allocator();
|
| 368 |
+
void init_buffer_pools();
|
| 369 |
+
HeapBlock* get_free_heap(AllocParams& params);
|
| 370 |
+
bool get_free_buffer(AllocParams& params);
|
| 371 |
+
BufferBlock* get_allocated_buffer_block(const void* ptr);
|
| 372 |
+
BufferBlock* alloc_buffer_block(size_t size, uint32_t usage);
|
| 373 |
+
bool alloc_buffer(AllocParams& params);
|
| 374 |
+
void free_buffer(BufferBlock* buffer_block);
|
| 375 |
+
// returns true if the container heap is also released
|
| 376 |
+
bool release_buffer(BufferBlock* buffer_block, bool remove_empty_heap = true);
|
| 377 |
+
void release_buffers(BufferPool& pool);
|
| 378 |
+
bool release_available_cached_buffers(AllocParams& params);
|
| 379 |
+
bool release_cached_buffers();
|
| 380 |
+
// free unused cached blocks to reclaim GPU memory if memory pressure is high
|
| 381 |
+
void garbage_collect_cached_buffers(AllocParams& params);
|
| 382 |
+
// returns the suitable buffer pool type for the usage or
|
| 383 |
+
// requested/allocated sizes
|
| 384 |
+
BufferPool& get_pool(size_t requested_size, size_t aligned_size, uint32_t usage);
|
| 385 |
+
// returns the aligned allocation size that is optimized
|
| 386 |
+
// for the buffers to get reused frequently
|
| 387 |
+
size_t get_allocation_size(size_t size, uint32_t usage) const;
|
| 388 |
+
// maximum size of device memory available for allocation in current process
|
| 389 |
+
// Note: the recommendedMaxWorkingSetSize is typically 75% of the total system memory.
|
| 390 |
+
size_t max_device_size() const { return [m_device recommendedMaxWorkingSetSize]; }
|
| 391 |
+
// there are implicit allocations from MPS backend, so we need to query the 'device' for
|
| 392 |
+
// total allocated size instead of manually tracking in MPSAllocator
|
| 393 |
+
size_t current_allocated_size() const { return [m_device currentAllocatedSize]; }
|
| 394 |
+
|
| 395 |
+
bool trigger_memory_callbacks(BufferBlock* buffer_block, IMpsAllocatorCallback::EventType event) const {
|
| 396 |
+
for (const auto& name : MPSAllocatorCallbacksRegistry()->Keys()) {
|
| 397 |
+
MPSAllocatorCallbacksRegistry()->Create(name)->executeMPSAllocatorCallback(buffer_block ? buffer_block->buffer : nullptr, event);
|
| 398 |
+
}
|
| 399 |
+
return true;
|
| 400 |
+
}
|
| 401 |
+
};
|
| 402 |
+
|
| 403 |
+
} // namespace at::mps::HeapAllocator
|
.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSAllocatorInterface.h
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <c10/core/Allocator.h>
|
| 6 |
+
#include <c10/util/Registry.h>
|
| 7 |
+
#include <ATen/core/ATen_fwd.h>
|
| 8 |
+
|
| 9 |
+
#define MB(x) (x * 1048576UL)
|
| 10 |
+
|
| 11 |
+
namespace at::mps {
|
| 12 |
+
|
| 13 |
+
// this is a public interface to access MPSAllocator.
|
| 14 |
+
// Do not declare methods that would depend on MPS or Metal frameworks.
|
| 15 |
+
class IMPSAllocator : public c10::Allocator {
|
| 16 |
+
public:
|
| 17 |
+
// see the comments in MPSAllocator.h for the description of these methods.
|
| 18 |
+
virtual void emptyCache() const = 0;
|
| 19 |
+
virtual void freeInactiveBuffers() const = 0;
|
| 20 |
+
virtual ssize_t getUnalignedBufferSize(const void* ptr) const = 0;
|
| 21 |
+
virtual IntArrayRef getBufferShape(const void* ptr) const = 0;
|
| 22 |
+
virtual id_t getBufferId(const void* ptr) const = 0;
|
| 23 |
+
virtual void setBufferShape(const void* ptr, const IntArrayRef& shape) const = 0;
|
| 24 |
+
virtual bool isSharedBuffer(const void* ptr) const = 0;
|
| 25 |
+
virtual bool isSharedStorageSupported() const = 0;
|
| 26 |
+
virtual c10::DataPtr allocScalarBufferWithValue(void* value, size_t size) const = 0;
|
| 27 |
+
virtual std::string formatSize(size_t size) const = 0;
|
| 28 |
+
virtual void setLowWatermarkRatio(double ratio) const = 0;
|
| 29 |
+
virtual void setHighWatermarkRatio(double ratio) const = 0;
|
| 30 |
+
virtual ssize_t getLowWatermarkValue() const = 0;
|
| 31 |
+
virtual size_t getLowWatermarkLimit() const = 0;
|
| 32 |
+
virtual size_t getHighWatermarkLimit() const = 0;
|
| 33 |
+
virtual size_t getTotalAllocatedMemory() const = 0;
|
| 34 |
+
virtual size_t getCurrentAllocatedMemory() const = 0;
|
| 35 |
+
virtual size_t getDriverAllocatedMemory() const = 0;
|
| 36 |
+
virtual size_t getRecommendedMaxMemory() const = 0;
|
| 37 |
+
virtual std::pair<const void*, uint32_t> getSharedBufferPtr(const void* ptr) const = 0;
|
| 38 |
+
virtual bool recordEvents(c10::ArrayRef<const void*> buffers) const = 0;
|
| 39 |
+
virtual bool waitForEvents(c10::ArrayRef<const void*> buffers) const = 0;
|
| 40 |
+
};
|
| 41 |
+
|
| 42 |
+
class IMpsAllocatorCallback {
|
| 43 |
+
public:
|
| 44 |
+
enum class EventType {
|
| 45 |
+
ALLOCATED, // buffer got allocated to be used immediately
|
| 46 |
+
RECYCLED, // buffer pulled from free list to be reused
|
| 47 |
+
FREED, // buffer put to free list for future recycling
|
| 48 |
+
RELEASED, // buffer memory released
|
| 49 |
+
ALLOCATION_FAILED // buffer allocation failed
|
| 50 |
+
};
|
| 51 |
+
virtual ~IMpsAllocatorCallback() = default;
|
| 52 |
+
virtual void executeMPSAllocatorCallback(void* ptr, EventType event) = 0;
|
| 53 |
+
};
|
| 54 |
+
|
| 55 |
+
// MPS allocator will execute every registered callback when a block of memory is freed.
|
| 56 |
+
C10_DECLARE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback);
|
| 57 |
+
#define REGISTER_MPS_ALLOCATOR_CALLBACK(name, ...) \
|
| 58 |
+
C10_REGISTER_CLASS(MPSAllocatorCallbacksRegistry, name, __VA_ARGS__);
|
| 59 |
+
|
| 60 |
+
IMPSAllocator* getIMPSAllocator(bool sharedAllocator = false);
|
| 61 |
+
|
| 62 |
+
bool isMPSPinnedPtr(const void* data);
|
| 63 |
+
|
| 64 |
+
} // namespace at::mps
|
.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSDevice.h
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2022 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
#include <c10/core/Allocator.h>
|
| 5 |
+
#include <c10/macros/Macros.h>
|
| 6 |
+
#include <c10/util/Exception.h>
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
#ifdef __OBJC__
|
| 10 |
+
#include <Foundation/Foundation.h>
|
| 11 |
+
#include <Metal/Metal.h>
|
| 12 |
+
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
| 13 |
+
typedef id<MTLDevice> MTLDevice_t;
|
| 14 |
+
typedef id<MTLLibrary> MTLLibrary_t;
|
| 15 |
+
typedef id<MTLComputePipelineState> MTLComputePipelineState_t;
|
| 16 |
+
typedef id<MTLLibrary> MTLLibrary_t;
|
| 17 |
+
#else
|
| 18 |
+
typedef void* MTLDevice;
|
| 19 |
+
typedef void* MTLDevice_t;
|
| 20 |
+
typedef void* MTLLibrary_t;
|
| 21 |
+
typedef void* MTLComputePipelineState_t;
|
| 22 |
+
typedef void* MTLLibrary_t;
|
| 23 |
+
#endif
|
| 24 |
+
|
| 25 |
+
namespace at::mps {
|
| 26 |
+
|
| 27 |
+
// Helper enum to check if a MPSGraph op is supported in a given macOS version
|
| 28 |
+
enum class MacOSVersion : uint32_t {
|
| 29 |
+
MACOS_VER_13_1_PLUS = 0,
|
| 30 |
+
MACOS_VER_13_2_PLUS,
|
| 31 |
+
MACOS_VER_13_3_PLUS,
|
| 32 |
+
MACOS_VER_14_0_PLUS,
|
| 33 |
+
MACOS_VER_14_4_PLUS,
|
| 34 |
+
MACOS_VER_15_0_PLUS,
|
| 35 |
+
};
|
| 36 |
+
|
| 37 |
+
//-----------------------------------------------------------------
|
| 38 |
+
// MPSDevice
|
| 39 |
+
//
|
| 40 |
+
// MPSDevice is a singleton class that returns the default device
|
| 41 |
+
//-----------------------------------------------------------------
|
| 42 |
+
|
| 43 |
+
class TORCH_API MPSDevice {
|
| 44 |
+
public:
|
| 45 |
+
/**
|
| 46 |
+
* MPSDevice should not be cloneable.
|
| 47 |
+
*/
|
| 48 |
+
MPSDevice(MPSDevice& other) = delete;
|
| 49 |
+
/**
|
| 50 |
+
* MPSDevice should not be assignable.
|
| 51 |
+
*/
|
| 52 |
+
void operator=(const MPSDevice&) = delete;
|
| 53 |
+
/**
|
| 54 |
+
* Gets single instance of the Device.
|
| 55 |
+
*/
|
| 56 |
+
static MPSDevice* getInstance();
|
| 57 |
+
/**
|
| 58 |
+
* Returns the single device.
|
| 59 |
+
*/
|
| 60 |
+
MTLDevice_t device() {
|
| 61 |
+
return _mtl_device;
|
| 62 |
+
}
|
| 63 |
+
/**
|
| 64 |
+
* Returns whether running on Ventura or newer
|
| 65 |
+
*/
|
| 66 |
+
bool isMacOS13Plus(MacOSVersion version) const;
|
| 67 |
+
|
| 68 |
+
MTLComputePipelineState_t metalIndexingPSO(const std::string &kernel);
|
| 69 |
+
MTLLibrary_t getMetalIndexingLibrary();
|
| 70 |
+
|
| 71 |
+
~MPSDevice();
|
| 72 |
+
|
| 73 |
+
private:
|
| 74 |
+
static MPSDevice* _device;
|
| 75 |
+
MTLDevice_t _mtl_device;
|
| 76 |
+
MTLLibrary_t _mtl_indexing_library;
|
| 77 |
+
MPSDevice();
|
| 78 |
+
};
|
| 79 |
+
|
| 80 |
+
TORCH_API bool is_available();
|
| 81 |
+
TORCH_API bool is_macos_13_or_newer(MacOSVersion version);
|
| 82 |
+
TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);
|
| 83 |
+
|
| 84 |
+
} // namespace at::mps
|
.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSEvent.h
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <ATen/mps/MPSStream.h>
|
| 6 |
+
#include <ctime>
|
| 7 |
+
#include <stack>
|
| 8 |
+
|
| 9 |
+
namespace at::mps {
|
| 10 |
+
|
| 11 |
+
// NOTE: don't create instances of this class directly.
|
| 12 |
+
// Use MPSEventPool to acquire instances of MPSEvent.
|
| 13 |
+
class MPSEvent {
|
| 14 |
+
public:
|
| 15 |
+
explicit MPSEvent(id_t ID, MPSStream* stream, bool enable_timing);
|
| 16 |
+
~MPSEvent();
|
| 17 |
+
|
| 18 |
+
// records an event on the stream
|
| 19 |
+
void record(bool needsLock, bool syncEvent = false);
|
| 20 |
+
// makes all future work submitted to the stream wait for this event.
|
| 21 |
+
bool wait(bool needsLock, bool syncEvent = false);
|
| 22 |
+
// schedules a notifyListener callback for the event.
|
| 23 |
+
bool notify(bool needsLock, MTLSharedEventNotificationBlock block);
|
| 24 |
+
// checks if events are already signaled.
|
| 25 |
+
bool query() const;
|
| 26 |
+
// blocks the CPU thread until all the GPU work that were scheduled
|
| 27 |
+
// prior to recording this event are completed.
|
| 28 |
+
bool synchronize();
|
| 29 |
+
// resets this event with new parameters in case it gets reused from the event pool
|
| 30 |
+
void reset(MPSStream* stream, bool enable_timing);
|
| 31 |
+
// returns the unique ID of the event instance
|
| 32 |
+
id_t getID() const { return m_id; }
|
| 33 |
+
// returns the completion timestamp of the event
|
| 34 |
+
uint64_t getCompletionTime() const { return m_completion_time; }
|
| 35 |
+
// if already recorded, waits for cpu_sync_cv to be signaled
|
| 36 |
+
void waitForCpuSync();
|
| 37 |
+
|
| 38 |
+
private:
|
| 39 |
+
id_t m_id;
|
| 40 |
+
// enables measuring the completion time of the notifyListener of this event
|
| 41 |
+
bool m_enable_timing;
|
| 42 |
+
uint64_t m_signalCounter = 0;
|
| 43 |
+
MPSStream* m_stream = nullptr;
|
| 44 |
+
MTLSharedEvent_t m_event = nullptr;
|
| 45 |
+
MTLSharedEventListener* m_listener = nullptr;
|
| 46 |
+
// used to sync the events created on this Stream with CPU
|
| 47 |
+
std::mutex m_cpu_sync_mutex{};
|
| 48 |
+
std::condition_variable m_cpu_sync_cv{};
|
| 49 |
+
// CondVar predicate to sync the events created on this Stream with CPU
|
| 50 |
+
bool m_cpu_sync_completed = false;
|
| 51 |
+
// used to compute elapsed time
|
| 52 |
+
uint64_t m_completion_time = 0;
|
| 53 |
+
|
| 54 |
+
void recordLocked(bool syncEvent);
|
| 55 |
+
bool waitLocked(bool syncEvent);
|
| 56 |
+
bool notifyLocked(MTLSharedEventNotificationBlock block);
|
| 57 |
+
void notifyCpuSync();
|
| 58 |
+
static uint64_t getTime() {
|
| 59 |
+
return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW);
|
| 60 |
+
}
|
| 61 |
+
};
|
| 62 |
+
|
| 63 |
+
typedef std::unique_ptr<MPSEvent, std::function<void(MPSEvent*)>> MPSEventPtr;
|
| 64 |
+
|
| 65 |
+
class MPSEventPool {
|
| 66 |
+
public:
|
| 67 |
+
explicit MPSEventPool(MPSStream* default_stream);
|
| 68 |
+
~MPSEventPool();
|
| 69 |
+
|
| 70 |
+
MPSEventPtr acquireEvent(bool enable_timing, MPSStream* stream);
|
| 71 |
+
void emptyCache();
|
| 72 |
+
|
| 73 |
+
// these are mainly used for MPSHooks and torch.mps.Event() bindings
|
| 74 |
+
id_t acquireEvent(bool enable_timing);
|
| 75 |
+
void releaseEvent(id_t event_id);
|
| 76 |
+
void recordEvent(id_t event_id, bool syncEvent);
|
| 77 |
+
void waitForEvent(id_t event_id, bool syncEvent);
|
| 78 |
+
void synchronizeEvent(id_t event_id);
|
| 79 |
+
bool queryEvent(id_t event_id);
|
| 80 |
+
// returns elapsed time between two recorded events in milliseconds
|
| 81 |
+
double elapsedTime(id_t start_event_id, id_t end_event_id);
|
| 82 |
+
|
| 83 |
+
private:
|
| 84 |
+
MPSStream* m_default_stream = nullptr;
|
| 85 |
+
std::recursive_mutex m_mutex;
|
| 86 |
+
std::stack<std::unique_ptr<MPSEvent>> m_pool{};
|
| 87 |
+
// dictionary to associate event IDs with event objects
|
| 88 |
+
// used to retain in-use events out of the pool
|
| 89 |
+
// for torch.mps.Event() bindings.
|
| 90 |
+
std::unordered_map<id_t, MPSEventPtr> m_in_use_events{};
|
| 91 |
+
uint64_t m_event_counter = 0;
|
| 92 |
+
std::function<void(MPSEvent*)> m_default_deleter;
|
| 93 |
+
|
| 94 |
+
MPSEvent* getInUseEvent(id_t event_id, bool locked = true);
|
| 95 |
+
};
|
| 96 |
+
|
| 97 |
+
// shared_ptr is used to get MPSEventPool destroyed after dependent instances
|
| 98 |
+
std::shared_ptr<MPSEventPool> getMPSEventPool();
|
| 99 |
+
|
| 100 |
+
} // namespace at::mps
|
.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSGeneratorImpl.h
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2022 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <ATen/core/Generator.h>
|
| 6 |
+
#include <ATen/core/PhiloxRNGEngine.h>
|
| 7 |
+
#include <c10/core/GeneratorImpl.h>
|
| 8 |
+
#include <optional>
|
| 9 |
+
|
| 10 |
+
namespace at {
|
| 11 |
+
namespace mps::detail {
|
| 12 |
+
|
| 13 |
+
constexpr uint32_t PHILOX_STATE_N = 7;
|
| 14 |
+
struct rng_data_pod {
|
| 15 |
+
std::array<uint32_t, PHILOX_STATE_N> state{1};
|
| 16 |
+
uint64_t seed = default_rng_seed_val;
|
| 17 |
+
};
|
| 18 |
+
|
| 19 |
+
TORCH_API const Generator& getDefaultMPSGenerator();
|
| 20 |
+
TORCH_API Generator createMPSGenerator(uint64_t seed_val = default_rng_seed_val);
|
| 21 |
+
|
| 22 |
+
} // namespace mps::detail
|
| 23 |
+
|
| 24 |
+
struct TORCH_API MPSGeneratorImpl : public c10::GeneratorImpl {
|
| 25 |
+
// Constructors
|
| 26 |
+
MPSGeneratorImpl(uint64_t seed_in = default_rng_seed_val);
|
| 27 |
+
~MPSGeneratorImpl() override = default;
|
| 28 |
+
|
| 29 |
+
// MPSGeneratorImpl methods
|
| 30 |
+
std::shared_ptr<MPSGeneratorImpl> clone() const;
|
| 31 |
+
void set_current_seed(uint64_t seed) override;
|
| 32 |
+
void set_offset(uint64_t offset) override;
|
| 33 |
+
uint64_t get_offset() const override;
|
| 34 |
+
uint64_t current_seed() const override;
|
| 35 |
+
uint64_t seed() override;
|
| 36 |
+
void set_state(const c10::TensorImpl& new_state) override;
|
| 37 |
+
c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
|
| 38 |
+
void update_philox_counters();
|
| 39 |
+
|
| 40 |
+
void set_engine(at::Philox4_32 engine) { engine_ = engine; };
|
| 41 |
+
at::Philox4_32 engine() { return engine_; };
|
| 42 |
+
uint32_t* state_data() { return data_.state.data(); }
|
| 43 |
+
static DeviceType device_type() { return DeviceType::MPS; };
|
| 44 |
+
|
| 45 |
+
private:
|
| 46 |
+
mps::detail::rng_data_pod data_;
|
| 47 |
+
at::Philox4_32 engine_;
|
| 48 |
+
|
| 49 |
+
MPSGeneratorImpl* clone_impl() const override;
|
| 50 |
+
};
|
| 51 |
+
|
| 52 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSGuardImpl.h
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2022 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
| 5 |
+
#include <c10/macros/Macros.h>
|
| 6 |
+
#include <c10/util/Exception.h>
|
| 7 |
+
#include <ATen/Context.h>
|
| 8 |
+
#include <ATen/mps/MPSStream.h>
|
| 9 |
+
#include <ATen/mps/MPSEvent.h>
|
| 10 |
+
|
| 11 |
+
#ifdef __OBJC__
|
| 12 |
+
#include <Foundation/Foundation.h>
|
| 13 |
+
#include <Metal/Metal.h>
|
| 14 |
+
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
| 15 |
+
#endif
|
| 16 |
+
|
| 17 |
+
#include <ATen/Tensor.h>
|
| 18 |
+
#include <c10/core/MemoryFormat.h>
|
| 19 |
+
#include <c10/core/Storage.h>
|
| 20 |
+
#include <c10/core/TensorImpl.h>
|
| 21 |
+
#include <sys/_types/_size_t.h>
|
| 22 |
+
#include <memory>
|
| 23 |
+
#include <c10/core/UndefinedTensorImpl.h>
|
| 24 |
+
#include <c10/util/intrusive_ptr.h>
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
namespace at::mps {
|
| 28 |
+
|
| 29 |
+
typedef MPSEvent* mpsEvent_t;
|
| 30 |
+
|
| 31 |
+
// TODO: Move the MPSGuardImpl to inherit from NoOpDeviceGuardImpl
|
| 32 |
+
// https://github.com/pytorch/pytorch/issues/77170
|
| 33 |
+
struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
| 34 |
+
static constexpr c10::DeviceType static_type = c10::DeviceType::MPS;
|
| 35 |
+
|
| 36 |
+
// constructor
|
| 37 |
+
MPSGuardImpl() {}
|
| 38 |
+
explicit MPSGuardImpl(c10::DeviceType t) {
|
| 39 |
+
TORCH_INTERNAL_ASSERT(t == c10::DeviceType::MPS);
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
// returns the type
|
| 43 |
+
c10::DeviceType type() const override {
|
| 44 |
+
return c10::DeviceType::MPS;
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
Device exchangeDevice(Device d) const override {
|
| 48 |
+
return Device(c10::DeviceType::MPS, 0);
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
Device getDevice() const override {
|
| 52 |
+
return Device(c10::DeviceType::MPS, 0);
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
std::optional<Device> uncheckedGetDevice() const noexcept {
|
| 56 |
+
return Device(c10::DeviceType::MPS, 0);
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
void setDevice(Device d) const override {
|
| 60 |
+
TORCH_INTERNAL_ASSERT(d.is_mps());
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
void uncheckedSetDevice(Device d) const noexcept override {
|
| 64 |
+
// TODO: Currently setting only device 0
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
Stream getStream(Device d) const noexcept override {
|
| 68 |
+
return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
Stream getNewStream(Device, int priority = 0) const override {
|
| 72 |
+
(void)priority;
|
| 73 |
+
return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
Stream getDefaultStream(Device d) const override {
|
| 77 |
+
return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
// NB: These do NOT set the current device
|
| 81 |
+
Stream exchangeStream(Stream s) const noexcept override {
|
| 82 |
+
return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
|
| 83 |
+
}
|
| 84 |
+
DeviceIndex deviceCount() const noexcept override {
|
| 85 |
+
if (at::hasMPS()) {
|
| 86 |
+
//TODO: extend it for multi-device case
|
| 87 |
+
return 1;
|
| 88 |
+
} else {
|
| 89 |
+
return 0;
|
| 90 |
+
}
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
// Event-related functions
|
| 94 |
+
void createEvent(
|
| 95 |
+
mpsEvent_t* event,
|
| 96 |
+
const EventFlag flag) const;
|
| 97 |
+
|
| 98 |
+
void destroyEvent(
|
| 99 |
+
void* event,
|
| 100 |
+
const DeviceIndex device_index) const noexcept override;
|
| 101 |
+
|
| 102 |
+
void record(
|
| 103 |
+
void** event,
|
| 104 |
+
const Stream& stream,
|
| 105 |
+
const DeviceIndex device_index,
|
| 106 |
+
const EventFlag flag) const override;
|
| 107 |
+
|
| 108 |
+
void block(
|
| 109 |
+
void* event,
|
| 110 |
+
const Stream& stream) const override;
|
| 111 |
+
|
| 112 |
+
bool queryEvent(void* event) const override;
|
| 113 |
+
|
| 114 |
+
};
|
| 115 |
+
|
| 116 |
+
/// A variant of OptionalDeviceGuard that is specialized for MPS.
|
| 117 |
+
struct OptionalMPSGuard {
|
| 118 |
+
explicit OptionalMPSGuard() : guard_() {}
|
| 119 |
+
|
| 120 |
+
explicit OptionalMPSGuard(std::optional<Device> device_opt)
|
| 121 |
+
: guard_(device_opt) {}
|
| 122 |
+
|
| 123 |
+
/// Set the current MPS device to the passed device index, if it is not
|
| 124 |
+
/// nullopt
|
| 125 |
+
explicit OptionalMPSGuard(std::optional<DeviceIndex> device_index_opt)
|
| 126 |
+
: guard_(device_index_opt) {}
|
| 127 |
+
|
| 128 |
+
// Copy is not allowed
|
| 129 |
+
OptionalMPSGuard(const OptionalMPSGuard&) = delete;
|
| 130 |
+
OptionalMPSGuard& operator=(const OptionalMPSGuard&) = delete;
|
| 131 |
+
OptionalMPSGuard(OptionalMPSGuard&& other) = delete;
|
| 132 |
+
OptionalMPSGuard& operator=(OptionalMPSGuard&& other) = delete;
|
| 133 |
+
|
| 134 |
+
/// Sets the MPS device to the given device, initializing the guard if it
|
| 135 |
+
/// is not already initialized. Errors if the given device is not a MPS
|
| 136 |
+
/// device.
|
| 137 |
+
void set_device(Device device) {
|
| 138 |
+
guard_.set_device(device);
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
/// Sets the MPS device to the given device, initializing the guard if it is
|
| 142 |
+
/// not already initialized. Errors if the given device is not a MPS device.
|
| 143 |
+
void reset_device(Device device) {
|
| 144 |
+
guard_.reset_device(device);
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
/// Sets the MPS device to the given device index, initializing the guard if
|
| 148 |
+
/// it is not already initialized.
|
| 149 |
+
void set_index(DeviceIndex device_index) {
|
| 150 |
+
guard_.set_index(device_index);
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
/// Returns the device that was set immediately prior to initialization of the
|
| 154 |
+
/// guard, or nullopt if the guard is uninitialized.
|
| 155 |
+
std::optional<Device> original_device() const {
|
| 156 |
+
return guard_.original_device();
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
/// Returns the most recent device that was set using this device guard,
|
| 160 |
+
/// either from construction, or via set_device, if the guard is initialized,
|
| 161 |
+
/// or nullopt if the guard is uninitialized.
|
| 162 |
+
std::optional<Device> current_device() const {
|
| 163 |
+
return guard_.current_device();
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
/// Restore the original MPS device, resetting this guard to uninitialized
|
| 167 |
+
/// state.
|
| 168 |
+
void reset() {
|
| 169 |
+
guard_.reset();
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
private:
|
| 173 |
+
c10::impl::InlineOptionalDeviceGuard<MPSGuardImpl> guard_;
|
| 174 |
+
};
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
C10_REGISTER_GUARD_IMPL(MPS, MPSGuardImpl);
|
| 178 |
+
|
| 179 |
+
} // namespace at::mps
|
.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSHooks.h
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2022 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <ATen/detail/MPSHooksInterface.h>
|
| 6 |
+
#include <ATen/Generator.h>
|
| 7 |
+
#include <ATen/mps/MPSEvent.h>
|
| 8 |
+
#include <optional>
|
| 9 |
+
|
| 10 |
+
namespace at::mps {
|
| 11 |
+
|
| 12 |
+
// The real implementation of MPSHooksInterface
|
| 13 |
+
struct MPSHooks : public at::MPSHooksInterface {
|
| 14 |
+
MPSHooks(at::MPSHooksArgs) {}
|
| 15 |
+
void initMPS() const override;
|
| 16 |
+
|
| 17 |
+
// MPSDevice interface
|
| 18 |
+
bool hasMPS() const override;
|
| 19 |
+
bool isOnMacOSorNewer(unsigned major, unsigned minor) const override;
|
| 20 |
+
|
| 21 |
+
// MPSGeneratorImpl interface
|
| 22 |
+
const Generator& getDefaultMPSGenerator() const override;
|
| 23 |
+
|
| 24 |
+
// MPSStream interface
|
| 25 |
+
void deviceSynchronize() const override;
|
| 26 |
+
void commitStream() const override;
|
| 27 |
+
void* getCommandBuffer() const override;
|
| 28 |
+
void* getDispatchQueue() const override;
|
| 29 |
+
|
| 30 |
+
// MPSAllocator interface
|
| 31 |
+
Allocator* getMPSDeviceAllocator() const override;
|
| 32 |
+
void emptyCache() const override;
|
| 33 |
+
size_t getCurrentAllocatedMemory() const override;
|
| 34 |
+
size_t getDriverAllocatedMemory() const override;
|
| 35 |
+
size_t getRecommendedMaxMemory() const override;
|
| 36 |
+
void setMemoryFraction(double ratio) const override;
|
| 37 |
+
bool isPinnedPtr(const void* data) const override;
|
| 38 |
+
Allocator* getPinnedMemoryAllocator() const override;
|
| 39 |
+
|
| 40 |
+
// MPSProfiler interface
|
| 41 |
+
void profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const override;
|
| 42 |
+
void profilerStopTrace() const override;
|
| 43 |
+
|
| 44 |
+
// MPSEvent interface
|
| 45 |
+
uint32_t acquireEvent(bool enable_timing) const override;
|
| 46 |
+
void releaseEvent(uint32_t event_id) const override;
|
| 47 |
+
void recordEvent(uint32_t event_id) const override;
|
| 48 |
+
void waitForEvent(uint32_t event_id) const override;
|
| 49 |
+
void synchronizeEvent(uint32_t event_id) const override;
|
| 50 |
+
bool queryEvent(uint32_t event_id) const override;
|
| 51 |
+
double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const override;
|
| 52 |
+
|
| 53 |
+
// Compatibility with Accelerator API
|
| 54 |
+
bool hasPrimaryContext(DeviceIndex device_index) const override {
|
| 55 |
+
// When MPS is available, it is always in use for the one device.
|
| 56 |
+
return true;
|
| 57 |
+
}
|
| 58 |
+
};
|
| 59 |
+
|
| 60 |
+
} // namespace at::mps
|
.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSProfiler.h
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2022 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <ATen/Tensor.h>
|
| 6 |
+
#include <ATen/mps/MPSStream.h>
|
| 7 |
+
#include <ATen/mps/MPSAllocatorInterface.h>
|
| 8 |
+
|
| 9 |
+
#include <os/signpost.h>
|
| 10 |
+
#include <os/log.h>
|
| 11 |
+
|
| 12 |
+
#include <atomic>
|
| 13 |
+
#include <ctime>
|
| 14 |
+
#include <sstream>
|
| 15 |
+
#include <string>
|
| 16 |
+
#include <unordered_map>
|
| 17 |
+
#include <utility>
|
| 18 |
+
|
| 19 |
+
namespace at::mps {
|
| 20 |
+
|
| 21 |
+
namespace Profiler {
|
| 22 |
+
|
| 23 |
+
struct BaseInfo {
|
| 24 |
+
// profiling info types
|
| 25 |
+
enum class Type {
|
| 26 |
+
GRAPH,
|
| 27 |
+
KERNEL,
|
| 28 |
+
COPY,
|
| 29 |
+
CPU_FALLBACK,
|
| 30 |
+
};
|
| 31 |
+
|
| 32 |
+
BaseInfo(Type infoType, uint64_t Id, const uintptr_t Handle) :
|
| 33 |
+
type(infoType), profileId(Id), handle(Handle) { }
|
| 34 |
+
virtual ~BaseInfo() = default;
|
| 35 |
+
|
| 36 |
+
// type of profiling info
|
| 37 |
+
Type type;
|
| 38 |
+
// unique profile ID for execution instances of operations or copies
|
| 39 |
+
uint64_t profileId;
|
| 40 |
+
// ID generated by os_signpost
|
| 41 |
+
// since it's possible to use event and interval-based signposts at the
|
| 42 |
+
// same time, we need separate IDs for each.
|
| 43 |
+
os_signpost_id_t eventSignpostId = 0, intervalSignpostId = 0;
|
| 44 |
+
// accumulated GPU time in ms (obtained from CompletionHandler's "GPUEndTime - GPUStartTime")
|
| 45 |
+
std::atomic<double> totalGpuTime{0.0};
|
| 46 |
+
// accumulated Scheduling time in ms (obtained from CompletionHandler's "KernelEndTime - KernelStartTime")
|
| 47 |
+
std::atomic<double> totalSchedulingTime{0.0};
|
| 48 |
+
// indicates if the operation or copy execution has completed
|
| 49 |
+
std::atomic_bool completed{false};
|
| 50 |
+
// handle used to identify the profile info's instance (usually the pointer)
|
| 51 |
+
const uintptr_t handle;
|
| 52 |
+
|
| 53 |
+
virtual const std::string toString(double gpuTime = 0, double schedulingTime = 0) const;
|
| 54 |
+
// builds a string for a tensor (format: Device:ScalarType[tensor.sizes()])
|
| 55 |
+
static std::string buildTensorString(const Tensor& tensor, bool includeBufferId = false) {
|
| 56 |
+
if (tensor.defined()) {
|
| 57 |
+
std::stringstream tensorStr;
|
| 58 |
+
auto deviceType = tensor.device().type();
|
| 59 |
+
tensorStr << c10::DeviceTypeName(deviceType);
|
| 60 |
+
// see comments for INCLUDE_BUFFER_ID
|
| 61 |
+
if (includeBufferId && deviceType == at::kMPS) {
|
| 62 |
+
id<MTLBuffer> buffer = __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
|
| 63 |
+
tensorStr << "(buf#" << (getIMPSAllocator()->getBufferId(buffer))
|
| 64 |
+
<< ":" << buffer.retainCount << ")";
|
| 65 |
+
}
|
| 66 |
+
tensorStr << ":"
|
| 67 |
+
<< tensor.scalar_type() << tensor.sizes();
|
| 68 |
+
return tensorStr.str();
|
| 69 |
+
} else {
|
| 70 |
+
return "undefined";
|
| 71 |
+
}
|
| 72 |
+
}
|
| 73 |
+
static uint64_t getTime() {
|
| 74 |
+
return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW);
|
| 75 |
+
}
|
| 76 |
+
};
|
| 77 |
+
|
| 78 |
+
struct OperationInfo : BaseInfo {
|
| 79 |
+
OperationInfo(const void* Handle, bool IsGraph, uint64_t Id, const std::string& StrKey) :
|
| 80 |
+
BaseInfo(IsGraph ? Type::GRAPH : Type::KERNEL, Id, uintptr_t(Handle)), strKey(StrKey) { }
|
| 81 |
+
|
| 82 |
+
uint64_t runCount = 0;
|
| 83 |
+
std::string strKey;
|
| 84 |
+
|
| 85 |
+
const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
|
| 86 |
+
|
| 87 |
+
// builds a string for a kernel
|
| 88 |
+
static std::string buildKernelString(const std::string& kernelName,
|
| 89 |
+
const TensorList& tensors,
|
| 90 |
+
bool includeBufferId = false) {
|
| 91 |
+
std::stringstream kernelStr;
|
| 92 |
+
kernelStr << kernelName;
|
| 93 |
+
for (const Tensor& tensor: tensors) {
|
| 94 |
+
kernelStr << ":" << BaseInfo::buildTensorString(tensor, includeBufferId);
|
| 95 |
+
}
|
| 96 |
+
return kernelStr.str();
|
| 97 |
+
}
|
| 98 |
+
};
|
| 99 |
+
|
| 100 |
+
struct CpuFbInfo : BaseInfo {
|
| 101 |
+
CpuFbInfo(uint64_t Id, const std::string& OpName) :
|
| 102 |
+
BaseInfo(Type::CPU_FALLBACK, Id, 0), opName(OpName) { }
|
| 103 |
+
|
| 104 |
+
uint64_t runCount = 0;
|
| 105 |
+
// the current and total overhead of copies in bytes required to convert the Op's
|
| 106 |
+
// input tensors from MPS to CPU and then output from CPU back to MPS
|
| 107 |
+
size_t currentCopyOverhead = 0;
|
| 108 |
+
size_t totalCopyOverhead = 0;
|
| 109 |
+
std::string opName;
|
| 110 |
+
std::string strKey;
|
| 111 |
+
uint64_t startTime = 0;
|
| 112 |
+
|
| 113 |
+
const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
|
| 114 |
+
|
| 115 |
+
void updateCopyOverhead(const TensorList& tensors) {
|
| 116 |
+
currentCopyOverhead = 0;
|
| 117 |
+
for (const Tensor& tensor: tensors) {
|
| 118 |
+
if (tensor.defined()) {
|
| 119 |
+
currentCopyOverhead += tensor.nbytes();
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
totalCopyOverhead += currentCopyOverhead;
|
| 123 |
+
}
|
| 124 |
+
};
|
| 125 |
+
|
| 126 |
+
struct CopyInfo : BaseInfo {
|
| 127 |
+
enum class Kind {
|
| 128 |
+
MPS_TO_MPS,
|
| 129 |
+
MPS_TO_CPU,
|
| 130 |
+
CPU_TO_MPS,
|
| 131 |
+
};
|
| 132 |
+
|
| 133 |
+
CopyInfo(const void* Handle, size_t Length, uint64_t Id, bool IsNonBlocking, bool UsesBlitter) :
|
| 134 |
+
BaseInfo(Type::COPY, Id, uintptr_t(Handle)), kind(Kind::MPS_TO_MPS),
|
| 135 |
+
length(Length), isNonBlocking(IsNonBlocking), usesBlitter(UsesBlitter) { }
|
| 136 |
+
|
| 137 |
+
Kind kind;
|
| 138 |
+
size_t length;
|
| 139 |
+
bool isNonBlocking;
|
| 140 |
+
bool usesBlitter;
|
| 141 |
+
std::string srcStrKey;
|
| 142 |
+
std::string dstStrKey;
|
| 143 |
+
// for copies that don't use blitters, we measure CPU time
|
| 144 |
+
uint64_t startTime = 0;
|
| 145 |
+
|
| 146 |
+
const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
|
| 147 |
+
|
| 148 |
+
static std::string buildTensorString(const void* buffer, const OptionalTensorRef tensor, bool includeBufferId = false);
|
| 149 |
+
|
| 150 |
+
static bool isStorageOnMPS(const void* buffer, const OptionalTensorRef tensor) {
|
| 151 |
+
if (tensor.has_value()) {
|
| 152 |
+
return tensor->device().type() == at::kMPS;
|
| 153 |
+
}
|
| 154 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(buffer);
|
| 155 |
+
// getUnalignedBufferSize() returns -1 if input buffer is not on MPS device
|
| 156 |
+
return getIMPSAllocator()->getUnalignedBufferSize(buffer) >= 0;
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
static Kind getCopyKind(const void* srcBuffer, const void* dstBuffer,
|
| 160 |
+
const OptionalTensorRef srcTensor, const OptionalTensorRef dstTensor) {
|
| 161 |
+
const bool isSrcOnMPS = isStorageOnMPS(srcBuffer, srcTensor);
|
| 162 |
+
const bool isDstOnMPS = isStorageOnMPS(dstBuffer, dstTensor);
|
| 163 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isSrcOnMPS || isDstOnMPS);
|
| 164 |
+
if (isSrcOnMPS && !isDstOnMPS) {
|
| 165 |
+
return Kind::MPS_TO_CPU;
|
| 166 |
+
} else if (!isSrcOnMPS && isDstOnMPS) {
|
| 167 |
+
return Kind::CPU_TO_MPS;
|
| 168 |
+
}
|
| 169 |
+
return Kind::MPS_TO_MPS;
|
| 170 |
+
}
|
| 171 |
+
};
|
| 172 |
+
|
| 173 |
+
struct CopyStat : CopyInfo {
|
| 174 |
+
explicit CopyStat(std::string CopyKindStr) :
|
| 175 |
+
CopyInfo(nullptr, 0, 0, false, false), kindStr(std::move(CopyKindStr)) {}
|
| 176 |
+
// total number of copies
|
| 177 |
+
size_t totalCount = 0;
|
| 178 |
+
// number of Scalar copies (i.e., less than sizeof(int64))
|
| 179 |
+
size_t scalarsCount = 0;
|
| 180 |
+
// number of blocking copies (i.e., require syncing to GPU)
|
| 181 |
+
size_t blockingCount = 0;
|
| 182 |
+
// number of copies that used memcpy(), instead of Metal Blit Encoder
|
| 183 |
+
size_t memcpyCount = 0;
|
| 184 |
+
// accumulated GPU time in ms for the scalar copies
|
| 185 |
+
std::atomic<double> scalarsGpuTime{0.0};
|
| 186 |
+
// copy kind in string type
|
| 187 |
+
std::string kindStr;
|
| 188 |
+
};
|
| 189 |
+
|
| 190 |
+
class MPSProfiler {
|
| 191 |
+
public:
|
| 192 |
+
// lower 16 bits used for profiler options
|
| 193 |
+
enum ProfileOptions : uint32_t {
|
| 194 |
+
OPTIONS_NONE = 0,
|
| 195 |
+
// ALL_* means, all signpost types (RUN_OPERATION|BLIT_COPY|CPU_FALLBACK, etc.)
|
| 196 |
+
// (used for convenience to not compute bit flags by OR-ing manually)
|
| 197 |
+
// trace all signpost types using events
|
| 198 |
+
ALL_SIGNPOST_EVENTS = (1 << 0),
|
| 199 |
+
// trace all signpost types using intervals
|
| 200 |
+
ALL_SIGNPOST_INTERVALS = (1 << 1),
|
| 201 |
+
// always wait for command buffer to finish executing after each commit
|
| 202 |
+
WAIT_UNTIL_COMPLETED = (1 << 2),
|
| 203 |
+
// for interval-based signposts, include the scheduling portion of
|
| 204 |
+
// Graph/Kernel/Copy executions as well.
|
| 205 |
+
// if flag is disable, only "GPU run time" is included in interval,
|
| 206 |
+
// and not schedule time.
|
| 207 |
+
INCLUDE_SCHEDULE_INTERVAL = (1 << 3),
|
| 208 |
+
|
| 209 |
+
// use these if you need to trace signposts types individually (rarely required)
|
| 210 |
+
// trace signpost using intervals
|
| 211 |
+
USE_INTERVALS = (1 << 4),
|
| 212 |
+
// trace signpost by emitting events
|
| 213 |
+
USE_EVENTS = (1 << 5),
|
| 214 |
+
// used for sanity check (Change this when new option added)
|
| 215 |
+
OPTIONS_COUNT = (USE_EVENTS << 1) - 1,
|
| 216 |
+
};
|
| 217 |
+
|
| 218 |
+
// when adding new types, #define the type string in MPSProfiler.mm as well.
|
| 219 |
+
// upper 16 bits used for event types
|
| 220 |
+
enum SignpostTypes : uint32_t {
|
| 221 |
+
SIGNPOST_NONE = 0,
|
| 222 |
+
// trace signposts for PyTorch operation executions
|
| 223 |
+
RUN_OPERATION = (1 << 16),
|
| 224 |
+
// trace signposts for blitter copies
|
| 225 |
+
BLIT_COPY = (1 << 17),
|
| 226 |
+
// trace signposts for ops that fall back on CPU
|
| 227 |
+
CPU_FALLBACK = (1 << 18),
|
| 228 |
+
// used for sanity check (Change this when new type added)
|
| 229 |
+
SIGNPOST_COUNT = (CPU_FALLBACK << 1) - 1,
|
| 230 |
+
};
|
| 231 |
+
|
| 232 |
+
enum LogOptions : uint32_t {
|
| 233 |
+
LOG_NONE = 0,
|
| 234 |
+
|
| 235 |
+
// Info logging options during execution
|
| 236 |
+
// -------------------------------------
|
| 237 |
+
// prints operation info (id/key/run_count) during execution
|
| 238 |
+
OPERATION_INFO = (1 << 0),
|
| 239 |
+
// prints copy info (src/dst tensors/buffers, size, etc.) during execution
|
| 240 |
+
COPY_INFO = (1 << 1),
|
| 241 |
+
// prints CPU Fallback info (id/runCount/opName/copyOverhead) during execution
|
| 242 |
+
CPU_FALLBACK_INFO = (1 << 2),
|
| 243 |
+
|
| 244 |
+
// Profiling Statistics logging options when process terminates
|
| 245 |
+
// ------------------------------------------------------------
|
| 246 |
+
// prints all stats (OPERATION_STATS, COPY_STATS, CPU_FALLBACK_STATS) before process terminates
|
| 247 |
+
// this is convenient to not combine following stats bit flags manually
|
| 248 |
+
ALL_STATS = (1 << 3),
|
| 249 |
+
// prints operation stats (GPU times, run count, etc.) before process terminates
|
| 250 |
+
OPERATION_STATS = (1 << 4),
|
| 251 |
+
// prints copies stats (GPU times, copy kinds, sizes, etc.) before process terminates
|
| 252 |
+
COPY_STATS = (1 << 5),
|
| 253 |
+
// prints CPU Fallback stats (CPU times, run times, size of MPS<->CPU copies
|
| 254 |
+
// for tensors, etc.) before process terminates
|
| 255 |
+
CPU_FALLBACK_STATS = (1 << 6),
|
| 256 |
+
|
| 257 |
+
// Metadata format options when logging the info
|
| 258 |
+
// ---------------------------------------------
|
| 259 |
+
// if enabled, includes GPU run time in metadata (i.e., GPUEndTime-GPUStartTime
|
| 260 |
+
// from Metal Command Buffers) (e.g., [GPU=0.324 ms])
|
| 261 |
+
INCLUDE_GPU_TIME = (1 << 7),
|
| 262 |
+
// if enabled, includes GPU scheduling time in metadata separately
|
| 263 |
+
// (i.e., KernelEndTime-KernelStartTime from Metal Command Buffers)
|
| 264 |
+
// e.g., [GPU=0.324 ms, KRNL=0.036 ms]
|
| 265 |
+
INCLUDE_KERNEL_TIME = (1 << 8),
|
| 266 |
+
// if enabled, includes the unique buffer ID in metadata for the storage
|
| 267 |
+
// of a tensor that was allocated on MPSAllocator. This is useful (along with
|
| 268 |
+
// the EV "PYTORCH_DEBUG_MPS_ALLOCATOR") to identify buffers that are involved
|
| 269 |
+
// with various operations.
|
| 270 |
+
INCLUDE_BUFFER_ID = (1 << 9),
|
| 271 |
+
|
| 272 |
+
// used for sanity check (Change this when new option added)
|
| 273 |
+
LOG_COUNT = (INCLUDE_BUFFER_ID << 1) - 1,
|
| 274 |
+
};
|
| 275 |
+
|
| 276 |
+
explicit MPSProfiler();
|
| 277 |
+
~MPSProfiler();
|
| 278 |
+
|
| 279 |
+
// the handle is either "MPSGraph*" or "id<MTLComputePipelineState>" for Metal Kernels
|
| 280 |
+
// the beginProfile*() functions return a profileId which is unique per graph/kernel/copy
|
| 281 |
+
uint64_t beginProfileKernel(const void* handle, const std::string& strKey, bool isGraph);
|
| 282 |
+
uint64_t beginProfileKernel(const void* handle, const std::string& kernelName, const TensorList& tensors);
|
| 283 |
+
uint64_t beginProfileCopy(const void* srcBuffer, const void* dstBuffer,
|
| 284 |
+
const OptionalTensorRef srcTensor,
|
| 285 |
+
const OptionalTensorRef dstTensor,
|
| 286 |
+
size_t length, bool isNonBlocking, bool usesBlitter = true);
|
| 287 |
+
uint64_t beginProfileCPUFallback(const std::string& opName, const TensorList& tensors);
|
| 288 |
+
void beginProfileGPUInterval(const void* handle);
|
| 289 |
+
|
| 290 |
+
void endProfileCopy(uint64_t profileId, SyncType syncType);
|
| 291 |
+
void endProfileKernel(const void* handle, SyncType syncType = SyncType::NONE);
|
| 292 |
+
void endProfileCPUFallback(const std::string& opName);
|
| 293 |
+
|
| 294 |
+
// these are used to hook into Python bindings for torch.mps.profiler module.
|
| 295 |
+
// this enables generating OS Signpost traces from MPSProfiler on-demand
|
| 296 |
+
// during runtime (instead of environment variables).
|
| 297 |
+
// The "mode" could be either "interval", "event", or both "interval,event"
|
| 298 |
+
// for interval-based and/or event-based signpost tracing.
|
| 299 |
+
void StartTrace(const std::string& mode, bool waitUntilCompleted);
|
| 300 |
+
void StopTrace();
|
| 301 |
+
|
| 302 |
+
// Abstractions for GPU trace capturing
|
| 303 |
+
bool isCaptureEnabled() const;
|
| 304 |
+
bool isCapturing() const;
|
| 305 |
+
void startCapture(const std::string& name, MPSStream* stream = nullptr);
|
| 306 |
+
void stopCapture(MPSStream* stream = nullptr);
|
| 307 |
+
|
| 308 |
+
// convenience functions to indicate whether signpost tracing or
|
| 309 |
+
// logging are enabled for the SignpostTypes
|
| 310 |
+
bool isOperationProfilingEnabled() const {
|
| 311 |
+
return (m_signpost_types & SignpostTypes::RUN_OPERATION) ||
|
| 312 |
+
(m_log_options & (LogOptions::OPERATION_INFO | LogOptions::OPERATION_STATS));
|
| 313 |
+
}
|
| 314 |
+
bool isCopyProfilingEnabled() const {
|
| 315 |
+
return (m_signpost_types & SignpostTypes::BLIT_COPY) ||
|
| 316 |
+
(m_log_options & (LogOptions::COPY_INFO | LogOptions::COPY_STATS));
|
| 317 |
+
}
|
| 318 |
+
bool isCPUFallbackProfilingEnabled() const {
|
| 319 |
+
return (m_signpost_types & SignpostTypes::CPU_FALLBACK) ||
|
| 320 |
+
(m_log_options & (LogOptions::CPU_FALLBACK_INFO | LogOptions::CPU_FALLBACK_STATS));
|
| 321 |
+
}
|
| 322 |
+
bool isSignpostTracingEnabled() const {
|
| 323 |
+
return (m_signpost_types != SignpostTypes::SIGNPOST_NONE);
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
private:
|
| 327 |
+
// indicates what type of signpost types are enabled and traced by MPS profiler.
|
| 328 |
+
uint32_t m_signpost_types = 0;
|
| 329 |
+
uint32_t m_profile_options = 0;
|
| 330 |
+
uint32_t m_log_options = 0;
|
| 331 |
+
uint64_t m_kernel_counter = 0;
|
| 332 |
+
uint64_t m_graph_counter = 0;
|
| 333 |
+
uint64_t m_cpu_fb_counter = 0;
|
| 334 |
+
uint64_t m_copy_counter = 0;
|
| 335 |
+
// technically, it's possible to trace both events and intervals at the same time
|
| 336 |
+
// so we use separate os_log categories for them
|
| 337 |
+
os_log_t m_os_log_events;
|
| 338 |
+
os_log_t m_os_log_intervals;
|
| 339 |
+
// stats logging could run either from destructor or signal handler
|
| 340 |
+
// so this is used to check if logging has already started.
|
| 341 |
+
std::atomic_bool hasLoggedStats{false};
|
| 342 |
+
// indicates there are pending completionHandler callbacks that haven't been called yet.
|
| 343 |
+
std::atomic_bool hasPendingCompletionHandlers{false};
|
| 344 |
+
// used to capture sigint signal to log profiling stats
|
| 345 |
+
static struct sigaction currentSigint, previousSigint;
|
| 346 |
+
|
| 347 |
+
// We use the following lists for two reasons:
|
| 348 |
+
// 1- for interval-based signposts the "begin" point won't be in same function
|
| 349 |
+
// as the "end" point where we need to be able to retrieve signpost's info
|
| 350 |
+
// 2- if Operations info need to be logged when process ends using LogOptions::OPERATION_INFO.
|
| 351 |
+
|
| 352 |
+
// the pointer key for this map is either "MPSGraph*" or "id<MTLComputePipelineState>" for Metal Kernels
|
| 353 |
+
// this list is retained and could be logged along with aggregate profiling numbers when the process ends.
|
| 354 |
+
std::unordered_map<uintptr_t, std::unique_ptr<OperationInfo>> m_op_info_list{};
|
| 355 |
+
// the string key for this map is the op name that we fall back to execute on CPU
|
| 356 |
+
// this list is retained and could be logged along with aggregate profiling numbers when the process ends.
|
| 357 |
+
std::unordered_map<std::string, std::unique_ptr<CpuFbInfo>> m_cpu_fb_info_list{};
|
| 358 |
+
// this list contains the info for copies, and its key is the unique profileId
|
| 359 |
+
// which is generated from m_copy_counter
|
| 360 |
+
// The copyInfo list is not retained.
|
| 361 |
+
std::unordered_map<uint64_t, std::unique_ptr<CopyInfo>> m_copy_info_list{};
|
| 362 |
+
// a short list that contains copy stats
|
| 363 |
+
std::unordered_map<CopyInfo::Kind, std::unique_ptr<CopyStat>> m_copy_stat_list{};
|
| 364 |
+
|
| 365 |
+
mutable MTLCaptureManager *captureManager = nil;
|
| 366 |
+
unsigned captureCount = 0;
|
| 367 |
+
|
| 368 |
+
void initialize();
|
| 369 |
+
void beginProfileExecution(BaseInfo& info, bool cpuExecution = false);
|
| 370 |
+
void endProfileExecution(BaseInfo& info, os_signpost_id_t event_signpost_id,
|
| 371 |
+
os_signpost_id_t interval_signpost_id,
|
| 372 |
+
double gpuTime, double schedulingTime);
|
| 373 |
+
void addProfilerScheduledHandler(BaseInfo& info);
|
| 374 |
+
void addProfilerCompletedHandler(BaseInfo& info, SyncType syncType);
|
| 375 |
+
void emitSignpostEvent(SignpostTypes signpost_type, os_signpost_id_t signpost_id,
|
| 376 |
+
const std::string& msg) const;
|
| 377 |
+
void beginSignpostInterval(SignpostTypes signpost_type, os_signpost_id_t signpost_id,
|
| 378 |
+
const std::string& msg) const;
|
| 379 |
+
void endSignpostInterval(SignpostTypes signpost_type, os_signpost_id_t signpost_id) const;
|
| 380 |
+
|
| 381 |
+
void updateCopyStats(const CopyInfo& copyInfo, double gpuTime, double schedulingTime);
|
| 382 |
+
// returns true if logging the profiling info "during the execution" is enabled
|
| 383 |
+
bool isProfileInfoLoggingEnabled(BaseInfo::Type infoType, bool isExecutionEnded);
|
| 384 |
+
// logs all the profiling stats that are enabled
|
| 385 |
+
void logProfilingStats();
|
| 386 |
+
// logs kernel profiling stats when the process ends.
|
| 387 |
+
void logOperationsProfilingStats(std::FILE* f) const;
|
| 388 |
+
// logs CPU Fallback profiling stats when the process ends.
|
| 389 |
+
void logCPUFallbackProfilingStats(std::FILE* f) const;
|
| 390 |
+
// logs copy profiling stats when the process ends.
|
| 391 |
+
void logCopyProfilingStats(std::FILE* f) const;
|
| 392 |
+
|
| 393 |
+
os_signpost_id_t generateSignpostId(os_signpost_type_t signpostType, const void* ptr = nullptr);
|
| 394 |
+
static SignpostTypes getSignpostType(BaseInfo::Type infoType);
|
| 395 |
+
static void handleIntSignal(int signal);
|
| 396 |
+
};
|
| 397 |
+
|
| 398 |
+
} // namespace Profiler
|
| 399 |
+
|
| 400 |
+
Profiler::MPSProfiler& getMPSProfiler();
|
| 401 |
+
|
| 402 |
+
} // namespace at::mps
|
.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSStream.h
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2022 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <cstdint>
|
| 6 |
+
#include <utility>
|
| 7 |
+
|
| 8 |
+
#include <c10/core/DeviceGuard.h>
|
| 9 |
+
#include <c10/util/Exception.h>
|
| 10 |
+
#include <c10/core/Stream.h>
|
| 11 |
+
#include <ATen/mps/MPSDevice.h>
|
| 12 |
+
|
| 13 |
+
#ifdef __OBJC__
|
| 14 |
+
#include <Foundation/Foundation.h>
|
| 15 |
+
#include <Metal/Metal.h>
|
| 16 |
+
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
| 17 |
+
#include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
|
| 18 |
+
typedef id<MTLCommandQueue> MTLCommandQueue_t;
|
| 19 |
+
typedef id<MTLCommandBuffer> MTLCommandBuffer_t;
|
| 20 |
+
typedef id<MTLComputeCommandEncoder> MTLComputeCommandEncoder_t;
|
| 21 |
+
typedef id<MTLSharedEvent> MTLSharedEvent_t;
|
| 22 |
+
typedef id<MTLDevice> MTLDevice_t;
|
| 23 |
+
#else
|
| 24 |
+
typedef void* MTLCommandQueue_t;
|
| 25 |
+
typedef void* MTLCommandQueue;
|
| 26 |
+
typedef void* MTLCommandBuffer_t;
|
| 27 |
+
typedef void* MTLCommandBuffer;
|
| 28 |
+
typedef void* MTLComputeCommandEncoder_t;
|
| 29 |
+
typedef void* MTLSharedEvent_t;
|
| 30 |
+
typedef void* dispatch_queue_t;
|
| 31 |
+
typedef void* MTLDevice_t;
|
| 32 |
+
#define nil NULL;
|
| 33 |
+
#endif
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
namespace at::mps {
|
| 37 |
+
|
| 38 |
+
//-----------------------------------------------------------------
|
| 39 |
+
// MPSStream
|
| 40 |
+
//-----------------------------------------------------------------
|
| 41 |
+
|
| 42 |
+
enum class SyncType {
|
| 43 |
+
NONE, // no commit to command buffer
|
| 44 |
+
COMMIT, // commit and flush the command buffer
|
| 45 |
+
COMMIT_AND_WAIT, // flush and wait for command buffer execution to finish
|
| 46 |
+
COMMIT_AND_CONTINUE,// commit and continue with a new underlying command buffer
|
| 47 |
+
COMMIT_ADAPTIVE, // commit adaptively based on available memory
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
+
class TORCH_API MPSStream
|
| 51 |
+
{
|
| 52 |
+
public:
|
| 53 |
+
enum Unchecked { UNCHECKED };
|
| 54 |
+
|
| 55 |
+
/// Construct a MPSStream from a Stream. This construction is checked,
|
| 56 |
+
/// and will raise an error if the Stream is not, in fact, a MPS stream.
|
| 57 |
+
explicit MPSStream(Stream stream);
|
| 58 |
+
|
| 59 |
+
~MPSStream();
|
| 60 |
+
MTLCommandQueue_t commandQueue() const { return _commandQueue; };
|
| 61 |
+
dispatch_queue_t queue() const { return _serialQueue; }
|
| 62 |
+
|
| 63 |
+
MPSCommandBuffer* commandBuffer();
|
| 64 |
+
MTLComputeCommandEncoder_t commandEncoder();
|
| 65 |
+
void endKernelCoalescing();
|
| 66 |
+
void synchronize(SyncType syncType);
|
| 67 |
+
void fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t offset, SyncType syncType = SyncType::NONE);
|
| 68 |
+
void copy(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
|
| 69 |
+
size_t length, size_t srcOffset, size_t dstOffset,
|
| 70 |
+
uint64_t profileId, SyncType syncType = SyncType::NONE);
|
| 71 |
+
void copy_and_sync(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
|
| 72 |
+
size_t length, size_t srcOffset, size_t dstOffset,
|
| 73 |
+
bool non_blocking, uint64_t profileId);
|
| 74 |
+
void executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results, SyncType syncType = SyncType::NONE);
|
| 75 |
+
void addCompletedHandler(MTLCommandBufferHandler block);
|
| 76 |
+
|
| 77 |
+
/// Get the MPS device index that this stream is associated with.
|
| 78 |
+
c10::DeviceIndex device_index() const { return _stream.device_index(); }
|
| 79 |
+
|
| 80 |
+
MTLCommandQueue_t stream() const { return _commandQueue; };
|
| 81 |
+
|
| 82 |
+
MTLDevice_t device() const { return [_commandQueue device];}
|
| 83 |
+
|
| 84 |
+
/// Explicit conversion to Stream.
|
| 85 |
+
Stream unwrap() const { return _stream; }
|
| 86 |
+
|
| 87 |
+
private:
|
| 88 |
+
Stream _stream;
|
| 89 |
+
MTLCommandQueue_t _commandQueue = nil;
|
| 90 |
+
MPSCommandBuffer* _commandBuffer = nil;
|
| 91 |
+
MPSCommandBuffer* _prevCommandBuffer = nil;
|
| 92 |
+
MTLComputeCommandEncoder_t _commandEncoder = nil;
|
| 93 |
+
MPSGraphExecutionDescriptor *_executionDescriptor = nil;
|
| 94 |
+
MPSGraphCompilationDescriptor *_compilationDescriptor = nil;
|
| 95 |
+
dispatch_queue_t _serialQueue = nullptr;
|
| 96 |
+
// CommitAndContinue is enabled by default
|
| 97 |
+
bool _enableCommitAndContinue = true;
|
| 98 |
+
|
| 99 |
+
// use synchronize() to access any of these commit functions outside MPSStream
|
| 100 |
+
void commit();
|
| 101 |
+
void commitAndWait();
|
| 102 |
+
void commitAndContinue();
|
| 103 |
+
void flush();
|
| 104 |
+
};
|
| 105 |
+
|
| 106 |
+
/**
|
| 107 |
+
* Get the current MPS stream
|
| 108 |
+
*/
|
| 109 |
+
TORCH_API MPSStream* getCurrentMPSStream();
|
| 110 |
+
|
| 111 |
+
/**
|
| 112 |
+
* Get the default MPS stream
|
| 113 |
+
*/
|
| 114 |
+
TORCH_API MPSStream* getDefaultMPSStream();
|
| 115 |
+
|
| 116 |
+
//-----------------------------------------------------------------
|
| 117 |
+
// MPSStreamImpl
|
| 118 |
+
//-----------------------------------------------------------------
|
| 119 |
+
|
| 120 |
+
class TORCH_API MPSStreamImpl
|
| 121 |
+
{
|
| 122 |
+
public:
|
| 123 |
+
/**
|
| 124 |
+
* Gets single instance of the MPSStream.
|
| 125 |
+
*/
|
| 126 |
+
static MPSStream* getInstance();
|
| 127 |
+
|
| 128 |
+
private:
|
| 129 |
+
static MPSStream* _stream;
|
| 130 |
+
MPSStreamImpl();
|
| 131 |
+
};
|
| 132 |
+
|
| 133 |
+
} // namespace at::mps
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/CatKernel.h
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/native/DispatchStub.h>
|
| 5 |
+
#include <ATen/core/IListRef.h>
|
| 6 |
+
|
| 7 |
+
namespace at::native {
|
| 8 |
+
|
| 9 |
+
using cat_serial_fn = void(*)(const Tensor &, const MaterializedITensorListRef&, int64_t);
|
| 10 |
+
DECLARE_DISPATCH(cat_serial_fn, cat_serial_stub);
|
| 11 |
+
|
| 12 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Loops.h
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// This file provides two functions to help write elementwise kernels:
|
| 4 |
+
//
|
| 5 |
+
// cpu_kernel(TensorIterator iter, <lambda>)
|
| 6 |
+
// cpu_kernel_vec(TensorIterator iter, <lambda>, <vec_lambda>)
|
| 7 |
+
//
|
| 8 |
+
// Both functions may generate vectorized code. The cpu_kernel implementation
|
| 9 |
+
// relies on the compiler's auto-vectorization. The cpu_kernel_vec
|
| 10 |
+
// implementation uses x86 SIMD intrinsics when available. These functions
|
| 11 |
+
// are only intended to be used in the ATen/native/cpu subdirectory, since files
|
| 12 |
+
// in other directories are not compiled with AVX/AVX2 enabled. See README.md
|
| 13 |
+
// for more details.
|
| 14 |
+
//
|
| 15 |
+
// For example, to write a multiplication kernel for float:
|
| 16 |
+
//
|
| 17 |
+
// cpu_kernel(iter, [](float a, float b) { return a * b; });
|
| 18 |
+
//
|
| 19 |
+
// Or you may write:
|
| 20 |
+
//
|
| 21 |
+
// cpu_kernel_vec(iter,
|
| 22 |
+
// [](float a, float b) { return a * b; },
|
| 23 |
+
// [](Vectorized<float> a, Vectorized<float> b) { return a * b; });
|
| 24 |
+
//
|
| 25 |
+
// See BinaryOpsKernel.cpp for the complete implementation
|
| 26 |
+
//
|
| 27 |
+
//
|
| 28 |
+
|
| 29 |
+
#include <cstdint>
|
| 30 |
+
#include <c10/util/C++17.h>
|
| 31 |
+
#include <c10/util/Load.h>
|
| 32 |
+
#include <c10/util/irange.h>
|
| 33 |
+
#include <ATen/detail/FunctionTraits.h>
|
| 34 |
+
#include <ATen/native/cpu/IsContiguous.h>
|
| 35 |
+
#include <ATen/native/TensorIterator.h>
|
| 36 |
+
#include <ATen/native/TensorIteratorDynamicCasting.h>
|
| 37 |
+
#include <ATen/cpu/vec/vec.h>
|
| 38 |
+
|
| 39 |
+
#include <utility>
|
| 40 |
+
|
| 41 |
+
namespace at::native { inline namespace CPU_CAPABILITY {
|
| 42 |
+
|
| 43 |
+
using namespace vec;
|
| 44 |
+
|
| 45 |
+
template <typename traits, std::size_t... INDEX>
|
| 46 |
+
typename traits::ArgsTuple
|
| 47 |
+
dereference_impl(char* C10_RESTRICT data[], const int64_t* strides, int64_t i,
|
| 48 |
+
std::index_sequence<INDEX...>) {
|
| 49 |
+
return std::make_tuple(
|
| 50 |
+
c10::load<typename traits::template arg<INDEX>::type>(
|
| 51 |
+
data[INDEX] + i * strides[INDEX])...);
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
template <typename traits>
|
| 55 |
+
typename traits::ArgsTuple
|
| 56 |
+
dereference(char* C10_RESTRICT data[], const int64_t* strides, int64_t i) {
|
| 57 |
+
using Indices = std::make_index_sequence<traits::arity>;
|
| 58 |
+
return dereference_impl<traits>(data, strides, i, Indices{});
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
template <typename traits, std::size_t... INDEX>
|
| 62 |
+
typename traits::ArgsTuple
|
| 63 |
+
dereference_vec_impl(char* C10_RESTRICT data[],
|
| 64 |
+
const typename traits::result_type& opt_scalar,
|
| 65 |
+
size_t S,
|
| 66 |
+
int64_t i,
|
| 67 |
+
std::index_sequence<INDEX...>) {
|
| 68 |
+
using Vec = typename traits::result_type;
|
| 69 |
+
using scalar_t = typename Vec::value_type;
|
| 70 |
+
return std::make_tuple(
|
| 71 |
+
S == INDEX + 1 ?
|
| 72 |
+
opt_scalar :
|
| 73 |
+
Vec::loadu(data[INDEX] + i * sizeof(scalar_t))...);
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
template <typename traits>
|
| 77 |
+
typename traits::ArgsTuple
|
| 78 |
+
dereference_vec(char* C10_RESTRICT data[], const typename traits::result_type& opt_scalar, size_t S, int64_t i) {
|
| 79 |
+
using Indices = std::make_index_sequence<traits::arity>;
|
| 80 |
+
return dereference_vec_impl<traits>(data, opt_scalar, S, i, Indices{});
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
template <typename func_t,
|
| 84 |
+
std::enable_if_t<!std::is_void_v<typename function_traits<func_t>::result_type>>* = nullptr>
|
| 85 |
+
inline void
|
| 86 |
+
execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) {
|
| 87 |
+
using traits = function_traits<func_t>;
|
| 88 |
+
using result_type = typename traits::result_type;
|
| 89 |
+
for (; i < n; i++) {
|
| 90 |
+
result_type* out_ptr = (result_type*)(data[0] + i * strides[0]);
|
| 91 |
+
*out_ptr = c10::guts::apply(op, dereference<traits>(
|
| 92 |
+
&data[1],
|
| 93 |
+
&strides[1],
|
| 94 |
+
i));
|
| 95 |
+
}
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
template <typename func_t,
|
| 99 |
+
std::enable_if_t<std::is_void_v<typename function_traits<func_t>::result_type>>* = nullptr>
|
| 100 |
+
inline void
|
| 101 |
+
execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) {
|
| 102 |
+
using traits = function_traits<func_t>;
|
| 103 |
+
for (; i < n; i++) {
|
| 104 |
+
c10::guts::apply(op, dereference<traits>(
|
| 105 |
+
&data[0],
|
| 106 |
+
&strides[0],
|
| 107 |
+
i));
|
| 108 |
+
}
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
// Basic loop operation (one output, N inputs). May be auto-vectorized
|
| 112 |
+
// by the compiler. Supports inputs and outputs of different types.
|
| 113 |
+
template <typename func_t>
|
| 114 |
+
inline void
|
| 115 |
+
basic_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) {
|
| 116 |
+
using traits = function_traits<func_t>;
|
| 117 |
+
constexpr int ntensors = traits::arity + 1;
|
| 118 |
+
|
| 119 |
+
// Copying strides to temporary array helps auto vectorization in older GCC
|
| 120 |
+
// versions.
|
| 121 |
+
int64_t strides[ntensors];
|
| 122 |
+
for (const auto arg : c10::irange(ntensors)) {
|
| 123 |
+
strides[arg] = strides_[arg];
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
execute_op(data, strides, i, n, std::forward<func_t>(op));
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
// the recursive variadic template for iterating over the returned tuple
|
| 130 |
+
template<class T, size_t N>
|
| 131 |
+
struct TupleOutput {
|
| 132 |
+
static void handle(char *C10_RESTRICT data[], const int64_t *strides, int64_t i,
|
| 133 |
+
const T &tuple) {
|
| 134 |
+
TupleOutput<T, N - 1>::handle(data, strides, i, tuple);
|
| 135 |
+
|
| 136 |
+
auto output = std::get<N - 1>(tuple);
|
| 137 |
+
using output_type = decltype(output);
|
| 138 |
+
output_type * out_ptr = (output_type *)(data[N - 1] + i * strides[N - 1]);
|
| 139 |
+
*out_ptr = output;
|
| 140 |
+
}
|
| 141 |
+
};
|
| 142 |
+
|
| 143 |
+
// Base case for the above recursive template
|
| 144 |
+
template<class T>
|
| 145 |
+
struct TupleOutput<T, 1> {
|
| 146 |
+
static void handle(char *C10_RESTRICT data[], const int64_t *strides, int64_t i,
|
| 147 |
+
const T &tuple) {
|
| 148 |
+
auto output = std::get<0>(tuple);
|
| 149 |
+
using output_type = decltype(output);
|
| 150 |
+
output_type* out_ptr = (output_type *)(data[0] + i * strides[0]);
|
| 151 |
+
*out_ptr = output;
|
| 152 |
+
}
|
| 153 |
+
};
|
| 154 |
+
|
| 155 |
+
template<class... Args>
|
| 156 |
+
void handle_tuple_outputs(char* C10_RESTRICT data[],
|
| 157 |
+
const int64_t* strides,
|
| 158 |
+
int64_t i,
|
| 159 |
+
const std::tuple<Args...> &tuple) {
|
| 160 |
+
TupleOutput<decltype(tuple), sizeof...(Args)>::handle(data, strides, i, tuple);
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
// Loop operation for `cpu_kernel_multiple_outputs`.
|
| 164 |
+
// 1. Use `c10::guts::apply` to make dynamic method invocation
|
| 165 |
+
// for the lambda passed in `cpu_kernel_multiple_outputs`.
|
| 166 |
+
// 2. Iterate over the members of the returned tuple, set the corresponding
|
| 167 |
+
// output tensor by the tuple member in `handle_tuple_outputs` function.
|
| 168 |
+
template <typename func_t>
|
| 169 |
+
inline void
|
| 170 |
+
multiple_outputs_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) {
|
| 171 |
+
using traits = function_traits<func_t>;
|
| 172 |
+
|
| 173 |
+
using result_type = typename traits::result_type;
|
| 174 |
+
constexpr int num_outputs = std::tuple_size<result_type>::value;
|
| 175 |
+
constexpr int ntensors = traits::arity + num_outputs;
|
| 176 |
+
|
| 177 |
+
// Copying strides to temporary array helps auto vectorization in older GCC
|
| 178 |
+
// versions.
|
| 179 |
+
int64_t strides[ntensors];
|
| 180 |
+
for (const auto arg : c10::irange(ntensors)) {
|
| 181 |
+
strides[arg] = strides_[arg];
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
for (; i < n; i++) {
|
| 185 |
+
auto output = c10::guts::apply(op, dereference<traits>(
|
| 186 |
+
&data[num_outputs],
|
| 187 |
+
&strides[num_outputs],
|
| 188 |
+
i));
|
| 189 |
+
handle_tuple_outputs(data, strides, i, output);
|
| 190 |
+
}
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
// Explicitly vectorized loop implementation. All inputs and outputs must be
|
| 194 |
+
// the same type and contiguous with one exception: a single input may be
|
| 195 |
+
// a scalar (stride 0). It's position is indicated by the argument `S`. If `S`
|
| 196 |
+
// is 0, then there are no scalar inputs.
|
| 197 |
+
template <typename func_t, typename vec_func_t>
|
| 198 |
+
inline void
|
| 199 |
+
vectorized_loop(char** C10_RESTRICT data_, int64_t n, int64_t S, func_t&& op, vec_func_t&& vop) {
|
| 200 |
+
using traits = function_traits<vec_func_t>;
|
| 201 |
+
using scalar_t = typename function_traits<func_t>::result_type;
|
| 202 |
+
using Vec = Vectorized<scalar_t>;
|
| 203 |
+
constexpr int ntensors = traits::arity + 1;
|
| 204 |
+
|
| 205 |
+
char* C10_RESTRICT data[ntensors];
|
| 206 |
+
for (const auto arg : c10::irange(ntensors)) {
|
| 207 |
+
data[arg] = data_[arg];
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
Vec opt_scalar = Vec(S > 0 ? *(scalar_t*)data[S] : scalar_t(0));
|
| 211 |
+
int64_t i = 0;
|
| 212 |
+
for (; i <= n - 2 * Vec::size(); i += 2 * Vec::size()) {
|
| 213 |
+
auto args1 = dereference_vec<traits>(&data[1], opt_scalar, S, i);
|
| 214 |
+
auto args2 = dereference_vec<traits>(&data[1], opt_scalar, S, i + Vec::size());
|
| 215 |
+
auto out1 = c10::guts::apply(vop, std::move(args1));
|
| 216 |
+
auto out2 = c10::guts::apply(vop, std::move(args2));
|
| 217 |
+
out1.store(data[0] + i * sizeof(scalar_t));
|
| 218 |
+
out2.store(data[0] + (i + Vec::size()) * sizeof(scalar_t));
|
| 219 |
+
}
|
| 220 |
+
if (i < n) {
|
| 221 |
+
int64_t strides[ntensors];
|
| 222 |
+
for (const auto arg : c10::irange(ntensors)) {
|
| 223 |
+
strides[arg] = (S > 0 && arg == S) ? 0 : sizeof(scalar_t);
|
| 224 |
+
}
|
| 225 |
+
basic_loop(data, strides, i, n, std::forward<func_t>(op));
|
| 226 |
+
}
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
template <typename traits, typename cb_t>
|
| 231 |
+
inline void unroll_contiguous_scalar_checks(
|
| 232 |
+
const int64_t* /*strides*/,
|
| 233 |
+
std::index_sequence<>,
|
| 234 |
+
cb_t&& cb) {
|
| 235 |
+
cb(0);
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
template <typename traits, typename cb_t, size_t INDEX0, size_t ...INDEX>
|
| 239 |
+
inline void unroll_contiguous_scalar_checks(
|
| 240 |
+
const int64_t* strides,
|
| 241 |
+
std::index_sequence<INDEX0, INDEX...>,
|
| 242 |
+
cb_t&& cb) {
|
| 243 |
+
if (is_contiguous_scalar<traits, INDEX0 + 1>(strides)) {
|
| 244 |
+
cb(INDEX0 + 1);
|
| 245 |
+
} else {
|
| 246 |
+
unroll_contiguous_scalar_checks<traits>(strides, std::index_sequence<INDEX...>{}, std::forward<cb_t>(cb));
|
| 247 |
+
}
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
template <typename op_t, typename vop_t>
|
| 251 |
+
struct VectorizedLoop2d {
|
| 252 |
+
op_t op;
|
| 253 |
+
vop_t vop;
|
| 254 |
+
|
| 255 |
+
using traits = function_traits<op_t>;
|
| 256 |
+
static constexpr int ntensors = traits::arity + 1;
|
| 257 |
+
using data_t = std::array<char*, ntensors>;
|
| 258 |
+
|
| 259 |
+
VectorizedLoop2d(op_t op, vop_t vop):
|
| 260 |
+
op(std::move(op)), vop(std::move(vop)) {}
|
| 261 |
+
|
| 262 |
+
static void advance(data_t &data, const int64_t *outer_strides) {
|
| 263 |
+
for (const auto arg : c10::irange(data.size())) {
|
| 264 |
+
data[arg] += outer_strides[arg];
|
| 265 |
+
}
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
void operator()(char** base, const int64_t *strides, int64_t size0, int64_t size1) {
|
| 269 |
+
data_t data;
|
| 270 |
+
std::copy_n(base, ntensors, data.data());
|
| 271 |
+
const int64_t *outer_strides = &strides[ntensors];
|
| 272 |
+
|
| 273 |
+
if (is_contiguous<traits>(strides)) {
|
| 274 |
+
for (const auto i C10_UNUSED : c10::irange(size1)) {
|
| 275 |
+
vectorized_loop(data.data(), size0, 0, op, vop);
|
| 276 |
+
advance(data, outer_strides);
|
| 277 |
+
}
|
| 278 |
+
} else {
|
| 279 |
+
using Indices = std::make_index_sequence<traits::arity>;
|
| 280 |
+
unroll_contiguous_scalar_checks<traits>(strides, Indices{}, [&](size_t idx) {
|
| 281 |
+
if (idx) {
|
| 282 |
+
for (const auto i C10_UNUSED : c10::irange(size1)) {
|
| 283 |
+
vectorized_loop(data.data(), size0, idx, op, vop);
|
| 284 |
+
advance(data, outer_strides);
|
| 285 |
+
}
|
| 286 |
+
} else {
|
| 287 |
+
for (const auto i C10_UNUSED : c10::irange(size1)) {
|
| 288 |
+
basic_loop(data.data(), strides, 0, size0, op);
|
| 289 |
+
advance(data, outer_strides);
|
| 290 |
+
}
|
| 291 |
+
}
|
| 292 |
+
});
|
| 293 |
+
}
|
| 294 |
+
}
|
| 295 |
+
};
|
| 296 |
+
|
| 297 |
+
template <typename op_t, typename vop_t>
|
| 298 |
+
VectorizedLoop2d<op_t, vop_t> make_vectorized_loop2d(
|
| 299 |
+
op_t &&op, vop_t &&vop) {
|
| 300 |
+
return VectorizedLoop2d<op_t, vop_t>(std::forward<op_t>(op), std::forward<vop_t>(vop));
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
template <typename func_t>
|
| 304 |
+
void cpu_kernel(TensorIteratorBase& iter, func_t&& op, int64_t grain_size = at::internal::GRAIN_SIZE) {
|
| 305 |
+
using traits = function_traits<func_t>;
|
| 306 |
+
// this could be extended to work with void return types
|
| 307 |
+
TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
|
| 308 |
+
TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
|
| 309 |
+
// dynamic casting not currently supported on CPU
|
| 310 |
+
TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
|
| 311 |
+
|
| 312 |
+
iter.for_each([&](char** data, const int64_t* strides, int64_t n) {
|
| 313 |
+
// basic loop can handle 1d slices with arbitrary strides, and 1d slices is all that
|
| 314 |
+
// iter.for_each is ever sending to the loop lambda
|
| 315 |
+
basic_loop(data, strides, 0, n, op);
|
| 316 |
+
}, grain_size);
|
| 317 |
+
iter.cast_outputs();
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
// This function helps write elementwise kernels that requires multiple outputs.
|
| 321 |
+
// It follows the similar structure of cpu_kernel.
|
| 322 |
+
// Instead of `basic_loop` function, a new `multiple_outputs_loop` function is
|
| 323 |
+
// manipulated to handle multiple return values.
|
| 324 |
+
// For now `needs_dynamic_casting` check is not added as the passed lambda (`func_t`)
|
| 325 |
+
// of `multiple_outputs_loop` returns `std::tuple` instead of `scalar_t`.
|
| 326 |
+
// The `gpu_kernel_multiple_outputs` is also implemented without this check,
|
| 327 |
+
// We could extend `needs_dynamic_casting` to support both `std::tuple` and
|
| 328 |
+
// `thrust::tuple` in the future.
|
| 329 |
+
template <typename func_t>
|
| 330 |
+
void cpu_kernel_multiple_outputs(TensorIteratorBase& iter, func_t&& op, int64_t grain_size = at::internal::GRAIN_SIZE) {
|
| 331 |
+
using traits = function_traits<func_t>;
|
| 332 |
+
TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
|
| 333 |
+
|
| 334 |
+
iter.for_each([&](char** data, const int64_t* strides, int64_t n) {
|
| 335 |
+
multiple_outputs_loop(data, strides, 0, n, op);
|
| 336 |
+
}, grain_size);
|
| 337 |
+
iter.cast_outputs();
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
template <bool check_dynamic_cast=true, typename func_t, typename vec_func_t>
|
| 341 |
+
void cpu_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop, int64_t grain_size = at::internal::GRAIN_SIZE) {
|
| 342 |
+
using traits = function_traits<func_t>;
|
| 343 |
+
// this could be extended to work with void return types
|
| 344 |
+
TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
|
| 345 |
+
TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
|
| 346 |
+
// dynamic casting not currently supported on CPU, but some kernels (like Fill)
|
| 347 |
+
// explicitly dynamic_cast, so we give the opt-out of checking.
|
| 348 |
+
if constexpr (check_dynamic_cast) {
|
| 349 |
+
TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
iter.for_each(make_vectorized_loop2d(std::forward<func_t>(op), std::forward<vec_func_t>(vop)), grain_size);
|
| 353 |
+
iter.cast_outputs();
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
template <typename func_t>
|
| 357 |
+
void cpu_serial_kernel(TensorIteratorBase& iter, func_t&& op, const Range& range) {
|
| 358 |
+
using traits = function_traits<func_t>;
|
| 359 |
+
constexpr bool result_void = std::is_void_v<typename traits::result_type>;
|
| 360 |
+
TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity &&
|
| 361 |
+
((result_void && iter.noutputs() == 0) || (!result_void && iter.noutputs() == 1)));
|
| 362 |
+
// dynamic casting not currently supported on CPU
|
| 363 |
+
TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
|
| 364 |
+
|
| 365 |
+
iter.serial_for_each([&](char** data, const int64_t* strides, int64_t n) {
|
| 366 |
+
basic_loop(data, strides, 0, n, op);
|
| 367 |
+
}, range);
|
| 368 |
+
iter.cast_outputs();
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
template <typename func_t>
|
| 372 |
+
void cpu_serial_kernel(TensorIteratorBase& iter, func_t&& op) {
|
| 373 |
+
cpu_serial_kernel(iter, std::forward<func_t>(op), {0, iter.numel()});
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
template <typename func_t, typename vec_func_t>
|
| 377 |
+
void cpu_serial_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop, const Range& range) {
|
| 378 |
+
using traits = function_traits<func_t>;
|
| 379 |
+
// this could be extended to work with void return types
|
| 380 |
+
TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
|
| 381 |
+
TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
|
| 382 |
+
// dynamic casting not currently supported on CPU
|
| 383 |
+
TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
|
| 384 |
+
|
| 385 |
+
iter.serial_for_each(make_vectorized_loop2d(std::forward<func_t>(op), std::forward<vec_func_t>(vop)), range);
|
| 386 |
+
iter.cast_outputs();
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
template <typename func_t, typename vec_func_t>
|
| 390 |
+
void cpu_serial_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop) {
|
| 391 |
+
cpu_serial_kernel_vec(iter, std::forward<func_t>(op), std::forward<vec_func_t>(vop), {0, iter.numel()});
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
}} // namespace at::native::<anonymous>
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/ReduceUtils.h
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/Parallel.h>
|
| 4 |
+
#include <ATen/NumericUtils.h>
|
| 5 |
+
#include <ATen/cpu/vec/vec.h>
|
| 6 |
+
#include <ATen/cpu/vec/functional.h>
|
| 7 |
+
#include <ATen/native/ReductionType.h>
|
| 8 |
+
#include <c10/util/irange.h>
|
| 9 |
+
#include <ATen/OpMathType.h>
|
| 10 |
+
#include <ATen/native/cpu/utils.h>
|
| 11 |
+
#include <ATen/OpMathType.h>
|
| 12 |
+
|
| 13 |
+
namespace at::native {
|
| 14 |
+
inline namespace CPU_CAPABILITY {
|
| 15 |
+
|
| 16 |
+
using namespace vec;
|
| 17 |
+
|
| 18 |
+
#define AT_DISPATCH_REDUCTION_TYPES(op, ...) \
|
| 19 |
+
[&] { \
|
| 20 |
+
switch (op) { \
|
| 21 |
+
case ReductionType::SUM: { \
|
| 22 |
+
static constexpr auto reduce = ReductionType::SUM; \
|
| 23 |
+
return __VA_ARGS__(); \
|
| 24 |
+
} \
|
| 25 |
+
case ReductionType::MEAN: { \
|
| 26 |
+
static constexpr auto reduce = ReductionType::MEAN; \
|
| 27 |
+
return __VA_ARGS__(); \
|
| 28 |
+
} \
|
| 29 |
+
case ReductionType::MIN: { \
|
| 30 |
+
static constexpr auto reduce = ReductionType::MIN; \
|
| 31 |
+
return __VA_ARGS__(); \
|
| 32 |
+
} \
|
| 33 |
+
case ReductionType::MAX: { \
|
| 34 |
+
static constexpr auto reduce = ReductionType::MAX; \
|
| 35 |
+
return __VA_ARGS__(); \
|
| 36 |
+
} \
|
| 37 |
+
case ReductionType::PROD: { \
|
| 38 |
+
static constexpr auto reduce = ReductionType::PROD; \
|
| 39 |
+
return __VA_ARGS__(); \
|
| 40 |
+
} \
|
| 41 |
+
} \
|
| 42 |
+
}()
|
| 43 |
+
|
| 44 |
+
template <typename scalar_t, ReductionType reduce>
|
| 45 |
+
inline vec_scalar_t<scalar_t> init_value() {
|
| 46 |
+
using acc_t = vec_scalar_t<scalar_t>;
|
| 47 |
+
acc_t val;
|
| 48 |
+
if (reduce == ReductionType::SUM ||
|
| 49 |
+
reduce == ReductionType::MEAN) {
|
| 50 |
+
val = static_cast<acc_t>(0);
|
| 51 |
+
} else if (reduce == ReductionType::PROD) {
|
| 52 |
+
val = static_cast<acc_t>(1);
|
| 53 |
+
} else if (reduce == ReductionType::MAX) {
|
| 54 |
+
val = -std::numeric_limits<acc_t>::infinity();
|
| 55 |
+
} else {
|
| 56 |
+
TORCH_INTERNAL_ASSERT(reduce == ReductionType::MIN);
|
| 57 |
+
val = std::numeric_limits<acc_t>::infinity();
|
| 58 |
+
}
|
| 59 |
+
return val;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
template <typename scalar_t, ReductionType reduce>
|
| 63 |
+
inline vec_scalar_t<scalar_t> init_value(const std::optional<Scalar>& initial) {
|
| 64 |
+
using acc_t = vec_scalar_t<scalar_t>;
|
| 65 |
+
if (initial.has_value()) {
|
| 66 |
+
return initial.value().to<acc_t>();
|
| 67 |
+
} else {
|
| 68 |
+
return init_value<scalar_t, reduce>();
|
| 69 |
+
}
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
template <typename scalar_t>
|
| 73 |
+
inline void init(scalar_t* out, int64_t size, const vec_scalar_t<scalar_t>& val) {
|
| 74 |
+
using Vec = Vectorized<vec_scalar_t<scalar_t>>;
|
| 75 |
+
map<scalar_t>(
|
| 76 |
+
[val](Vec x) { return Vec(val); },
|
| 77 |
+
out,
|
| 78 |
+
out,
|
| 79 |
+
size);
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
template <typename scalar_t, ReductionType reduce>
|
| 83 |
+
inline void init(scalar_t* out, int64_t size, const std::optional<Scalar>& initial) {
|
| 84 |
+
using acc_t = vec_scalar_t<scalar_t>;
|
| 85 |
+
acc_t val = init_value<scalar_t, reduce>(initial);
|
| 86 |
+
init(out, size, val);
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
// overload with `include_self`, used by scatter_reduce
|
| 90 |
+
template <typename scalar_t, ReductionType reduce>
|
| 91 |
+
inline void init(scalar_t* out, int64_t size, bool include_self = false) {
|
| 92 |
+
using acc_t = vec_scalar_t<scalar_t>;
|
| 93 |
+
if (!include_self) {
|
| 94 |
+
acc_t val = init_value<scalar_t, reduce>();
|
| 95 |
+
init(out, size, val);
|
| 96 |
+
}
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
template <typename scalar_t, ReductionType reduce>
|
| 100 |
+
inline void _init(scalar_t* self_ptr, at::opmath_type<scalar_t>* buffer_ptr, int64_t size, bool include_self) {
|
| 101 |
+
if (!include_self) {
|
| 102 |
+
init<at::opmath_type<scalar_t>, reduce>(buffer_ptr, size, include_self);
|
| 103 |
+
} else {
|
| 104 |
+
vec::convert(self_ptr, buffer_ptr, size);
|
| 105 |
+
}
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
template <typename scalar_t>
|
| 109 |
+
inline typename std::enable_if<!std::is_same<scalar_t, Vec2>::value, scalar_t>::type
|
| 110 |
+
_max(const scalar_t& x, const scalar_t& y) {
|
| 111 |
+
return at::_isnan(y) ? y : std::max(x, y);
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
template <typename scalar_t>
|
| 115 |
+
inline Vectorized<scalar_t> _max(const Vectorized<scalar_t>& x, const Vectorized<scalar_t>& y) {
|
| 116 |
+
// vec::maximum propagates NaN
|
| 117 |
+
return vec::maximum(x, y);
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
template <typename vec_t>
|
| 121 |
+
inline typename std::enable_if<std::is_same<vec_t, Vec2>::value, Vec2>::type
|
| 122 |
+
_max(const vec_t& x, const vec_t& y) {
|
| 123 |
+
// vec::maximum propagates NaN
|
| 124 |
+
return maximum(x, y);
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
template <typename scalar_t>
|
| 128 |
+
inline typename std::enable_if<!std::is_same<scalar_t, Vec2>::value, scalar_t>::type
|
| 129 |
+
_min(const scalar_t& x, const scalar_t& y) {
|
| 130 |
+
return at::_isnan(y) ? y : std::min(x, y);
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
template <typename scalar_t>
|
| 134 |
+
inline Vectorized<scalar_t> _min(const Vectorized<scalar_t>& x, const Vectorized<scalar_t>& y) {
|
| 135 |
+
// vec::minimum propagates NaN
|
| 136 |
+
return vec::minimum(x, y);
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
template <typename vec_t>
|
| 140 |
+
inline typename std::enable_if<std::is_same<vec_t, Vec2>::value, Vec2>::type
|
| 141 |
+
_min(const vec_t& x, const vec_t& y) {
|
| 142 |
+
// vec::minimum propagates NaN
|
| 143 |
+
return minimum(x, y);
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
template <typename scalar_t, typename accumut, typename Op,
|
| 147 |
+
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
| 148 |
+
inline void map_acc(
|
| 149 |
+
const Op& vec_fun,
|
| 150 |
+
accumut* output_data,
|
| 151 |
+
const accumut* input_data,
|
| 152 |
+
const scalar_t* input_data2,
|
| 153 |
+
int64_t size) {
|
| 154 |
+
using Vec = vec::Vectorized<scalar_t>;
|
| 155 |
+
using aVec = vec::Vectorized<accumut>;
|
| 156 |
+
int64_t d = 0;
|
| 157 |
+
constexpr int64_t kVecSize = Vec::size();
|
| 158 |
+
constexpr int64_t kaVecSize = aVec::size();
|
| 159 |
+
for (d = 0; d < size - (size % kVecSize); d += kVecSize) {
|
| 160 |
+
Vec data2_vec = Vec::loadu(input_data2 + d);
|
| 161 |
+
auto [data2_avec0, data2_avec1] = convert_to_float<scalar_t>(data2_vec);
|
| 162 |
+
aVec input_vec0 = aVec::loadu(input_data + d);
|
| 163 |
+
aVec input_vec1 = aVec::loadu(input_data + d + kaVecSize);
|
| 164 |
+
vec_fun(input_vec0, data2_avec0).store(output_data + d);
|
| 165 |
+
vec_fun(input_vec1, data2_avec1).store(output_data + d + kaVecSize);
|
| 166 |
+
}
|
| 167 |
+
if (size - d > 0) {
|
| 168 |
+
int64_t tail_size = size - d;
|
| 169 |
+
Vec data2_vec = Vec::loadu(input_data2 + d, tail_size);
|
| 170 |
+
auto [data2_avec0, data2_avec1] = convert_to_float<scalar_t>(data2_vec);
|
| 171 |
+
if (tail_size > kaVecSize) {
|
| 172 |
+
aVec input_vec0 = aVec::loadu(input_data + d);
|
| 173 |
+
aVec input_vec1 = aVec::loadu(input_data + d + kaVecSize, tail_size - kaVecSize);
|
| 174 |
+
vec_fun(input_vec0, data2_avec0).store(output_data + d);
|
| 175 |
+
vec_fun(input_vec1, data2_avec1).store(output_data + d + kaVecSize, tail_size - kaVecSize);
|
| 176 |
+
} else {
|
| 177 |
+
aVec input_vec0 = aVec::loadu(input_data + d, tail_size);
|
| 178 |
+
vec_fun(input_vec0, data2_avec0).store(output_data + d, tail_size);
|
| 179 |
+
}
|
| 180 |
+
}
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
// for Max and Min, propagate NaN:
|
| 184 |
+
template <typename T, ReductionType reduce>
|
| 185 |
+
inline T update(const T& x, const T& y) {
|
| 186 |
+
if (reduce == ReductionType::SUM ||
|
| 187 |
+
reduce == ReductionType::MEAN) {
|
| 188 |
+
return x + y;
|
| 189 |
+
} else if (reduce == ReductionType::PROD) {
|
| 190 |
+
return x * y;
|
| 191 |
+
} else if (reduce == ReductionType::MAX) {
|
| 192 |
+
return _max(x, y);
|
| 193 |
+
} else {
|
| 194 |
+
TORCH_INTERNAL_ASSERT(reduce == ReductionType::MIN);
|
| 195 |
+
return _min(x, y);
|
| 196 |
+
}
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
template <typename scalar_t, ReductionType reduce>
|
| 200 |
+
inline void update(scalar_t* out, const scalar_t* data, int64_t K) {
|
| 201 |
+
using Vec = vec::Vectorized<vec_scalar_t<scalar_t>>;
|
| 202 |
+
map2<scalar_t>(
|
| 203 |
+
[](Vec x, Vec y) { return update<Vec, reduce>(x, y); },
|
| 204 |
+
out,
|
| 205 |
+
out,
|
| 206 |
+
data,
|
| 207 |
+
K);
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
template <typename scalar_t, ReductionType reduce,
|
| 211 |
+
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
| 212 |
+
inline void update(at::opmath_type<scalar_t>* out, const scalar_t* data, int64_t K) {
|
| 213 |
+
using opmath_t = at::opmath_type<scalar_t>;
|
| 214 |
+
using Vec = vec::Vectorized<opmath_t>;
|
| 215 |
+
map_acc<scalar_t, opmath_t>(
|
| 216 |
+
[](Vec x, Vec y) { return update<Vec, reduce>(x, y); },
|
| 217 |
+
out,
|
| 218 |
+
out,
|
| 219 |
+
data,
|
| 220 |
+
K);
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
template <typename scalar_t, ReductionType reduce>
|
| 224 |
+
inline void write(scalar_t* out, int64_t count, int64_t K) {
|
| 225 |
+
using Vec = vec::Vectorized<vec_scalar_t<scalar_t>>;
|
| 226 |
+
if (reduce == ReductionType::MEAN) {
|
| 227 |
+
if (count > 0) {
|
| 228 |
+
vec::map<scalar_t>(
|
| 229 |
+
[count](Vec x) { return x / Vec(count); },
|
| 230 |
+
out,
|
| 231 |
+
out,
|
| 232 |
+
K);
|
| 233 |
+
}
|
| 234 |
+
}
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
} // namespace CPU_CAPABILITY
|
| 238 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/AffineQuantizer.h
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/Dispatch.h>
|
| 5 |
+
#include <ATen/native/DispatchStub.h>
|
| 6 |
+
#include <ATen/native/quantized/AffineQuantizerBase.h>
|
| 7 |
+
|
| 8 |
+
namespace at {
|
| 9 |
+
namespace native {
|
| 10 |
+
|
| 11 |
+
Tensor& quantize_tensor_per_tensor_affine(
|
| 12 |
+
const Tensor& rtensor,
|
| 13 |
+
Tensor& qtensor,
|
| 14 |
+
double scale,
|
| 15 |
+
int64_t zero_point);
|
| 16 |
+
Tensor& quantize_tensor_per_channel_affine(
|
| 17 |
+
const Tensor& rtensor,
|
| 18 |
+
Tensor& qtensor,
|
| 19 |
+
const Tensor& scales,
|
| 20 |
+
Tensor zero_points,
|
| 21 |
+
int64_t axis);
|
| 22 |
+
|
| 23 |
+
Tensor& quantize_tensor_per_channel_float_qparams(
|
| 24 |
+
const Tensor& rtensor,
|
| 25 |
+
Tensor& qtensor,
|
| 26 |
+
const Tensor& scales,
|
| 27 |
+
const Tensor& zero_points,
|
| 28 |
+
int64_t axis);
|
| 29 |
+
|
| 30 |
+
Tensor& dequantize_tensor_per_tensor_affine(
|
| 31 |
+
const Tensor& qtensor,
|
| 32 |
+
Tensor& rtensor,
|
| 33 |
+
double scale,
|
| 34 |
+
int64_t zero_point);
|
| 35 |
+
Tensor& dequantize_tensor_per_channel_affine(
|
| 36 |
+
const Tensor& qtensor,
|
| 37 |
+
Tensor& rtensor,
|
| 38 |
+
const Tensor& scales,
|
| 39 |
+
Tensor zero_points,
|
| 40 |
+
int64_t axis);
|
| 41 |
+
Tensor& dequantize_tensor_per_channel_float_qparams(
|
| 42 |
+
const Tensor& qtensor,
|
| 43 |
+
Tensor& rtensor,
|
| 44 |
+
const Tensor& scales,
|
| 45 |
+
const Tensor& zero_points,
|
| 46 |
+
int64_t axis);
|
| 47 |
+
|
| 48 |
+
using quantize_tensor_per_tensor_affine_fn =
|
| 49 |
+
void (*)(const Tensor& rtensor, Tensor& qtensor, double scale, int64_t zero_point);
|
| 50 |
+
|
| 51 |
+
using quantize_tensor_per_channel_affine_fn = void (*)(
|
| 52 |
+
const Tensor& rtensor,
|
| 53 |
+
Tensor& qtensor,
|
| 54 |
+
const Tensor& scales,
|
| 55 |
+
const Tensor& zero_points,
|
| 56 |
+
int64_t axis);
|
| 57 |
+
|
| 58 |
+
using quantize_tensor_per_channel_float_qparams_fn = void (*)(
|
| 59 |
+
const Tensor& rtensor,
|
| 60 |
+
Tensor& qtensor,
|
| 61 |
+
const Tensor& scales,
|
| 62 |
+
const Tensor& zero_points,
|
| 63 |
+
int64_t axis);
|
| 64 |
+
|
| 65 |
+
using dequantize_tensor_per_tensor_affine_fn =
|
| 66 |
+
void (*)(const Tensor& qtensor, Tensor& rtensor, double scale, int64_t zero_point);
|
| 67 |
+
|
| 68 |
+
using dequantize_tensor_per_channel_affine_fn = void (*)(
|
| 69 |
+
const Tensor& qtensor,
|
| 70 |
+
Tensor& rtensor,
|
| 71 |
+
const Tensor& scales,
|
| 72 |
+
const Tensor& zero_points,
|
| 73 |
+
int64_t axis);
|
| 74 |
+
|
| 75 |
+
using dequantize_tensor_per_channel_float_qparams_fn = void (*)(
|
| 76 |
+
const Tensor& qtensor,
|
| 77 |
+
Tensor& rtensor,
|
| 78 |
+
const Tensor& scales,
|
| 79 |
+
const Tensor& zero_points,
|
| 80 |
+
int64_t axis);
|
| 81 |
+
|
| 82 |
+
using quantize_tensor_per_tensor_affine_sub_byte_fn =
|
| 83 |
+
void (*)(const Tensor& rtensor, Tensor& qtensor, float scale, float zero_point);
|
| 84 |
+
|
| 85 |
+
using dequantize_tensor_per_tensor_affine_sub_byte_fn =
|
| 86 |
+
void (*)(const Tensor& qtensor, Tensor& rtensor, float scale, float zero_point);
|
| 87 |
+
|
| 88 |
+
DECLARE_DISPATCH(
|
| 89 |
+
quantize_tensor_per_tensor_affine_fn,
|
| 90 |
+
quantize_tensor_per_tensor_affine_stub);
|
| 91 |
+
DECLARE_DISPATCH(
|
| 92 |
+
quantize_tensor_per_channel_affine_fn,
|
| 93 |
+
quantize_tensor_per_channel_affine_stub);
|
| 94 |
+
DECLARE_DISPATCH(
|
| 95 |
+
quantize_tensor_per_channel_float_qparams_fn,
|
| 96 |
+
quantize_tensor_per_channel_float_qparams_stub);
|
| 97 |
+
|
| 98 |
+
DECLARE_DISPATCH(
|
| 99 |
+
dequantize_tensor_per_tensor_affine_fn,
|
| 100 |
+
dequantize_tensor_per_tensor_affine_stub);
|
| 101 |
+
DECLARE_DISPATCH(
|
| 102 |
+
dequantize_tensor_per_channel_affine_fn,
|
| 103 |
+
dequantize_tensor_per_channel_affine_stub);
|
| 104 |
+
DECLARE_DISPATCH(
|
| 105 |
+
dequantize_tensor_per_channel_float_qparams_fn,
|
| 106 |
+
dequantize_tensor_per_channel_float_qparams_stub);
|
| 107 |
+
|
| 108 |
+
DECLARE_DISPATCH(
|
| 109 |
+
quantize_tensor_per_tensor_affine_sub_byte_fn,
|
| 110 |
+
quantize_tensor_per_tensor_affine_sub_byte_stub);
|
| 111 |
+
|
| 112 |
+
DECLARE_DISPATCH(
|
| 113 |
+
dequantize_tensor_per_tensor_affine_sub_byte_fn,
|
| 114 |
+
dequantize_tensor_per_tensor_affine_sub_byte_stub);
|
| 115 |
+
|
| 116 |
+
template <typename T>
|
| 117 |
+
TORCH_API Tensor quantize_tensor(
|
| 118 |
+
Tensor rtensor,
|
| 119 |
+
Tensor qtensor,
|
| 120 |
+
double scale,
|
| 121 |
+
int64_t zero_point);
|
| 122 |
+
template <typename T>
|
| 123 |
+
TORCH_API Tensor dequantize_tensor(
|
| 124 |
+
Tensor qtensor,
|
| 125 |
+
Tensor rtensor,
|
| 126 |
+
double scale,
|
| 127 |
+
int64_t zero_point);
|
| 128 |
+
|
| 129 |
+
} // namespace native
|
| 130 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/AffineQuantizerBase.h
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <c10/macros/Export.h>
|
| 3 |
+
#include <c10/core/ScalarType.h>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
namespace native {
|
| 7 |
+
|
| 8 |
+
// Quantize a float value into a uint value given scale and zero_point
|
| 9 |
+
template <typename T>
|
| 10 |
+
TORCH_API T quantize_val(double scale, int64_t zero_point, float value);
|
| 11 |
+
// TODO combine this with quantize_val once the numerics for ARM are aligned
|
| 12 |
+
// with it
|
| 13 |
+
template <typename T>
|
| 14 |
+
T quantize_val_arm(
|
| 15 |
+
const float scale,
|
| 16 |
+
const int32_t zero_point,
|
| 17 |
+
const float value);
|
| 18 |
+
template <typename T, int precision = 8>
|
| 19 |
+
void quantize_vec(
|
| 20 |
+
double scale,
|
| 21 |
+
int64_t zero_point,
|
| 22 |
+
const float* src,
|
| 23 |
+
T* dst,
|
| 24 |
+
size_t count = 8);
|
| 25 |
+
template <typename T>
|
| 26 |
+
TORCH_API float dequantize_val(double scale, int64_t zero_point, T value);
|
| 27 |
+
template <typename T>
|
| 28 |
+
TORCH_API float dequantize_vec(
|
| 29 |
+
double scale,
|
| 30 |
+
int64_t zero_point,
|
| 31 |
+
const T* src,
|
| 32 |
+
float* dst,
|
| 33 |
+
size_t count = 8);
|
| 34 |
+
template <typename SRC_T, typename DST_T>
|
| 35 |
+
TORCH_API DST_T requantize_val(double, int64_t, double, int64_t, SRC_T src);
|
| 36 |
+
|
| 37 |
+
// Given a multiplier and a zero_point, requantize int32_t computed values back
|
| 38 |
+
// to quantized values. See comment above
|
| 39 |
+
// make_per_tensor_affine_quantizer function for the usage of int64_t
|
| 40 |
+
template <typename DST_T>
|
| 41 |
+
TORCH_API DST_T
|
| 42 |
+
requantize_from_int(double multiplier, int64_t zero_point, int64_t src);
|
| 43 |
+
|
| 44 |
+
int quantize_val_float_qparams(float scale, float zero_point, float value, int qmin, int qmax);
|
| 45 |
+
|
| 46 |
+
} // namespace native
|
| 47 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/ConvUtils.h
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/List.h>
|
| 3 |
+
#include <ATen/native/ConvUtils.h>
|
| 4 |
+
|
| 5 |
+
namespace at::native::quantized {
|
| 6 |
+
namespace {
|
| 7 |
+
// MakeConvOutputShape used from both CPU and CUDA libraries
|
| 8 |
+
// and exporting symbol from torch_cpu would probably take more storage
|
| 9 |
+
// than duplicating implementation which likely be inlined away
|
| 10 |
+
template <int kSpatialDim>
|
| 11 |
+
at::SmallVector<int64_t, kSpatialDim + 2> MakeConvOutputShape(
|
| 12 |
+
int N, // mini-batch
|
| 13 |
+
int M, // output channels
|
| 14 |
+
const std::array<int64_t, kSpatialDim>& input_image_shape,
|
| 15 |
+
const std::vector<int64_t>& kernel,
|
| 16 |
+
const torch::List<int64_t>& stride,
|
| 17 |
+
const torch::List<int64_t>& padding,
|
| 18 |
+
const torch::List<int64_t>& dilation);
|
| 19 |
+
|
| 20 |
+
#if defined(USE_CUDA) || defined(USE_PYTORCH_QNNPACK)
|
| 21 |
+
template <>
|
| 22 |
+
at::SmallVector<int64_t, 4> MakeConvOutputShape<2>(
|
| 23 |
+
int N, // mini-batch
|
| 24 |
+
int M, // output channels
|
| 25 |
+
const std::array<int64_t, 2>& input_image_shape,
|
| 26 |
+
const std::vector<int64_t>& kernel,
|
| 27 |
+
const at::List<int64_t>& stride,
|
| 28 |
+
const at::List<int64_t>& padding,
|
| 29 |
+
const at::List<int64_t>& dilation) {
|
| 30 |
+
const int H = input_image_shape[0];
|
| 31 |
+
const int W = input_image_shape[1];
|
| 32 |
+
const int64_t Y_H =
|
| 33 |
+
(H + 2 * padding[0] - dilation[0] * (kernel[0] - 1) - 1) / stride[0] + 1;
|
| 34 |
+
const int64_t Y_W =
|
| 35 |
+
(W + 2 * padding[1] - dilation[1] * (kernel[1] - 1) - 1) / stride[1] + 1;
|
| 36 |
+
return {N, M, Y_H, Y_W};
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
template <>
|
| 40 |
+
at::SmallVector<int64_t, 5> MakeConvOutputShape<3>(
|
| 41 |
+
int N, // mini-batch
|
| 42 |
+
int M, // output channels
|
| 43 |
+
const std::array<int64_t, 3>& input_image_shape,
|
| 44 |
+
const std::vector<int64_t>& kernel,
|
| 45 |
+
const at::List<int64_t>& stride,
|
| 46 |
+
const at::List<int64_t>& padding,
|
| 47 |
+
const torch::List<int64_t>& dilation) {
|
| 48 |
+
const int D = input_image_shape[0];
|
| 49 |
+
const int H = input_image_shape[1];
|
| 50 |
+
const int W = input_image_shape[2];
|
| 51 |
+
const int64_t Y_D =
|
| 52 |
+
(D + 2 * padding[0] - dilation[0] * (kernel[0] - 1) - 1) / stride[0] + 1;
|
| 53 |
+
const int64_t Y_H =
|
| 54 |
+
(H + 2 * padding[1] - dilation[1] * (kernel[1] - 1) - 1) / stride[1] + 1;
|
| 55 |
+
const int64_t Y_W =
|
| 56 |
+
(W + 2 * padding[2] - dilation[2] * (kernel[2] - 1) - 1) / stride[2] + 1;
|
| 57 |
+
return {N, M, Y_D, Y_H, Y_W};
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
#endif
|
| 61 |
+
} // anonymous namespace
|
| 62 |
+
} // namespace at::native::quantized
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/Copy.h
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
namespace native {
|
| 7 |
+
|
| 8 |
+
Tensor& quantized_copy_from_float_(Tensor& self, const Tensor& src);
|
| 9 |
+
}
|
| 10 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/FakeQuantAffine.h
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/Dispatch.h>
|
| 5 |
+
#include <ATen/native/DispatchStub.h>
|
| 6 |
+
|
| 7 |
+
namespace at {
|
| 8 |
+
|
| 9 |
+
struct TensorIterator;
|
| 10 |
+
|
| 11 |
+
namespace native {
|
| 12 |
+
|
| 13 |
+
using fake_quant_tensor_cachemask_fn = void (*)(
|
| 14 |
+
Tensor& output,
|
| 15 |
+
Tensor& mask,
|
| 16 |
+
const Tensor& input,
|
| 17 |
+
float sc,
|
| 18 |
+
int64_t z_point,
|
| 19 |
+
int64_t quant_min,
|
| 20 |
+
int64_t quant_max);
|
| 21 |
+
|
| 22 |
+
using fake_quant_tensor_cachemask_tensor_qparams_fn = void (*)(
|
| 23 |
+
Tensor& output,
|
| 24 |
+
Tensor& mask,
|
| 25 |
+
const Tensor& input,
|
| 26 |
+
const Tensor& sc,
|
| 27 |
+
const Tensor& z_point,
|
| 28 |
+
const Tensor& fake_quant_enabled,
|
| 29 |
+
int64_t quant_min,
|
| 30 |
+
int64_t quant_max);
|
| 31 |
+
|
| 32 |
+
using fake_quant_learnable_grad_tensor_fn = void (*)(
|
| 33 |
+
TensorIterator& iter,
|
| 34 |
+
float scale,
|
| 35 |
+
float inv_scale,
|
| 36 |
+
int64_t zero_point,
|
| 37 |
+
int64_t quant_min,
|
| 38 |
+
int64_t quant_max,
|
| 39 |
+
float grad_factor);
|
| 40 |
+
|
| 41 |
+
DECLARE_DISPATCH(fake_quant_tensor_cachemask_fn, fake_quant_tensor_cachemask_stub);
|
| 42 |
+
DECLARE_DISPATCH(fake_quant_tensor_cachemask_tensor_qparams_fn, fake_quant_tensor_cachemask_tensor_qparams_stub);
|
| 43 |
+
DECLARE_DISPATCH(fake_quant_learnable_grad_tensor_fn, fake_quant_grad_learnable_tensor_stub);
|
| 44 |
+
|
| 45 |
+
using fake_quant_per_channel_fn = void (*)(
|
| 46 |
+
TensorIterator &iter,
|
| 47 |
+
int64_t quant_min,
|
| 48 |
+
int64_t quant_max);
|
| 49 |
+
|
| 50 |
+
using fake_quant_per_channel_cachemask_fn = void (*)(
|
| 51 |
+
TensorIterator &iter,
|
| 52 |
+
TensorIterator &iter_mask,
|
| 53 |
+
int64_t quant_min,
|
| 54 |
+
int64_t quant_max);
|
| 55 |
+
|
| 56 |
+
DECLARE_DISPATCH(fake_quant_per_channel_cachemask_fn, fake_quant_per_channel_cachemask_stub);
|
| 57 |
+
|
| 58 |
+
using fake_quant_learnable_per_channel_fn = void (*)(
|
| 59 |
+
TensorIterator &iter,
|
| 60 |
+
int64_t quant_min,
|
| 61 |
+
int64_t quant_max,
|
| 62 |
+
float grad_factor);
|
| 63 |
+
|
| 64 |
+
DECLARE_DISPATCH(fake_quant_learnable_per_channel_fn, fake_quant_grad_learnable_channel_stub);
|
| 65 |
+
|
| 66 |
+
} // namespace native
|
| 67 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/IndexKernel.h
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/native/TensorIterator.h>
|
| 3 |
+
|
| 4 |
+
namespace at {
|
| 5 |
+
namespace native {
|
| 6 |
+
using masked_fill_kernel_quantized_fn = void(*)(TensorIterator& iter, const Scalar& value, double scale, int zero_point);
|
| 7 |
+
using index_put_kernel_quantized_fn = void(*)(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate, double scale, int zero_point);
|
| 8 |
+
|
| 9 |
+
DECLARE_DISPATCH(masked_fill_kernel_quantized_fn, masked_fill_kernel_quantized_stub);
|
| 10 |
+
DECLARE_DISPATCH(index_put_kernel_quantized_fn, index_put_kernel_quantized_stub);
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
} // native
|
| 14 |
+
} // at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/PackedParams.h
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/core/ivalue.h>
|
| 5 |
+
|
| 6 |
+
struct LinearPackedParamsBase : public torch::jit::CustomClassHolder {
|
| 7 |
+
virtual at::Tensor apply(
|
| 8 |
+
at::Tensor input,
|
| 9 |
+
double output_scale,
|
| 10 |
+
int64_t output_zero_point) = 0;
|
| 11 |
+
virtual at::Tensor apply_relu(
|
| 12 |
+
at::Tensor input,
|
| 13 |
+
double output_scale,
|
| 14 |
+
int64_t output_zero_point) = 0;
|
| 15 |
+
|
| 16 |
+
// out variant of LinearPackedParamsBase::apply
|
| 17 |
+
virtual at::Tensor& apply_out(
|
| 18 |
+
const at::Tensor& /*input*/,
|
| 19 |
+
double /*output_scale*/,
|
| 20 |
+
int64_t /*output_zero_point*/,
|
| 21 |
+
at::Tensor& output) {
|
| 22 |
+
throw std::runtime_error(
|
| 23 |
+
"apply_out is not implemented for this packed "
|
| 24 |
+
"parameter type");
|
| 25 |
+
return output;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
virtual at::Tensor& apply_relu_out(
|
| 29 |
+
const at::Tensor& /*input*/,
|
| 30 |
+
double /*output_scale*/,
|
| 31 |
+
int64_t /*output_zero_point*/,
|
| 32 |
+
at::Tensor& output) {
|
| 33 |
+
throw std::runtime_error(
|
| 34 |
+
"apply_relu_out is not implemented for this packed "
|
| 35 |
+
"parameter type");
|
| 36 |
+
return output;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
// Corresponding pattern (the ops with `*` are part of the pattern that
|
| 40 |
+
// represents the computation of quantized::linear_with_input_q_dq_qweight_dq_output_fp32):
|
| 41 |
+
// input -> q* -> dq* -> linear* ->
|
| 42 |
+
// qweight -> dq* /
|
| 43 |
+
//
|
| 44 |
+
// After fusion:
|
| 45 |
+
// input -> quantized::linear_with_input_q_dq_qweight_dq_output_fp32* ->
|
| 46 |
+
// qweight /
|
| 47 |
+
//
|
| 48 |
+
// Additional Note: the weight is packed as well
|
| 49 |
+
// Params:
|
| 50 |
+
// X: float32 Tensor, will be quantized to quint8 in the op
|
| 51 |
+
// W_prepack: packed qint8 quantized weight and bias
|
| 52 |
+
// Returns:
|
| 53 |
+
// Y: float32 Tensor
|
| 54 |
+
virtual at::Tensor apply_with_input_q_dq_qweight_dq_output_fp32(
|
| 55 |
+
at::Tensor input,
|
| 56 |
+
double input_scale,
|
| 57 |
+
int64_t input_zero_point) {
|
| 58 |
+
throw std::runtime_error(
|
| 59 |
+
"apply_with_input_q_dq_qweight_dq_output_fp32 is not implemented for this packed "
|
| 60 |
+
"parameter type");
|
| 61 |
+
return {};
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
// Corresponding pattern (the ops with `*` are part of the pattern that
|
| 65 |
+
// represents the computation of quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32):
|
| 66 |
+
// input -> q* -> dq* -> linear* -> relu* ->
|
| 67 |
+
// qweight -> dq* /
|
| 68 |
+
//
|
| 69 |
+
// After fusion:
|
| 70 |
+
// input -> quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32* ->
|
| 71 |
+
// qweight /
|
| 72 |
+
//
|
| 73 |
+
// Additional Note: the weight is packed as well
|
| 74 |
+
// Params:
|
| 75 |
+
// input: float32 Tensor, will be quantized to quint8 in the op
|
| 76 |
+
// Returns:
|
| 77 |
+
// float32 Tensor
|
| 78 |
+
virtual at::Tensor apply_with_input_q_dq_qweight_dq_relu_output_fp32(
|
| 79 |
+
at::Tensor input,
|
| 80 |
+
double input_scale,
|
| 81 |
+
int64_t input_zero_point) {
|
| 82 |
+
throw std::runtime_error(
|
| 83 |
+
"apply_with_input_q_dq_qweight_dq_relu_output_fp32 is not implemented for this packed "
|
| 84 |
+
"parameter type");
|
| 85 |
+
return {};
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
virtual at::Tensor apply_dynamic(
|
| 89 |
+
at::Tensor input,
|
| 90 |
+
bool reduce_range = false) = 0;
|
| 91 |
+
virtual at::Tensor apply_dynamic_relu(
|
| 92 |
+
at::Tensor input,
|
| 93 |
+
bool reduce_range = false) = 0;
|
| 94 |
+
|
| 95 |
+
virtual at::Tensor& apply_dynamic_out(
|
| 96 |
+
const at::Tensor& /* input */,
|
| 97 |
+
at::Tensor& output,
|
| 98 |
+
bool /* reduce_range */) {
|
| 99 |
+
throw std::runtime_error(
|
| 100 |
+
"apply_dynamic_out is not implemented for this packed "
|
| 101 |
+
"parameter type");
|
| 102 |
+
return output;
|
| 103 |
+
}
|
| 104 |
+
virtual at::Tensor& apply_dynamic_relu_out(
|
| 105 |
+
const at::Tensor& /* input */,
|
| 106 |
+
at::Tensor& output,
|
| 107 |
+
bool /* reduce_range */) {
|
| 108 |
+
throw std::runtime_error(
|
| 109 |
+
"apply_dynamic_relu_out is not implemented for this packed "
|
| 110 |
+
"parameter type");
|
| 111 |
+
return output;
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
virtual std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() = 0;
|
| 115 |
+
|
| 116 |
+
virtual std::optional<at::Tensor> bias() = 0;
|
| 117 |
+
|
| 118 |
+
virtual void set_bias(std::optional<at::Tensor> /*bias*/) {
|
| 119 |
+
throw std::runtime_error(
|
| 120 |
+
"set_bias is not implemented for this packed "
|
| 121 |
+
"parameter type");
|
| 122 |
+
}
|
| 123 |
+
};
|
| 124 |
+
|
| 125 |
+
template <int kSpatialDim = 2>
|
| 126 |
+
struct ConvPackedParamsBase : public torch::jit::CustomClassHolder {
|
| 127 |
+
virtual at::Tensor apply(
|
| 128 |
+
const at::Tensor& input,
|
| 129 |
+
double output_scale,
|
| 130 |
+
int64_t output_zero_point) = 0;
|
| 131 |
+
virtual at::Tensor apply_relu(
|
| 132 |
+
const at::Tensor& input,
|
| 133 |
+
double output_scale,
|
| 134 |
+
int64_t output_zero_point) = 0;
|
| 135 |
+
virtual at::Tensor apply_dynamic(
|
| 136 |
+
const at::Tensor& input,
|
| 137 |
+
bool reduce_range) = 0;
|
| 138 |
+
|
| 139 |
+
virtual std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() = 0;
|
| 140 |
+
|
| 141 |
+
virtual torch::List<int64_t> stride() const = 0;
|
| 142 |
+
virtual torch::List<int64_t> padding() const = 0;
|
| 143 |
+
virtual torch::List<int64_t> output_padding() const = 0;
|
| 144 |
+
virtual torch::List<int64_t> dilation() const = 0;
|
| 145 |
+
virtual int64_t groups() const = 0;
|
| 146 |
+
virtual bool transpose() const = 0;
|
| 147 |
+
};
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/BinaryOps.h
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/Tensor.h>
|
| 2 |
+
|
| 3 |
+
namespace at {
|
| 4 |
+
namespace native {
|
| 5 |
+
TORCH_API Tensor
|
| 6 |
+
quantized_add(Tensor qa, Tensor qb, double scale, int64_t zero_point);
|
| 7 |
+
}
|
| 8 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/EmbeddingPackedParams.h
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/core/ivalue.h>
|
| 5 |
+
|
| 6 |
+
struct EmbeddingPackedParamsBase : public torch::jit::CustomClassHolder {
|
| 7 |
+
virtual at::Tensor embeddingbag_byte(
|
| 8 |
+
const at::Tensor& indices,
|
| 9 |
+
const std::optional<at::Tensor>& offsets,
|
| 10 |
+
bool pruned_weights,
|
| 11 |
+
const std::optional<at::Tensor>& per_sample_weights_,
|
| 12 |
+
const std::optional<at::Tensor>& compressed_indices_mapping,
|
| 13 |
+
bool include_last_offset,
|
| 14 |
+
bool is_embedding_op) = 0;
|
| 15 |
+
|
| 16 |
+
virtual at::Tensor embeddingbag_4bit(
|
| 17 |
+
const at::Tensor& indices,
|
| 18 |
+
const std::optional<at::Tensor>& offsets,
|
| 19 |
+
bool pruned_weights,
|
| 20 |
+
const std::optional<at::Tensor>& per_sample_weights_,
|
| 21 |
+
const std::optional<at::Tensor>& compressed_indices_mapping,
|
| 22 |
+
bool include_last_offset,
|
| 23 |
+
bool is_embedding_op) = 0;
|
| 24 |
+
|
| 25 |
+
virtual at::Tensor unpack() = 0;
|
| 26 |
+
|
| 27 |
+
virtual int64_t bit_rate() const = 0;
|
| 28 |
+
virtual int64_t version() const = 0;
|
| 29 |
+
};
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/OnednnUtils.h
ADDED
|
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/Config.h>
|
| 4 |
+
#if AT_MKLDNN_ENABLED()
|
| 5 |
+
#include <ATen/Tensor.h>
|
| 6 |
+
#include <ATen/native/quantized/PackedParams.h>
|
| 7 |
+
#include <ideep.hpp>
|
| 8 |
+
#include <cpuinfo.h>
|
| 9 |
+
|
| 10 |
+
#include <c10/util/CallOnce.h>
|
| 11 |
+
|
| 12 |
+
using PrimitiveCacheKey = std::tuple<
|
| 13 |
+
double, // input_scale
|
| 14 |
+
int64_t, // input_zero_point
|
| 15 |
+
std::vector<int64_t>, // input_shape
|
| 16 |
+
double, // output_scale
|
| 17 |
+
int64_t, // output_zero_point
|
| 18 |
+
int64_t, // OMP_number_of_threads
|
| 19 |
+
double, // accum_scale
|
| 20 |
+
int64_t>; // accum_zero_point
|
| 21 |
+
|
| 22 |
+
enum CacheKeyIndex {
|
| 23 |
+
InputScale,
|
| 24 |
+
InputZeroPoint,
|
| 25 |
+
InputShape,
|
| 26 |
+
OutputScale,
|
| 27 |
+
OutputZeroPoint,
|
| 28 |
+
NumOfThreads,
|
| 29 |
+
};
|
| 30 |
+
|
| 31 |
+
// Base class of primitive cache
|
| 32 |
+
struct PrimitiveCache {
|
| 33 |
+
PrimitiveCacheKey key;
|
| 34 |
+
|
| 35 |
+
bool hit(const PrimitiveCacheKey& key) {
|
| 36 |
+
return this->key == key;
|
| 37 |
+
}
|
| 38 |
+
};
|
| 39 |
+
|
| 40 |
+
using LinearParams = ideep::matmul_forward_params;
|
| 41 |
+
using Conv = dnnl::convolution_forward;
|
| 42 |
+
using ConvDesc = dnnl::convolution_forward::primitive_desc;
|
| 43 |
+
using ConvParams = ideep::convolution_forward_params;
|
| 44 |
+
using Deconv = dnnl::deconvolution_forward;
|
| 45 |
+
using DeconvDesc = dnnl::deconvolution_forward::primitive_desc;
|
| 46 |
+
using DeconvParams = ideep::deconv_forward_params;
|
| 47 |
+
|
| 48 |
+
struct LinearPrimitiveCache : PrimitiveCache {
|
| 49 |
+
LinearPrimitiveCache() {}
|
| 50 |
+
|
| 51 |
+
LinearPrimitiveCache(
|
| 52 |
+
const PrimitiveCacheKey& key,
|
| 53 |
+
const LinearParams& param) {
|
| 54 |
+
this->key = key;
|
| 55 |
+
this->param = param;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
LinearParams param;
|
| 59 |
+
|
| 60 |
+
// For dynamic qlinear, scale and zero point
|
| 61 |
+
// are set at execution time. So we only need to compare
|
| 62 |
+
// the rest part of key.
|
| 63 |
+
bool hit_dynamic(const PrimitiveCacheKey& new_key) {
|
| 64 |
+
auto cached_input_shape = std::get<InputShape>(this->key);
|
| 65 |
+
auto new_input_shape = std::get<InputShape>(new_key);
|
| 66 |
+
return (
|
| 67 |
+
cached_input_shape == new_input_shape &&
|
| 68 |
+
std::get<NumOfThreads>(this->key) == std::get<NumOfThreads>(new_key));
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
LinearParams& get_param() {
|
| 72 |
+
return param;
|
| 73 |
+
}
|
| 74 |
+
};
|
| 75 |
+
|
| 76 |
+
struct ConvPrimitiveCache : PrimitiveCache {
|
| 77 |
+
ConvPrimitiveCache() {}
|
| 78 |
+
|
| 79 |
+
ConvPrimitiveCache(
|
| 80 |
+
const PrimitiveCacheKey& key,
|
| 81 |
+
const ConvParams& params) {
|
| 82 |
+
this->key = key;
|
| 83 |
+
this->params = params;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
ConvParams params;
|
| 87 |
+
|
| 88 |
+
ConvParams& get_params() {
|
| 89 |
+
return params;
|
| 90 |
+
}
|
| 91 |
+
};
|
| 92 |
+
|
| 93 |
+
struct DeconvPrimitiveCache : PrimitiveCache {
|
| 94 |
+
DeconvPrimitiveCache() {}
|
| 95 |
+
|
| 96 |
+
DeconvPrimitiveCache(
|
| 97 |
+
const PrimitiveCacheKey& key,
|
| 98 |
+
const DeconvParams& params) {
|
| 99 |
+
this->key = key;
|
| 100 |
+
this->params = params;
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
DeconvParams params;
|
| 104 |
+
|
| 105 |
+
DeconvParams& get_params() {
|
| 106 |
+
return params;
|
| 107 |
+
}
|
| 108 |
+
};
|
| 109 |
+
|
| 110 |
+
enum PostOps {
|
| 111 |
+
NoPostOp,
|
| 112 |
+
Relu,
|
| 113 |
+
LeakyRelu,
|
| 114 |
+
Tanh,
|
| 115 |
+
Gelu
|
| 116 |
+
};
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
struct PackedLinearWeightsOnednn : public LinearPackedParamsBase {
|
| 120 |
+
PackedLinearWeightsOnednn(
|
| 121 |
+
std::unique_ptr<ideep::tensor> weight,
|
| 122 |
+
std::optional<ideep::tensor> bias,
|
| 123 |
+
at::Tensor orig_weight,
|
| 124 |
+
std::optional<at::Tensor> orig_bias)
|
| 125 |
+
: weight_(std::move(weight)),
|
| 126 |
+
bias_(std::move(bias)),
|
| 127 |
+
orig_weight_(std::move(orig_weight)),
|
| 128 |
+
orig_bias_(std::move(orig_bias)) {
|
| 129 |
+
cache_initialized_flag = std::make_unique<c10::once_flag>();
|
| 130 |
+
}
|
| 131 |
+
std::unique_ptr<ideep::tensor> weight_;
|
| 132 |
+
std::optional<ideep::tensor> bias_;
|
| 133 |
+
at::Tensor orig_weight_;
|
| 134 |
+
std::optional<at::Tensor> orig_bias_;
|
| 135 |
+
|
| 136 |
+
at::Tensor apply(
|
| 137 |
+
at::Tensor input,
|
| 138 |
+
double output_scale,
|
| 139 |
+
int64_t output_zero_point) override;
|
| 140 |
+
at::Tensor apply_relu(
|
| 141 |
+
at::Tensor input,
|
| 142 |
+
double output_scale,
|
| 143 |
+
int64_t output_zero_point) override;
|
| 144 |
+
|
| 145 |
+
at::Tensor apply_dynamic(at::Tensor input, bool reduce_range=false) override;
|
| 146 |
+
at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range=false) override;
|
| 147 |
+
|
| 148 |
+
at::Tensor apply_leaky_relu(
|
| 149 |
+
at::Tensor input,
|
| 150 |
+
double output_scale,
|
| 151 |
+
int64_t output_zero_point,
|
| 152 |
+
double negative_slope);
|
| 153 |
+
|
| 154 |
+
at::Tensor apply_tanh(
|
| 155 |
+
at::Tensor input,
|
| 156 |
+
double output_scale,
|
| 157 |
+
int64_t output_zero_point);
|
| 158 |
+
|
| 159 |
+
std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() override;
|
| 160 |
+
|
| 161 |
+
std::optional<at::Tensor> bias() override {
|
| 162 |
+
return orig_bias_;
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
static c10::intrusive_ptr<LinearPackedParamsBase> prepack(
|
| 166 |
+
at::Tensor weight,
|
| 167 |
+
std::optional<at::Tensor> bias);
|
| 168 |
+
|
| 169 |
+
private:
|
| 170 |
+
LinearPrimitiveCache prim_cache;
|
| 171 |
+
std::unique_ptr<c10::once_flag> cache_initialized_flag;
|
| 172 |
+
|
| 173 |
+
template <PostOps post_op>
|
| 174 |
+
at::Tensor apply_impl(
|
| 175 |
+
at::Tensor input,
|
| 176 |
+
double output_scale,
|
| 177 |
+
int64_t output_zero_point,
|
| 178 |
+
torch::List<at::Scalar> post_op_args = torch::List<at::Scalar>());
|
| 179 |
+
|
| 180 |
+
template <bool ReluFused>
|
| 181 |
+
at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range=false);
|
| 182 |
+
|
| 183 |
+
LinearPrimitiveCache& get_cache() {
|
| 184 |
+
return prim_cache;
|
| 185 |
+
}
|
| 186 |
+
};
|
| 187 |
+
|
| 188 |
+
template <int kSpatialDim = 2>
|
| 189 |
+
struct PackedConvWeightsOnednn : public ConvPackedParamsBase<kSpatialDim> {
|
| 190 |
+
PackedConvWeightsOnednn(
|
| 191 |
+
std::unique_ptr<ideep::tensor> weight,
|
| 192 |
+
std::optional<ideep::tensor> bias,
|
| 193 |
+
at::Tensor orig_weight,
|
| 194 |
+
std::optional<at::Tensor> orig_bias,
|
| 195 |
+
torch::List<int64_t> stride,
|
| 196 |
+
torch::List<int64_t> padding,
|
| 197 |
+
torch::List<int64_t> output_padding,
|
| 198 |
+
torch::List<int64_t> dilation,
|
| 199 |
+
int64_t groups,
|
| 200 |
+
uint8_t transpose)
|
| 201 |
+
: weight_(std::move(weight)),
|
| 202 |
+
bias_(std::move(bias)),
|
| 203 |
+
orig_weight_(std::move(orig_weight)),
|
| 204 |
+
orig_bias_(std::move(orig_bias)),
|
| 205 |
+
stride_(std::move(stride)),
|
| 206 |
+
padding_(std::move(padding)),
|
| 207 |
+
output_padding_(std::move(output_padding)),
|
| 208 |
+
dilation_(std::move(dilation)),
|
| 209 |
+
groups_(groups),
|
| 210 |
+
transpose_(transpose) {
|
| 211 |
+
cache_initialized_flag = std::make_unique<c10::once_flag>();
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
std::unique_ptr<ideep::tensor> weight_;
|
| 215 |
+
std::optional<ideep::tensor> bias_;
|
| 216 |
+
at::Tensor orig_weight_;
|
| 217 |
+
std::optional<at::Tensor> orig_bias_;
|
| 218 |
+
torch::List<int64_t> stride_;
|
| 219 |
+
torch::List<int64_t> padding_;
|
| 220 |
+
torch::List<int64_t> output_padding_;
|
| 221 |
+
torch::List<int64_t> dilation_;
|
| 222 |
+
int64_t groups_;
|
| 223 |
+
uint8_t transpose_;
|
| 224 |
+
|
| 225 |
+
at::Tensor apply(
|
| 226 |
+
const at::Tensor& input,
|
| 227 |
+
double output_scale,
|
| 228 |
+
int64_t output_zero_point) override;
|
| 229 |
+
|
| 230 |
+
at::Tensor apply_relu(
|
| 231 |
+
const at::Tensor& input,
|
| 232 |
+
double output_scale,
|
| 233 |
+
int64_t output_zero_point) override;
|
| 234 |
+
|
| 235 |
+
at::Tensor apply_dynamic(
|
| 236 |
+
const at::Tensor& input,
|
| 237 |
+
bool reduce_range) override;
|
| 238 |
+
|
| 239 |
+
at::Tensor apply_add(
|
| 240 |
+
const at::Tensor& input,
|
| 241 |
+
const at::Tensor& accum,
|
| 242 |
+
double output_scale,
|
| 243 |
+
int64_t output_zero_point);
|
| 244 |
+
|
| 245 |
+
at::Tensor apply_add_relu(
|
| 246 |
+
const at::Tensor& input,
|
| 247 |
+
const at::Tensor& accum,
|
| 248 |
+
double output_scale,
|
| 249 |
+
int64_t output_zero_point);
|
| 250 |
+
|
| 251 |
+
std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() override;
|
| 252 |
+
|
| 253 |
+
static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> prepack(
|
| 254 |
+
at::Tensor weight,
|
| 255 |
+
std::optional<at::Tensor> bias,
|
| 256 |
+
torch::List<int64_t> stride,
|
| 257 |
+
torch::List<int64_t> padding,
|
| 258 |
+
torch::List<int64_t> output_padding,
|
| 259 |
+
torch::List<int64_t> dilation,
|
| 260 |
+
int64_t groups,
|
| 261 |
+
bool transpose);
|
| 262 |
+
|
| 263 |
+
torch::List<int64_t> stride() const override {
|
| 264 |
+
return stride_;
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
torch::List<int64_t> padding() const override {
|
| 268 |
+
return padding_;
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
torch::List<int64_t> output_padding() const override {
|
| 272 |
+
return output_padding_;
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
torch::List<int64_t> dilation() const override {
|
| 276 |
+
return dilation_;
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
int64_t groups() const override {
|
| 280 |
+
return groups_;
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
bool transpose() const override {
|
| 284 |
+
return (bool)transpose_;
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
private:
|
| 288 |
+
ConvPrimitiveCache conv_prim_cache;
|
| 289 |
+
DeconvPrimitiveCache deconv_prim_cache;
|
| 290 |
+
std::unique_ptr<c10::once_flag> cache_initialized_flag;
|
| 291 |
+
|
| 292 |
+
template <bool ReluFused>
|
| 293 |
+
at::Tensor apply_impl(
|
| 294 |
+
const at::Tensor& input,
|
| 295 |
+
const std::optional<at::Tensor>& accum,
|
| 296 |
+
double output_scale,
|
| 297 |
+
int64_t output_zero_point);
|
| 298 |
+
|
| 299 |
+
ConvPrimitiveCache& get_conv_cache() {
|
| 300 |
+
assert(!transpose());
|
| 301 |
+
return conv_prim_cache;
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
DeconvPrimitiveCache& get_deconv_cache() {
|
| 305 |
+
assert(transpose());
|
| 306 |
+
return deconv_prim_cache;
|
| 307 |
+
}
|
| 308 |
+
};
|
| 309 |
+
|
| 310 |
+
namespace onednn_utils {
|
| 311 |
+
|
| 312 |
+
inline ideep::attr_t create_attr_by_post_op(
|
| 313 |
+
const c10::string_view& binary_post_op,
|
| 314 |
+
double binary_alpha,
|
| 315 |
+
double input1_scale,
|
| 316 |
+
int64_t input1_zero_point,
|
| 317 |
+
const ideep::tensor::desc& input1_desc,
|
| 318 |
+
const c10::string_view& unary_post_op,
|
| 319 |
+
const torch::List<std::optional<at::Scalar>>& unary_post_op_args,
|
| 320 |
+
const c10::string_view& unary_post_op_algorithm) {
|
| 321 |
+
using ideep::tensor;
|
| 322 |
+
if (binary_post_op == "none") {
|
| 323 |
+
if (unary_post_op == "relu") {
|
| 324 |
+
return ideep::attr_t::fuse_relu();
|
| 325 |
+
} else if (unary_post_op == "leaky_relu") {
|
| 326 |
+
TORCH_CHECK(
|
| 327 |
+
unary_post_op_args.size() == 1,
|
| 328 |
+
"onednn qlinear: expect one argument for post op leaky_relu but got ", unary_post_op_args.size(), " args");
|
| 329 |
+
auto alpha = unary_post_op_args[0].value().to<float>();
|
| 330 |
+
return ideep::attr_t::fuse_relu_v2(alpha);
|
| 331 |
+
} else if (unary_post_op == "tanh") {
|
| 332 |
+
return ideep::attr_t::fuse_tanh();
|
| 333 |
+
} else if (unary_post_op == "gelu") {
|
| 334 |
+
TORCH_CHECK(
|
| 335 |
+
unary_post_op_algorithm == "none" || unary_post_op_algorithm == "tanh",
|
| 336 |
+
"onednn qlinear: algorithm for post op gelu must be none or tanh but got ", unary_post_op_algorithm);
|
| 337 |
+
auto post_algorithm = unary_post_op_algorithm == "none" ?
|
| 338 |
+
dnnl::algorithm::eltwise_gelu_erf :
|
| 339 |
+
dnnl::algorithm::eltwise_gelu_tanh;
|
| 340 |
+
return ideep::attr_t::fuse_gelu_v2(0.f, 0.f, post_algorithm);
|
| 341 |
+
} else if (unary_post_op == "hardtanh") {
|
| 342 |
+
TORCH_CHECK(
|
| 343 |
+
unary_post_op_args.size() == 2 &&
|
| 344 |
+
unary_post_op_args[0].has_value() &&
|
| 345 |
+
unary_post_op_args[1].has_value(),
|
| 346 |
+
"hardtanh is expected to have two scalar input: min_val and max_val");
|
| 347 |
+
auto lower_bound_value =
|
| 348 |
+
unary_post_op_args[0].value().to<float>();
|
| 349 |
+
auto upper_bound_value =
|
| 350 |
+
unary_post_op_args[1].value().to<float>();
|
| 351 |
+
return ideep::attr_t::fuse_clamp(lower_bound_value, upper_bound_value);
|
| 352 |
+
} else if (unary_post_op == "hardswish") {
|
| 353 |
+
return ideep::attr_t::fuse_hardswish();
|
| 354 |
+
} else if (unary_post_op == "swish") {
|
| 355 |
+
return ideep::attr_t::fuse_swish();
|
| 356 |
+
} else {
|
| 357 |
+
TORCH_CHECK(
|
| 358 |
+
unary_post_op == "none",
|
| 359 |
+
"onednn qlinear: unsupported unary post op ", unary_post_op);
|
| 360 |
+
}
|
| 361 |
+
} else if (binary_post_op == "sum") {
|
| 362 |
+
if (unary_post_op == "none") {
|
| 363 |
+
return ideep::attr_t::fuse_sum(input1_scale, input1_zero_point);
|
| 364 |
+
} else if (unary_post_op == "relu") {
|
| 365 |
+
return ideep::attr_t::residual_with_sum_zero_point(input1_scale, input1_zero_point);
|
| 366 |
+
} else {
|
| 367 |
+
TORCH_CHECK(
|
| 368 |
+
false,
|
| 369 |
+
"onednn qlinear: unsupported unary post op ", unary_post_op, " with binary post op sum");
|
| 370 |
+
}
|
| 371 |
+
} else if (binary_post_op == "add") {
|
| 372 |
+
if (unary_post_op == "none") {
|
| 373 |
+
return ideep::attr_t::fuse_binary(ideep::algorithm::binary_add, input1_desc);
|
| 374 |
+
} else if (unary_post_op == "relu") {
|
| 375 |
+
ideep::post_ops po;
|
| 376 |
+
po.append_binary(ideep::algorithm::binary_add, input1_desc);
|
| 377 |
+
po.append_eltwise(ideep::algorithm::eltwise_relu, 0, 0);
|
| 378 |
+
return ideep::attr_t::attr_post_ops(po);
|
| 379 |
+
} else {
|
| 380 |
+
TORCH_CHECK(
|
| 381 |
+
false,
|
| 382 |
+
"onednn qlinear: unsupported unary post op ", unary_post_op, " with binary post op add");
|
| 383 |
+
}
|
| 384 |
+
} else {
|
| 385 |
+
TORCH_CHECK(
|
| 386 |
+
false,
|
| 387 |
+
"onednn qlinear: unsupported binary post op ", binary_post_op);
|
| 388 |
+
}
|
| 389 |
+
return ideep::attr_t();
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
// ONEDNN requires symmetric quantization of weight
|
| 393 |
+
// Use this util function to check.
|
| 394 |
+
inline bool is_weight_symmetric_quant(
|
| 395 |
+
const at::Tensor& weight,
|
| 396 |
+
bool is_transposed_conv) {
|
| 397 |
+
bool is_symmetric = true;
|
| 398 |
+
const auto qtype = weight.qscheme();
|
| 399 |
+
if (qtype == c10::kPerTensorAffine) {
|
| 400 |
+
is_symmetric &= (weight.q_zero_point() == 0);
|
| 401 |
+
} else if (qtype == c10::kPerChannelAffine) {
|
| 402 |
+
if (is_transposed_conv) {
|
| 403 |
+
// This case is currently not supported in PyTorch
|
| 404 |
+
// but we do not want to raise an error in this util function.
|
| 405 |
+
is_symmetric = false;
|
| 406 |
+
} else {
|
| 407 |
+
auto output_channels = weight.size(0);
|
| 408 |
+
for (int i = 0; i < output_channels; ++i) {
|
| 409 |
+
auto zp = weight.q_per_channel_zero_points()[i].item<int32_t>();
|
| 410 |
+
is_symmetric &= (zp == 0);
|
| 411 |
+
}
|
| 412 |
+
}
|
| 413 |
+
} else {
|
| 414 |
+
// This case is currently not supported in PyTorch
|
| 415 |
+
// but we do not want to raise an error in this util function.
|
| 416 |
+
is_symmetric = false;
|
| 417 |
+
}
|
| 418 |
+
return is_symmetric;
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
// When qengine is x86, use this util func to check if onednn kernel
|
| 422 |
+
// is preferred than fbgemm's to get better performance.
|
| 423 |
+
inline bool should_use_onednn_quant(
|
| 424 |
+
const at::Tensor& weight,
|
| 425 |
+
bool is_transposed_conv,
|
| 426 |
+
int groups,
|
| 427 |
+
torch::List<int64_t> output_padding) {
|
| 428 |
+
// Performance of onednn is only validated on Linux right now.
|
| 429 |
+
// Also, the heuristics for dispatching are based on perf data on Linux.
|
| 430 |
+
// So, for x86 qengine, we always use fbgemm kernels if OS is not Linux.
|
| 431 |
+
// TODO Support more OSs.
|
| 432 |
+
#if !defined(__linux__)
|
| 433 |
+
return false;
|
| 434 |
+
#else
|
| 435 |
+
bool vnni_available = cpuinfo_has_x86_avx512vnni();
|
| 436 |
+
bool w_sym_quant =
|
| 437 |
+
is_weight_symmetric_quant(weight, is_transposed_conv);
|
| 438 |
+
bool opad_all_zero =
|
| 439 |
+
std::all_of(output_padding.begin(), output_padding.end(), [](int i) { return i==0; });
|
| 440 |
+
return vnni_available && (groups <= 100) && w_sym_quant && opad_all_zero;
|
| 441 |
+
#endif
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
} // onednn_utils
|
| 445 |
+
|
| 446 |
+
at::Tensor _qconv_prepack_onednn(
|
| 447 |
+
at::Tensor weight, // from CPU backend instead of QuantizedCPU
|
| 448 |
+
at::Tensor weight_scales, // Weight zero points must be 0 for onednn
|
| 449 |
+
double input_scale,
|
| 450 |
+
int64_t input_zero_point,
|
| 451 |
+
torch::List<int64_t> stride,
|
| 452 |
+
torch::List<int64_t> padding,
|
| 453 |
+
torch::List<int64_t> dilation,
|
| 454 |
+
int64_t groups,
|
| 455 |
+
std::optional<torch::List<int64_t>> input_shape=std::nullopt);
|
| 456 |
+
|
| 457 |
+
#endif // #if AT_MKLDNN_ENABLED()
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/QnnpackUtils.h
ADDED
|
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#ifdef USE_PYTORCH_QNNPACK
|
| 4 |
+
#include <ATen/core/Tensor.h>
|
| 5 |
+
#include <c10/util/irange.h>
|
| 6 |
+
#include <pytorch_qnnpack.h>
|
| 7 |
+
#include <qnnpack_func.h>
|
| 8 |
+
#include <ATen/native/quantized/cpu/XnnpackUtils.h>
|
| 9 |
+
#include <ATen/native/quantized/PackedParams.h>
|
| 10 |
+
#include <ATen/native/utils/Factory.h>
|
| 11 |
+
|
| 12 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 13 |
+
#include <ATen/Functions.h>
|
| 14 |
+
#else
|
| 15 |
+
#include <ATen/ops/empty.h>
|
| 16 |
+
#endif
|
| 17 |
+
|
| 18 |
+
#include <utility>
|
| 19 |
+
inline int kPaddingChannels = 8;
|
| 20 |
+
struct QnnpackOperatorDeleter {
|
| 21 |
+
void operator()(pytorch_qnnp_operator_t op) {
|
| 22 |
+
pytorch_qnnp_delete_operator(op);
|
| 23 |
+
}
|
| 24 |
+
};
|
| 25 |
+
|
| 26 |
+
// PackedWeight struct for QNNPACK stores the original Weight and Bias as
|
| 27 |
+
// QNNPACK currently does not support an unpack function.
|
| 28 |
+
// For PyTorch Mobile, once the model is scripted and serialized we don't need
|
| 29 |
+
// to call unpack, so we can save some memory by checking for this case and free
|
| 30 |
+
// the original weights after packing.
|
| 31 |
+
// Input scale is set to null in pre-pack step. QNNPACK needs bias quantized
|
| 32 |
+
// with input scale which is available at runtime in pytorch. During runtime if
|
| 33 |
+
// input scale value changes then we requantize bias with the updated scale. For
|
| 34 |
+
// inference we expect the graph to be static so the input scale should not
|
| 35 |
+
// change across consecutive inference calls.
|
| 36 |
+
struct PackedLinearWeightsQnnp : public LinearPackedParamsBase {
|
| 37 |
+
PackedLinearWeightsQnnp(
|
| 38 |
+
std::unique_ptr<qnnpack::PackBMatrix> w,
|
| 39 |
+
at::Tensor orig_weight,
|
| 40 |
+
at::Tensor bias,
|
| 41 |
+
std::optional<double> input_scale,
|
| 42 |
+
at::Tensor w_scales,
|
| 43 |
+
std::vector<uint8_t>&& w_zps)
|
| 44 |
+
: w(std::move(w)),
|
| 45 |
+
orig_weight(std::move(orig_weight)),
|
| 46 |
+
bias_(at::native::mobile::allocate_padded_contiguous_if_needed(
|
| 47 |
+
bias, bias.suggest_memory_format())),
|
| 48 |
+
per_channel_(this->orig_weight.qscheme() == at::kPerChannelAffine),
|
| 49 |
+
input_scale(std::move(input_scale)),
|
| 50 |
+
w_scales(std::move(w_scales)),
|
| 51 |
+
w_zero_points(std::move(w_zps)),
|
| 52 |
+
q_scheme(this->orig_weight.qscheme()) {
|
| 53 |
+
weight_sizes = this->orig_weight.sizes().vec();
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
std::unique_ptr<qnnpack::PackBMatrix> w;
|
| 57 |
+
at::Tensor orig_weight;
|
| 58 |
+
at::Tensor bias_;
|
| 59 |
+
bool per_channel_;
|
| 60 |
+
std::optional<double> input_scale;
|
| 61 |
+
at::Tensor w_scales;
|
| 62 |
+
std::vector<uint8_t> w_zero_points;
|
| 63 |
+
std::vector<float> requantization_scales;
|
| 64 |
+
std::vector<int64_t> weight_sizes;
|
| 65 |
+
c10::QScheme q_scheme;
|
| 66 |
+
|
| 67 |
+
at::Tensor apply(
|
| 68 |
+
at::Tensor input,
|
| 69 |
+
double output_scale,
|
| 70 |
+
int64_t output_zero_point) override;
|
| 71 |
+
at::Tensor apply_relu(
|
| 72 |
+
at::Tensor input,
|
| 73 |
+
double output_scale,
|
| 74 |
+
int64_t output_zero_point) override;
|
| 75 |
+
|
| 76 |
+
at::Tensor apply_dynamic(at::Tensor input, bool reduce_range=false) override;
|
| 77 |
+
at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range=false) override;
|
| 78 |
+
|
| 79 |
+
std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() override;
|
| 80 |
+
|
| 81 |
+
std::optional<at::Tensor> bias() override {
|
| 82 |
+
return bias_;
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
static c10::intrusive_ptr<LinearPackedParamsBase> prepack(
|
| 86 |
+
at::Tensor weight,
|
| 87 |
+
std::optional<at::Tensor> bias);
|
| 88 |
+
|
| 89 |
+
bool per_channel() const {
|
| 90 |
+
return per_channel_;
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
private:
|
| 94 |
+
std::mutex qnnp_mutex_;
|
| 95 |
+
|
| 96 |
+
#ifdef USE_XNNPACK
|
| 97 |
+
xnnpack_operator xnnp_linear_op;
|
| 98 |
+
|
| 99 |
+
template <typename scalar_t, bool kReluFused>
|
| 100 |
+
at::Tensor apply_impl_xnnp(
|
| 101 |
+
const at::Tensor& input,
|
| 102 |
+
double output_scale,
|
| 103 |
+
int64_t output_zero_point);
|
| 104 |
+
#endif // USE_XNNPACK
|
| 105 |
+
|
| 106 |
+
template <bool ReluFused>
|
| 107 |
+
at::Tensor apply_impl(
|
| 108 |
+
at::Tensor input,
|
| 109 |
+
double output_scale,
|
| 110 |
+
int64_t output_zero_point);
|
| 111 |
+
|
| 112 |
+
template <bool ReluFused>
|
| 113 |
+
at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range);
|
| 114 |
+
};
|
| 115 |
+
|
| 116 |
+
template <int kSpatialDim = 2>
|
| 117 |
+
struct PackedConvWeightsQnnp : public ConvPackedParamsBase<kSpatialDim> {
|
| 118 |
+
PackedConvWeightsQnnp(
|
| 119 |
+
std::unique_ptr<qnnpack::PrePackConvWeights> w,
|
| 120 |
+
at::Tensor orig_weight,
|
| 121 |
+
at::Tensor bias,
|
| 122 |
+
torch::List<int64_t> stride,
|
| 123 |
+
torch::List<int64_t> padding,
|
| 124 |
+
torch::List<int64_t> output_padding,
|
| 125 |
+
torch::List<int64_t> dilation,
|
| 126 |
+
int64_t groups,
|
| 127 |
+
bool transpose,
|
| 128 |
+
std::optional<double> input_scale,
|
| 129 |
+
std::vector<int64_t> kernel,
|
| 130 |
+
at::Tensor w_scale,
|
| 131 |
+
std::vector<uint8_t>&& w_zps,
|
| 132 |
+
bool is_per_channel)
|
| 133 |
+
: w(std::move(w)),
|
| 134 |
+
orig_weight(std::move(orig_weight)),
|
| 135 |
+
bias(std::move(bias)),
|
| 136 |
+
stride_(std::move(stride)),
|
| 137 |
+
padding_(std::move(padding)),
|
| 138 |
+
output_padding_(std::move(output_padding)),
|
| 139 |
+
dilation_(std::move(dilation)),
|
| 140 |
+
groups_(groups),
|
| 141 |
+
transpose_(transpose),
|
| 142 |
+
is_per_channel_(is_per_channel),
|
| 143 |
+
input_scale(input_scale),
|
| 144 |
+
kernel_(std::move(kernel)),
|
| 145 |
+
w_scales(std::move(w_scale)),
|
| 146 |
+
w_zero_points(std::move(w_zps)) {
|
| 147 |
+
const bool any_padding = std::any_of(
|
| 148 |
+
padding_.begin(), padding_.end(), [](const auto& e) { return e != 0; });
|
| 149 |
+
const size_t kernel_size =
|
| 150 |
+
std::accumulate(kernel_.begin(), kernel_.end(), 1, std::multiplies<>());
|
| 151 |
+
|
| 152 |
+
const size_t group_input_channels = transpose
|
| 153 |
+
? this->orig_weight.size(0) / groups
|
| 154 |
+
: this->orig_weight.size(1);
|
| 155 |
+
const size_t group_output_channels = transpose
|
| 156 |
+
? this->orig_weight.size(1)
|
| 157 |
+
: this->orig_weight.size(0) / groups;
|
| 158 |
+
|
| 159 |
+
const size_t kernel_depth = kSpatialDim == 3 ? kernel_[0] : 1;
|
| 160 |
+
const size_t kernel_height = kernel_[kSpatialDim - 2];
|
| 161 |
+
const size_t kernel_width = kernel_[kSpatialDim - 1];
|
| 162 |
+
|
| 163 |
+
pytorch_qnnp_ukernel_type ukernel_type;
|
| 164 |
+
if (transpose_) {
|
| 165 |
+
ukernel_type = pytorch_qnnp_ukernel_type_conv;
|
| 166 |
+
} else {
|
| 167 |
+
ukernel_type = pytorch_qnnp_ukernel_type_none;
|
| 168 |
+
|
| 169 |
+
const bool has_depthwise_dimensions =
|
| 170 |
+
(kSpatialDim == 2 &&
|
| 171 |
+
((kernel_height == 3 && kernel_width == 3) ||
|
| 172 |
+
(kernel_height == 5 && kernel_width == 5))) ||
|
| 173 |
+
(kSpatialDim == 3 && kernel_height == 3 && kernel_width == 3 &&
|
| 174 |
+
kernel_depth == 3);
|
| 175 |
+
const bool has_depthwise_grouping =
|
| 176 |
+
group_input_channels == 1 && group_output_channels == 1 && groups > 1;
|
| 177 |
+
|
| 178 |
+
if (has_depthwise_dimensions && has_depthwise_grouping) {
|
| 179 |
+
ukernel_type = pytorch_qnnp_ukernel_type_dwconv;
|
| 180 |
+
} else if (
|
| 181 |
+
kernel_size == 1 &&
|
| 182 |
+
std::all_of(
|
| 183 |
+
stride_.begin(),
|
| 184 |
+
stride_.end(),
|
| 185 |
+
[](const auto& e) { return e == 1; }) &&
|
| 186 |
+
!any_padding) {
|
| 187 |
+
ukernel_type = group_input_channels >= SIZE_MAX
|
| 188 |
+
? pytorch_qnnp_ukernel_type_xzp_gemm
|
| 189 |
+
: pytorch_qnnp_ukernel_type_gemm;
|
| 190 |
+
} else {
|
| 191 |
+
ukernel_type = pytorch_qnnp_ukernel_type_conv;
|
| 192 |
+
}
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
if (is_per_channel && ukernel_type == pytorch_qnnp_ukernel_type_xzp_gemm) {
|
| 196 |
+
TORCH_INTERNAL_ASSERT(
|
| 197 |
+
false, "Per channel quantized weights are not supported for XZP kernels");
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
pytorch_qnnp_operator_t convolution{nullptr};
|
| 201 |
+
// Initially all the params are set to zero.
|
| 202 |
+
convolution = static_cast<pytorch_qnnp_operator_t>(
|
| 203 |
+
calloc(1, sizeof(struct pytorch_qnnp_operator)));
|
| 204 |
+
if (convolution == nullptr) {
|
| 205 |
+
TORCH_INTERNAL_ASSERT(
|
| 206 |
+
false, "failed to allocate %zu bytes for pytorch_qnnp_operator structure",
|
| 207 |
+
sizeof(struct pytorch_qnnp_operator));
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
convolution_op =
|
| 211 |
+
std::unique_ptr<pytorch_qnnp_operator, QnnpackOperatorDeleter>(
|
| 212 |
+
convolution);
|
| 213 |
+
|
| 214 |
+
// NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
|
| 215 |
+
convolution->ukernel_type = ukernel_type;
|
| 216 |
+
convolution->groups = groups;
|
| 217 |
+
convolution->group_input_channels = group_input_channels;
|
| 218 |
+
convolution->group_output_channels = group_output_channels;
|
| 219 |
+
convolution->kernel_depth = kernel_depth;
|
| 220 |
+
convolution->kernel_height = kernel_height;
|
| 221 |
+
convolution->kernel_width = kernel_width;
|
| 222 |
+
convolution->stride_depth = kSpatialDim == 3 ? stride_[0] : 1;
|
| 223 |
+
convolution->stride_height = stride_[kSpatialDim - 2];
|
| 224 |
+
convolution->stride_width = stride_[kSpatialDim - 1];
|
| 225 |
+
convolution->dilation_depth = kSpatialDim == 3 ? dilation_[0] : 1;
|
| 226 |
+
convolution->dilation_height = dilation_[kSpatialDim - 2];
|
| 227 |
+
convolution->dilation_width = dilation_[kSpatialDim - 1];
|
| 228 |
+
convolution->input_padding_height = padding_[kSpatialDim - 2];
|
| 229 |
+
convolution->input_padding_width = padding_[kSpatialDim - 1];
|
| 230 |
+
convolution->input_padding_depth = kSpatialDim == 3 ? padding_[0] : 0;
|
| 231 |
+
convolution->per_channel = is_per_channel_;
|
| 232 |
+
convolution->transpose = transpose_;
|
| 233 |
+
|
| 234 |
+
const uint32_t kr = pytorch_qnnp_params.q8conv.kr;
|
| 235 |
+
const size_t k_stride = (group_input_channels + (kr - 1)) & -kr;
|
| 236 |
+
|
| 237 |
+
size_t zero_size = sizeof(uint8_t) * k_stride;
|
| 238 |
+
size_t zero_offset = 0;
|
| 239 |
+
|
| 240 |
+
if (transpose_) {
|
| 241 |
+
convolution->adjustment_width = output_padding_[1];
|
| 242 |
+
convolution->adjustment_height = output_padding_[0];
|
| 243 |
+
if (group_input_channels < 8) {
|
| 244 |
+
zero_size += 8;
|
| 245 |
+
zero_offset = 8;
|
| 246 |
+
}
|
| 247 |
+
} else {
|
| 248 |
+
zero_buffer_size = 0;
|
| 249 |
+
if (any_padding) {
|
| 250 |
+
zero_size = 0;
|
| 251 |
+
zero_offset = 0;
|
| 252 |
+
if (ukernel_type == pytorch_qnnp_ukernel_type_dwconv) {
|
| 253 |
+
const uint32_t cr = pytorch_qnnp_params.q8dw9.cr;
|
| 254 |
+
const size_t group_stride = (groups + (cr - 1)) & -cr;
|
| 255 |
+
if (groups >= 8) {
|
| 256 |
+
zero_size = sizeof(uint8_t) * group_stride;
|
| 257 |
+
zero_offset = 0;
|
| 258 |
+
} else {
|
| 259 |
+
zero_size = sizeof(uint8_t) * group_stride + 8;
|
| 260 |
+
zero_offset = sizeof(uint8_t) * 8;
|
| 261 |
+
}
|
| 262 |
+
} else if (
|
| 263 |
+
ukernel_type == pytorch_qnnp_ukernel_type_conv ||
|
| 264 |
+
ukernel_type == pytorch_qnnp_ukernel_type_gemm) {
|
| 265 |
+
if (group_input_channels >= 8) {
|
| 266 |
+
zero_size = sizeof(uint8_t) * k_stride;
|
| 267 |
+
zero_offset = 0;
|
| 268 |
+
} else {
|
| 269 |
+
zero_size = sizeof(uint8_t) * k_stride + 8;
|
| 270 |
+
zero_offset = 8;
|
| 271 |
+
}
|
| 272 |
+
}
|
| 273 |
+
}
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
// NOLINTNEXTLINE(clang-analyzer-optin.portability.UnixAPI)
|
| 277 |
+
void* zero_buffer = malloc(zero_size);
|
| 278 |
+
if (zero_buffer == nullptr) {
|
| 279 |
+
pytorch_qnnp_delete_operator(convolution);
|
| 280 |
+
TORCH_INTERNAL_ASSERT(
|
| 281 |
+
false, "failed to allocate %zu bytes for zero padding",
|
| 282 |
+
zero_size);
|
| 283 |
+
}
|
| 284 |
+
// Need to set to input zero point
|
| 285 |
+
// memset(zero_buffer, input_zero_point, zero_size);
|
| 286 |
+
zero_buffer_size = zero_size;
|
| 287 |
+
convolution->zero_buffer = zero_buffer;
|
| 288 |
+
convolution->zero_pointer = (void*)((uintptr_t)zero_buffer + zero_offset);
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
std::unique_ptr<pytorch_qnnp_operator, QnnpackOperatorDeleter> convolution_op;
|
| 292 |
+
#ifdef USE_XNNPACK
|
| 293 |
+
xnnpack_operator xnnp_convolution_op;
|
| 294 |
+
#endif // USE_XNNPACK
|
| 295 |
+
std::unique_ptr<qnnpack::PrePackConvWeights> w;
|
| 296 |
+
at::Tensor orig_weight;
|
| 297 |
+
at::Tensor bias;
|
| 298 |
+
torch::List<int64_t> stride_;
|
| 299 |
+
torch::List<int64_t> padding_;
|
| 300 |
+
torch::List<int64_t> output_padding_;
|
| 301 |
+
torch::List<int64_t> dilation_;
|
| 302 |
+
int64_t groups_;
|
| 303 |
+
bool transpose_;
|
| 304 |
+
bool is_per_channel_;
|
| 305 |
+
std::optional<double> input_scale;
|
| 306 |
+
std::vector<int64_t> kernel_;
|
| 307 |
+
at::Tensor w_scales;
|
| 308 |
+
std::vector<uint8_t> w_zero_points;
|
| 309 |
+
std::vector<float> requantization_scales;
|
| 310 |
+
size_t zero_buffer_size;
|
| 311 |
+
|
| 312 |
+
at::Tensor apply(
|
| 313 |
+
const at::Tensor& input,
|
| 314 |
+
double output_scale,
|
| 315 |
+
int64_t output_zero_point) override;
|
| 316 |
+
|
| 317 |
+
at::Tensor apply_relu(
|
| 318 |
+
const at::Tensor& input,
|
| 319 |
+
double output_scale,
|
| 320 |
+
int64_t output_zero_point) override;
|
| 321 |
+
|
| 322 |
+
at::Tensor apply_dynamic(
|
| 323 |
+
const at::Tensor& input,
|
| 324 |
+
bool reduce_range=false) override;
|
| 325 |
+
|
| 326 |
+
std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() override;
|
| 327 |
+
|
| 328 |
+
static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> prepack(
|
| 329 |
+
at::Tensor weight,
|
| 330 |
+
std::optional<at::Tensor> bias,
|
| 331 |
+
torch::List<int64_t> stride,
|
| 332 |
+
torch::List<int64_t> padding,
|
| 333 |
+
torch::List<int64_t> output_padding,
|
| 334 |
+
torch::List<int64_t> dilation,
|
| 335 |
+
int64_t groups,
|
| 336 |
+
bool transpose);
|
| 337 |
+
|
| 338 |
+
torch::List<int64_t> stride() const override {
|
| 339 |
+
return stride_;
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
torch::List<int64_t> padding() const override {
|
| 343 |
+
return padding_;
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
torch::List<int64_t> output_padding() const override {
|
| 347 |
+
return output_padding_;
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
torch::List<int64_t> dilation() const override {
|
| 351 |
+
return dilation_;
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
int64_t groups() const override {
|
| 355 |
+
return groups_;
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
bool transpose() const override {
|
| 359 |
+
return transpose_;
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
bool per_channel() const {
|
| 363 |
+
return is_per_channel_;
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
private:
|
| 367 |
+
std::mutex qnnp_mutex_;
|
| 368 |
+
template <bool ReluFused>
|
| 369 |
+
at::Tensor apply_impl(
|
| 370 |
+
const at::Tensor& input,
|
| 371 |
+
double output_scale,
|
| 372 |
+
int64_t output_zero_point);
|
| 373 |
+
|
| 374 |
+
#ifdef USE_XNNPACK
|
| 375 |
+
template <typename scalar_t, bool ReluFused>
|
| 376 |
+
at::Tensor apply_impl_xnnp(
|
| 377 |
+
const at::Tensor& input,
|
| 378 |
+
double output_scale,
|
| 379 |
+
int64_t output_zero_point);
|
| 380 |
+
#endif // USE_XNNPACK
|
| 381 |
+
};
|
| 382 |
+
|
| 383 |
+
enum class Activation : uint8_t { NONE = 0, RELU = 1 };
|
| 384 |
+
|
| 385 |
+
#if defined(__ANDROID__) && !defined(__NDK_MAJOR__)
|
| 386 |
+
template <class T>
|
| 387 |
+
inline float Round(const float x) {
|
| 388 |
+
return ::nearbyintf(x);
|
| 389 |
+
}
|
| 390 |
+
inline double Round(const double x) {
|
| 391 |
+
return ::nearbyint(x);
|
| 392 |
+
}
|
| 393 |
+
#else
|
| 394 |
+
template <class T>
|
| 395 |
+
inline T Round(const T x) {
|
| 396 |
+
return std::nearbyint(x);
|
| 397 |
+
}
|
| 398 |
+
#endif
|
| 399 |
+
|
| 400 |
+
template<typename T>
|
| 401 |
+
inline T QuantizeValue(float scale, int32_t zero_point, float value) {
|
| 402 |
+
const int32_t qmin = std::numeric_limits<T>::min();
|
| 403 |
+
const int32_t qmax = std::numeric_limits<T>::max();
|
| 404 |
+
auto r = zero_point + static_cast<int32_t>(Round(value / scale));
|
| 405 |
+
r = std::max(r, qmin);
|
| 406 |
+
r = std::min(r, qmax);
|
| 407 |
+
return static_cast<T>(r);
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
template<typename T>
|
| 411 |
+
inline std::pair<T, T> activationLimits(
|
| 412 |
+
float scale,
|
| 413 |
+
int32_t zero_point,
|
| 414 |
+
Activation Ac) {
|
| 415 |
+
switch (Ac) {
|
| 416 |
+
case Activation::NONE:
|
| 417 |
+
return {std::numeric_limits<T>::min(),
|
| 418 |
+
std::numeric_limits<T>::max()};
|
| 419 |
+
case Activation::RELU:
|
| 420 |
+
return {QuantizeValue<T>(scale, zero_point, 0.0),
|
| 421 |
+
std::numeric_limits<T>::max()};
|
| 422 |
+
default:
|
| 423 |
+
#ifdef _MSC_VER
|
| 424 |
+
__assume(0);
|
| 425 |
+
#else
|
| 426 |
+
__builtin_unreachable();
|
| 427 |
+
#endif
|
| 428 |
+
}
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
namespace at {
|
| 432 |
+
namespace native {
|
| 433 |
+
namespace qnnp_avgpool_helper {
|
| 434 |
+
Tensor qnnpack_avg_pool2d(
|
| 435 |
+
Tensor input,
|
| 436 |
+
IntArrayRef kernel_size,
|
| 437 |
+
IntArrayRef stride,
|
| 438 |
+
IntArrayRef padding,
|
| 439 |
+
bool ceil_mode,
|
| 440 |
+
bool count_include_pad,
|
| 441 |
+
std::optional<int64_t> divisor_override);
|
| 442 |
+
} // qnnp_avgpool_helper
|
| 443 |
+
} // namespace native
|
| 444 |
+
} // namespace at
|
| 445 |
+
|
| 446 |
+
namespace {
|
| 447 |
+
C10_UNUSED std::vector<float> generate_requantization_scales(
|
| 448 |
+
const at::Tensor& weight_scales,
|
| 449 |
+
const float input_scale,
|
| 450 |
+
const float output_scale,
|
| 451 |
+
std::vector<float>& requant_scales) {
|
| 452 |
+
// Since weight scale is allocated with padding
|
| 453 |
+
// weight_scales.numel() gives us padded num elements.
|
| 454 |
+
const auto num_output_channels_padded = weight_scales.numel();
|
| 455 |
+
float *const weight_scales_data = weight_scales.data_ptr<float>();
|
| 456 |
+
if (static_cast<int64_t>(requant_scales.size()) < num_output_channels_padded) {
|
| 457 |
+
requant_scales.resize(num_output_channels_padded);
|
| 458 |
+
}
|
| 459 |
+
for (const auto i : c10::irange(num_output_channels_padded)) {
|
| 460 |
+
const auto inverse_output_scale = 1.f /output_scale;
|
| 461 |
+
requant_scales[i] = (weight_scales_data[i] * input_scale) * inverse_output_scale;
|
| 462 |
+
TORCH_CHECK(
|
| 463 |
+
(requant_scales[i] > 0.0f && std::isnormal(requant_scales[i])),
|
| 464 |
+
"failed to create op with requantization scale: ",
|
| 465 |
+
requant_scales[i],
|
| 466 |
+
": requantization scale must be finite and positive");
|
| 467 |
+
}
|
| 468 |
+
return requant_scales;
|
| 469 |
+
}
|
| 470 |
+
|
| 471 |
+
C10_UNUSED std::pair<std::vector<uint8_t>, at::Tensor> make_zero_points_and_scales_tensor(
|
| 472 |
+
const at::Tensor& weight_contig,
|
| 473 |
+
bool transpose = false,
|
| 474 |
+
uint32_t groups = 1
|
| 475 |
+
) {
|
| 476 |
+
const int out_ch_idx = transpose ? 1 : 0;
|
| 477 |
+
const auto num_output_channels = weight_contig.size(out_ch_idx) * (transpose ? groups : 1);
|
| 478 |
+
// Add 8 to account for bufferring needed by QNNPACK.
|
| 479 |
+
const auto num_output_channels_padded = num_output_channels + kPaddingChannels;
|
| 480 |
+
const auto qtype = weight_contig.qscheme();
|
| 481 |
+
std::vector<uint8_t> weight_zp(num_output_channels_padded, 0);
|
| 482 |
+
// Adjust weight zero point, similar to weight data.
|
| 483 |
+
if (qtype == at::kPerTensorAffine) {
|
| 484 |
+
for (const auto i : c10::irange(num_output_channels)) {
|
| 485 |
+
weight_zp[i] = (uint8_t)(weight_contig.q_zero_point() + 128);
|
| 486 |
+
}
|
| 487 |
+
} else if (qtype == at::kPerChannelAffine) {
|
| 488 |
+
TORCH_CHECK(
|
| 489 |
+
weight_contig.q_per_channel_zero_points().scalar_type() == at::kLong,
|
| 490 |
+
"Per channel zero points dtype must be long int.");
|
| 491 |
+
const int64_t* per_channel_zero_points =
|
| 492 |
+
weight_contig.q_per_channel_zero_points().data_ptr<int64_t>();
|
| 493 |
+
for (const auto i : c10::irange(num_output_channels)) {
|
| 494 |
+
weight_zp[i] = (uint8_t)(per_channel_zero_points[i] + 128);
|
| 495 |
+
}
|
| 496 |
+
} else {
|
| 497 |
+
TORCH_INTERNAL_ASSERT(false, "Unsupported quantization scheme.");
|
| 498 |
+
}
|
| 499 |
+
at:: Tensor weight_scales =
|
| 500 |
+
at::empty(
|
| 501 |
+
{num_output_channels_padded},
|
| 502 |
+
at::device(at::kCPU).dtype(at::kFloat));
|
| 503 |
+
float *const weight_scales_data = weight_scales.data_ptr<float>();
|
| 504 |
+
if (qtype == at::kPerTensorAffine) {
|
| 505 |
+
for (const auto i : c10::irange(num_output_channels)) {
|
| 506 |
+
weight_scales_data[i] = weight_contig.q_scale();
|
| 507 |
+
}
|
| 508 |
+
} else if (qtype == at::kPerChannelAffine) {
|
| 509 |
+
TORCH_CHECK(
|
| 510 |
+
weight_contig.q_per_channel_scales().scalar_type() == at::kDouble,
|
| 511 |
+
"Per channel scales dtype must be double.");
|
| 512 |
+
const double *const per_channel_scales =
|
| 513 |
+
weight_contig.q_per_channel_scales().data_ptr<double>();
|
| 514 |
+
for (const auto i : c10::irange(num_output_channels)) {
|
| 515 |
+
weight_scales_data[i] = static_cast<float>(per_channel_scales[i]);
|
| 516 |
+
}
|
| 517 |
+
} else {
|
| 518 |
+
TORCH_INTERNAL_ASSERT(false, "Unsupported quantization scheme.");
|
| 519 |
+
}
|
| 520 |
+
for (const auto i : c10::irange(num_output_channels, num_output_channels_padded)) {
|
| 521 |
+
weight_scales_data[i] = 1.f;
|
| 522 |
+
}
|
| 523 |
+
return {weight_zp, weight_scales};
|
| 524 |
+
}
|
| 525 |
+
} // namespace
|
| 526 |
+
|
| 527 |
+
#endif
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/QuantUtils.h
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/core/List.h>
|
| 5 |
+
#include <ATen/TensorOperators.h>
|
| 6 |
+
#include <c10/util/irange.h>
|
| 7 |
+
#include <algorithm>
|
| 8 |
+
#include <cmath>
|
| 9 |
+
|
| 10 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 11 |
+
#include <ATen/Functions.h>
|
| 12 |
+
#include <ATen/NativeFunctions.h>
|
| 13 |
+
#else
|
| 14 |
+
#include <ATen/ops/quantize_per_tensor_native.h>
|
| 15 |
+
#include <ATen/ops/quantize_per_channel_native.h>
|
| 16 |
+
#include <ATen/ops/zeros.h>
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
namespace quant_utils {
|
| 20 |
+
namespace {
|
| 21 |
+
float RawUint16ToFp16(unsigned short value) {
|
| 22 |
+
// Convert raw 16 bits half precision floating point number
|
| 23 |
+
// to single precision floating point number.
|
| 24 |
+
const unsigned short sign_bits = value >> 15;
|
| 25 |
+
const unsigned short exponent_bits = value >> 10 & 0x1f;
|
| 26 |
+
const unsigned short significand_bits = value & 0x3ff;
|
| 27 |
+
|
| 28 |
+
const float sign = sign_bits ? -1 : 1;
|
| 29 |
+
const float significand =
|
| 30 |
+
1 + significand_bits * 0.0009765625f; // 0.0009765625f = 0x1p-10 = 2^-10;
|
| 31 |
+
const float exponent = exponent_bits - 0xf;
|
| 32 |
+
|
| 33 |
+
return sign * std::ldexp(significand, exponent);
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
template <typename T>
|
| 37 |
+
bool CheckAndSaturate(T max_val, T* element) {
|
| 38 |
+
if (*element > max_val) {
|
| 39 |
+
*element = max_val;
|
| 40 |
+
return true;
|
| 41 |
+
}
|
| 42 |
+
if (*element < -max_val) {
|
| 43 |
+
*element = -max_val;
|
| 44 |
+
return true;
|
| 45 |
+
}
|
| 46 |
+
return false;
|
| 47 |
+
}
|
| 48 |
+
}
|
| 49 |
+
using namespace std;
|
| 50 |
+
// A structure to hold quantization parameters 'scale' and 'zero_point'.
|
| 51 |
+
// The meaning of these values is as the constants in the quantization equation
|
| 52 |
+
//
|
| 53 |
+
// real_value = scale * (quantized_value - zero_point)
|
| 54 |
+
//
|
| 55 |
+
// In other words, 'zero_point' is the quantized value that corresponds
|
| 56 |
+
// to the real value 0, and 'scale' is the difference of real values
|
| 57 |
+
// corresponding to consecutive quantized values.
|
| 58 |
+
struct TensorQuantizationParams {
|
| 59 |
+
double scale;
|
| 60 |
+
std::int32_t zero_point;
|
| 61 |
+
int precision;
|
| 62 |
+
};
|
| 63 |
+
|
| 64 |
+
// Use fp16_min as the small scale cutoff because we don't want to use scales in
|
| 65 |
+
// fp16 subnormal range. This is to be consistent with Glow and FakeLowP
|
| 66 |
+
// implementation for NNPI.
|
| 67 |
+
constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f;
|
| 68 |
+
|
| 69 |
+
// Following implementation should be identical to fbgemm::ChooseQuantizationParams
|
| 70 |
+
inline TensorQuantizationParams ChooseQuantizationParams(
|
| 71 |
+
float min,
|
| 72 |
+
float max,
|
| 73 |
+
int32_t qmin,
|
| 74 |
+
int32_t qmax,
|
| 75 |
+
bool preserve_sparsity = false,
|
| 76 |
+
bool force_scale_power_of_two = false,
|
| 77 |
+
bool reduce_range = false) {
|
| 78 |
+
TORCH_CHECK(
|
| 79 |
+
min <= max,
|
| 80 |
+
"In ChooseQuantizationParams, min should be less than or equal to max");
|
| 81 |
+
|
| 82 |
+
if (reduce_range) {
|
| 83 |
+
qmin = qmin/2;
|
| 84 |
+
qmax = qmax/2;
|
| 85 |
+
}
|
| 86 |
+
if (min < 0 && max > 0 && preserve_sparsity) {
|
| 87 |
+
int symmetric_qmin = -((qmax - qmin) / 2 + 1);
|
| 88 |
+
int symmetric_qmax = (qmax - qmin) / 2;
|
| 89 |
+
double max_scale =
|
| 90 |
+
std::max(fabs(min / symmetric_qmin), fabs(max / symmetric_qmax));
|
| 91 |
+
min = max_scale * symmetric_qmin;
|
| 92 |
+
max = max_scale * symmetric_qmax;
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
// We extend the [min, max] interval to ensure that it contains 0.
|
| 96 |
+
// Otherwise, we would not meet the requirement that 0 be an exactly
|
| 97 |
+
// representable value.
|
| 98 |
+
min = std::min(min, 0.f);
|
| 99 |
+
max = std::max(max, 0.f);
|
| 100 |
+
|
| 101 |
+
TORCH_CHECK(
|
| 102 |
+
qmin < qmax,
|
| 103 |
+
"In ChooseQuantizationParams, qmin should be less than qmax");
|
| 104 |
+
|
| 105 |
+
// Use double precision for intermediate computation but use single precision
|
| 106 |
+
// in final number to reflect the actual number used during quantization.
|
| 107 |
+
double scale = (static_cast<double>(max) - min) / (qmax - qmin);
|
| 108 |
+
// If scale is 0 or too small so its reciprocal is infinity, we arbitrary
|
| 109 |
+
// adjust the scale to 0.1 . We want to avoid scale's reciprocal being
|
| 110 |
+
// infinity because some of fbgemm code pre-computes scale's reciprocal to do
|
| 111 |
+
// multiplication instead of division in the time critical part of code.
|
| 112 |
+
if (float(scale) == 0.0f || std::isinf(1.0f / float(scale))) {
|
| 113 |
+
scale = 0.1;
|
| 114 |
+
}
|
| 115 |
+
TORCH_CHECK(scale > 0, "quantization scale should be > 0");
|
| 116 |
+
|
| 117 |
+
if (force_scale_power_of_two) {
|
| 118 |
+
if (scale < 1) {
|
| 119 |
+
scale = 1.0 / (1 << static_cast<int>(floor(log(1.0 / scale) / log(2))));
|
| 120 |
+
} else {
|
| 121 |
+
scale = 1 << static_cast<int>(ceil(log(scale) / log(2)));
|
| 122 |
+
}
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
// Cut off small scale
|
| 126 |
+
if (scale < SMALL_SCALE_THRESHOLD) {
|
| 127 |
+
float org_scale = scale;
|
| 128 |
+
scale = SMALL_SCALE_THRESHOLD;
|
| 129 |
+
// Adjust the min and max based on the new scale
|
| 130 |
+
if (min == 0.0f) {
|
| 131 |
+
max = SMALL_SCALE_THRESHOLD * (qmax - qmin);
|
| 132 |
+
} else if (max == 0.0f) {
|
| 133 |
+
min = -SMALL_SCALE_THRESHOLD * (qmax - qmin);
|
| 134 |
+
} else {
|
| 135 |
+
float amplifier = SMALL_SCALE_THRESHOLD / org_scale;
|
| 136 |
+
min *= amplifier;
|
| 137 |
+
max *= amplifier;
|
| 138 |
+
}
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
// Zero-point computation.
|
| 142 |
+
// First the initial floating-point computation. The zero-point can be
|
| 143 |
+
// determined from solving an affine equation for any known pair
|
| 144 |
+
// (real value, corresponding quantized value).
|
| 145 |
+
// We know two such pairs: (rmin, qmin) and (rmax, qmax).
|
| 146 |
+
// The arithmetic error on the zero point computed from either pair
|
| 147 |
+
// will be roughly machine_epsilon * (sum of absolute values of terms)
|
| 148 |
+
// so we want to use the variant that adds the smaller terms.
|
| 149 |
+
double zero_point_from_min = qmin - min / static_cast<double>(scale);
|
| 150 |
+
double zero_point_from_max = qmax - max / static_cast<double>(scale);
|
| 151 |
+
double zero_point_from_min_error =
|
| 152 |
+
std::abs(qmin) - std::abs(min / static_cast<double>(scale));
|
| 153 |
+
double zero_point_from_max_error =
|
| 154 |
+
std::abs(qmax) - std::abs(max / static_cast<double>(scale));
|
| 155 |
+
double initial_zero_point =
|
| 156 |
+
zero_point_from_min_error < zero_point_from_max_error
|
| 157 |
+
? zero_point_from_min
|
| 158 |
+
: zero_point_from_max;
|
| 159 |
+
|
| 160 |
+
// for symmetric quantization (preserve_sparsity == true), we force zero_point
|
| 161 |
+
// to be a middle value between qmin and qmax.
|
| 162 |
+
// If either min or max is 0, then we just use 0 as zero_point.
|
| 163 |
+
if (min < 0 && max > 0 && preserve_sparsity) {
|
| 164 |
+
initial_zero_point = static_cast<double>(qmin + qmax) / 2;
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
// Now we need to nudge the zero point to be an integer
|
| 168 |
+
// (our zero points are integer, and this is motivated by the requirement
|
| 169 |
+
// to be able to represent the real value "0" exactly as a quantized value,
|
| 170 |
+
// which is required in multiple places, for example in Im2col with zero
|
| 171 |
+
// padding).
|
| 172 |
+
int32_t nudged_zero_point = 0;
|
| 173 |
+
if (initial_zero_point < qmin) {
|
| 174 |
+
nudged_zero_point = qmin;
|
| 175 |
+
} else if (initial_zero_point > qmax) {
|
| 176 |
+
nudged_zero_point = qmax;
|
| 177 |
+
} else {
|
| 178 |
+
nudged_zero_point = nearbyint(initial_zero_point);
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
TensorQuantizationParams result;
|
| 182 |
+
result.scale = scale;
|
| 183 |
+
result.zero_point = nudged_zero_point;
|
| 184 |
+
return result;
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
// This function helps to convert the Conv1D dimensions usable by the Conv2d op.
|
| 188 |
+
constexpr int64_t kConv1dSqueezeDim = 0;
|
| 189 |
+
static C10_UNUSED torch::List<int64_t> MakeArgForConv1d(const torch::List<int64_t>& arg,
|
| 190 |
+
int64_t base_value) {
|
| 191 |
+
TORCH_CHECK(!arg.empty(), "Argument must have elements.");
|
| 192 |
+
torch::List<int64_t> result({arg.get(0), base_value});
|
| 193 |
+
if (arg.size() == 1) {
|
| 194 |
+
result[1] = arg.get(0);
|
| 195 |
+
} else {
|
| 196 |
+
result[1] = arg.get(1);
|
| 197 |
+
}
|
| 198 |
+
result[kConv1dSqueezeDim] = base_value;
|
| 199 |
+
return result;
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
// The range for using FP16 quantization of weights requires that the elements
|
| 203 |
+
// should be in the range of [5.96e-8, 65504]. If it is out of range, then the
|
| 204 |
+
// number will be saturated to max or min representable values by FP16.
|
| 205 |
+
inline void HandleWeightsSaturation(int64_t N, float* weight) {
|
| 206 |
+
const float kFp16Max = RawUint16ToFp16(0x7BFF);
|
| 207 |
+
bool found_out_of_range = false;
|
| 208 |
+
for (const auto i : c10::irange(N)) {
|
| 209 |
+
bool saturate = CheckAndSaturate<float>(kFp16Max, weight + i);
|
| 210 |
+
if (saturate) {
|
| 211 |
+
found_out_of_range = true;
|
| 212 |
+
}
|
| 213 |
+
}
|
| 214 |
+
if (found_out_of_range) {
|
| 215 |
+
TORCH_WARN("FOUND weight out of range ");
|
| 216 |
+
}
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
// Util function for quantizing bias.
|
| 220 |
+
inline at::Tensor QuantizeBias(
|
| 221 |
+
bool is_per_channel,
|
| 222 |
+
const at::Tensor& bias,
|
| 223 |
+
const at::Tensor& weight_contig,
|
| 224 |
+
double input_scale) {
|
| 225 |
+
at::Tensor qbias;
|
| 226 |
+
if (is_per_channel) {
|
| 227 |
+
auto bias_quant_scales =
|
| 228 |
+
weight_contig.q_per_channel_scales() * input_scale;
|
| 229 |
+
auto bias_zp = at::zeros(bias_quant_scales.sizes(), c10::kInt);
|
| 230 |
+
qbias = at::native::quantize_per_channel(
|
| 231 |
+
bias, bias_quant_scales, bias_zp, 0, c10::kQInt32);
|
| 232 |
+
} else {
|
| 233 |
+
qbias = at::native::quantize_per_tensor(
|
| 234 |
+
bias, weight_contig.q_scale() * input_scale, 0, c10::kQInt32);
|
| 235 |
+
}
|
| 236 |
+
return qbias;
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
} // namespace quant_utils
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/QuantizedOps.h
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Tensor.h>
|
| 3 |
+
#include <ATen/core/IListRef.h>
|
| 4 |
+
#include <ATen/Dispatch.h>
|
| 5 |
+
#include <ATen/TensorIterator.h>
|
| 6 |
+
#include <ATen/native/Activation.h>
|
| 7 |
+
#include <ATen/native/DispatchStub.h>
|
| 8 |
+
|
| 9 |
+
namespace at {
|
| 10 |
+
namespace native {
|
| 11 |
+
|
| 12 |
+
using qrelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/);
|
| 13 |
+
using qrelu_leaky_fn = void (*)(Tensor& /*out*/, const Tensor& /*qx*/,
|
| 14 |
+
const Scalar& /*negval_*/);
|
| 15 |
+
using qgelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/, GeluType /* approximate */);
|
| 16 |
+
using qsigmoid_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/, double output_scale, int64_t output_zero_point);
|
| 17 |
+
using qhardsigmoid_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/);
|
| 18 |
+
using qclamp_fn = void (*)(
|
| 19 |
+
const at::Tensor& /*qx*/,
|
| 20 |
+
const Scalar& min,
|
| 21 |
+
const Scalar& max,
|
| 22 |
+
at::Tensor& /*qy*/);
|
| 23 |
+
using qclamp_minmax_fn = void (*)(
|
| 24 |
+
const at::Tensor& /*qx*/,
|
| 25 |
+
const Scalar& /*min or max*/,
|
| 26 |
+
at::Tensor& /*qy*/);
|
| 27 |
+
using qthreshold_fn = void (*)(
|
| 28 |
+
const at::Tensor& /*qx*/,
|
| 29 |
+
const Scalar& threshold,
|
| 30 |
+
const Scalar& value,
|
| 31 |
+
at::Tensor& /*qy*/);
|
| 32 |
+
using qtanh_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/);
|
| 33 |
+
using qelu_fn = void(*)(
|
| 34 |
+
const at::Tensor& /*qx*/,
|
| 35 |
+
const Scalar& /*alpha*/,
|
| 36 |
+
const Scalar& /*scale*/,
|
| 37 |
+
const Scalar& /*input_scale*/,
|
| 38 |
+
at::Tensor& /*qy*/);
|
| 39 |
+
using qbinary_fn =
|
| 40 |
+
void (*)(Tensor& /*out*/, const Tensor& /*self*/, const Tensor& /*other*/);
|
| 41 |
+
using qadd_scalar_fn =
|
| 42 |
+
void (*)(Tensor& /*out*/, const Tensor& /*self*/, const Scalar& other /*other*/);
|
| 43 |
+
using qhardswish_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/);
|
| 44 |
+
using qdropout_fn = void(*)(
|
| 45 |
+
const at::Tensor& /*qx*/,
|
| 46 |
+
const Scalar& /*p*/,
|
| 47 |
+
bool training /*training*/,
|
| 48 |
+
at::Tensor& /*qy*/);
|
| 49 |
+
using qmaxpool_2d_fn = void (*)(
|
| 50 |
+
const Tensor& qx,
|
| 51 |
+
int64_t iC, // input/output channels
|
| 52 |
+
int64_t iH,
|
| 53 |
+
int64_t iW, // input sizes
|
| 54 |
+
int64_t oH,
|
| 55 |
+
int64_t oW, // output sizes
|
| 56 |
+
int64_t kH,
|
| 57 |
+
int64_t kW, // kernel size
|
| 58 |
+
int64_t sH,
|
| 59 |
+
int64_t sW, // strides
|
| 60 |
+
int64_t pH,
|
| 61 |
+
int64_t pW, // padding
|
| 62 |
+
int64_t dH,
|
| 63 |
+
int64_t dW, // dilation
|
| 64 |
+
Tensor& qy);
|
| 65 |
+
using qmaxpool_3d_fn = void (*)(
|
| 66 |
+
const Tensor& qx,
|
| 67 |
+
int64_t iC, // input/output channels
|
| 68 |
+
int64_t iT,
|
| 69 |
+
int64_t iH,
|
| 70 |
+
int64_t iW, // input sizes
|
| 71 |
+
int64_t oT,
|
| 72 |
+
int64_t oH,
|
| 73 |
+
int64_t oW, // output sizes
|
| 74 |
+
int64_t kT,
|
| 75 |
+
int64_t kH,
|
| 76 |
+
int64_t kW, // kernel size
|
| 77 |
+
int64_t sT,
|
| 78 |
+
int64_t sH,
|
| 79 |
+
int64_t sW, // strides
|
| 80 |
+
int64_t pT,
|
| 81 |
+
int64_t pH,
|
| 82 |
+
int64_t pW, // padding
|
| 83 |
+
int64_t dT,
|
| 84 |
+
int64_t dH,
|
| 85 |
+
int64_t dW, // dilation
|
| 86 |
+
Tensor& qy);
|
| 87 |
+
using qadaptive_avg_pool2d_fn = void (*)(
|
| 88 |
+
const Tensor& qx,
|
| 89 |
+
Tensor& qy,
|
| 90 |
+
int64_t sizeB,
|
| 91 |
+
int64_t sizeC,
|
| 92 |
+
int64_t isizeH,
|
| 93 |
+
int64_t isizeW,
|
| 94 |
+
int64_t osizeH,
|
| 95 |
+
int64_t osizeW,
|
| 96 |
+
int64_t istrideB,
|
| 97 |
+
int64_t istrideC,
|
| 98 |
+
int64_t istrideH,
|
| 99 |
+
int64_t istrideW);
|
| 100 |
+
using qadaptive_avg_pool3d_fn = void (*)(
|
| 101 |
+
const Tensor& qx,
|
| 102 |
+
Tensor& qy,
|
| 103 |
+
int64_t sizeB,
|
| 104 |
+
int64_t sizeC,
|
| 105 |
+
int64_t isizeD,
|
| 106 |
+
int64_t isizeH,
|
| 107 |
+
int64_t isizeW,
|
| 108 |
+
int64_t osizeD,
|
| 109 |
+
int64_t osizeH,
|
| 110 |
+
int64_t osizeW,
|
| 111 |
+
int64_t istrideB,
|
| 112 |
+
int64_t istrideC,
|
| 113 |
+
int64_t istrideD,
|
| 114 |
+
int64_t istrideH,
|
| 115 |
+
int64_t istrideW);
|
| 116 |
+
using qavg_pool2d_fn = void (*)(
|
| 117 |
+
const Tensor& qx,
|
| 118 |
+
Tensor& qy,
|
| 119 |
+
int64_t nBatch,
|
| 120 |
+
int64_t nInputPlane,
|
| 121 |
+
int64_t inputWidth,
|
| 122 |
+
int64_t inputHeight,
|
| 123 |
+
int64_t outputWidth,
|
| 124 |
+
int64_t outputHeight,
|
| 125 |
+
int kW,
|
| 126 |
+
int kH,
|
| 127 |
+
int dW,
|
| 128 |
+
int dH,
|
| 129 |
+
int padW,
|
| 130 |
+
int padH,
|
| 131 |
+
bool count_include_pad,
|
| 132 |
+
std::optional<int64_t> divisor_override);
|
| 133 |
+
|
| 134 |
+
using qavg_pool3d_fn = void (*)(
|
| 135 |
+
const Tensor& qx,
|
| 136 |
+
Tensor& qy,
|
| 137 |
+
int64_t nBatch,
|
| 138 |
+
int64_t nInputPlane,
|
| 139 |
+
int64_t inputWidth,
|
| 140 |
+
int64_t inputHeight,
|
| 141 |
+
int64_t inputDepth,
|
| 142 |
+
int64_t outputWidth,
|
| 143 |
+
int64_t outputHeight,
|
| 144 |
+
int64_t outputDepth,
|
| 145 |
+
int kW,
|
| 146 |
+
int kH,
|
| 147 |
+
int kD,
|
| 148 |
+
int dW,
|
| 149 |
+
int dH,
|
| 150 |
+
int dD,
|
| 151 |
+
int padW,
|
| 152 |
+
int padH,
|
| 153 |
+
int padD,
|
| 154 |
+
bool count_include_pad,
|
| 155 |
+
std::optional<int64_t> divisor_override);
|
| 156 |
+
|
| 157 |
+
using qupsample_bilinear2d_fn = void (*)(
|
| 158 |
+
Tensor& output,
|
| 159 |
+
const Tensor& input,
|
| 160 |
+
int64_t input_height,
|
| 161 |
+
int64_t input_width,
|
| 162 |
+
int64_t output_height,
|
| 163 |
+
int64_t output_width,
|
| 164 |
+
int64_t nbatch,
|
| 165 |
+
int64_t channels,
|
| 166 |
+
bool align_corners,
|
| 167 |
+
std::optional<double> scales_h,
|
| 168 |
+
std::optional<double> scales_w);
|
| 169 |
+
|
| 170 |
+
using qcat_nhwc_fn = Tensor (*)(
|
| 171 |
+
const MaterializedITensorListRef& qxs,
|
| 172 |
+
int64_t dim,
|
| 173 |
+
double scale,
|
| 174 |
+
int64_t zero_point);
|
| 175 |
+
using qtopk_fn = void(*)(Tensor&, Tensor&, const Tensor&, int64_t, int64_t, bool, bool);
|
| 176 |
+
|
| 177 |
+
using qbatch_norm_fn = void(*)(int64_t, int64_t, int64_t, int64_t, int64_t, const Tensor&, const Tensor&, const Tensor&, Tensor&);
|
| 178 |
+
|
| 179 |
+
using qnormalize_fn = void (*)(
|
| 180 |
+
const Tensor& /* X */,
|
| 181 |
+
const Tensor& /* gamma */,
|
| 182 |
+
const Tensor& /* beta */,
|
| 183 |
+
bool /* affine_per_channel */,
|
| 184 |
+
int /* num_channels */,
|
| 185 |
+
int /* num_groups */,
|
| 186 |
+
int64_t /* M */,
|
| 187 |
+
int64_t /* N */,
|
| 188 |
+
double /* eps */,
|
| 189 |
+
Tensor* /* Y */);
|
| 190 |
+
|
| 191 |
+
using qmean_inner_dim_fn = void (*)(
|
| 192 |
+
const Tensor& /* X */,
|
| 193 |
+
OptionalIntArrayRef /* opt_dim */,
|
| 194 |
+
bool /* keepdim */,
|
| 195 |
+
std::optional<ScalarType> /* opt_dtype */,
|
| 196 |
+
Tensor& /* Y */);
|
| 197 |
+
|
| 198 |
+
using qstd_inner_dim_fn = void (*)(
|
| 199 |
+
const Tensor& /* X */,
|
| 200 |
+
OptionalIntArrayRef /* dim */,
|
| 201 |
+
const std::optional<Scalar>& /* correction */,
|
| 202 |
+
bool /* keepdim */,
|
| 203 |
+
Tensor& /* Y */);
|
| 204 |
+
|
| 205 |
+
using qnormalize_nhwc_fn = void (*)(
|
| 206 |
+
const Tensor& /* X */,
|
| 207 |
+
const Tensor& /* gamma */,
|
| 208 |
+
const Tensor& /* beta */,
|
| 209 |
+
bool /* affine_per_channel */,
|
| 210 |
+
int /* num_channels */,
|
| 211 |
+
int /* num_groups */,
|
| 212 |
+
int64_t /* M */,
|
| 213 |
+
int64_t /* N */,
|
| 214 |
+
double /* eps */,
|
| 215 |
+
Tensor* /* Y */);
|
| 216 |
+
|
| 217 |
+
using qprelu_fn = void (*)(Tensor& /*out*/, const Tensor& /*qx*/,
|
| 218 |
+
const Tensor& /*qw*/);
|
| 219 |
+
|
| 220 |
+
DECLARE_DISPATCH(qadaptive_avg_pool2d_fn, qadaptive_avg_pool2d_nhwc_stub);
|
| 221 |
+
DECLARE_DISPATCH(qadaptive_avg_pool3d_fn, qadaptive_avg_pool3d_ndhwc_stub);
|
| 222 |
+
DECLARE_DISPATCH(qadd_scalar_fn, qadd_scalar_relu_stub);
|
| 223 |
+
DECLARE_DISPATCH(qadd_scalar_fn, qadd_scalar_stub);
|
| 224 |
+
DECLARE_DISPATCH(qavg_pool2d_fn, qavg_pool2d_nhwc_stub);
|
| 225 |
+
DECLARE_DISPATCH(qavg_pool3d_fn, qavg_pool3d_nhwc_stub);
|
| 226 |
+
DECLARE_DISPATCH(qbatch_norm_fn, qbatch_norm_relu_stub);
|
| 227 |
+
DECLARE_DISPATCH(qbatch_norm_fn, qbatch_norm_stub);
|
| 228 |
+
DECLARE_DISPATCH(qbinary_fn, qadd_relu_stub);
|
| 229 |
+
DECLARE_DISPATCH(qbinary_fn, qadd_stub);
|
| 230 |
+
DECLARE_DISPATCH(qbinary_fn, qmul_relu_stub);
|
| 231 |
+
DECLARE_DISPATCH(qbinary_fn, qmul_stub);
|
| 232 |
+
DECLARE_DISPATCH(qcat_nhwc_fn, qcat_nhwc_stub);
|
| 233 |
+
DECLARE_DISPATCH(qcat_nhwc_fn, qcat_relu_nhwc_stub);
|
| 234 |
+
DECLARE_DISPATCH(qclamp_fn, qclamp_stub);
|
| 235 |
+
DECLARE_DISPATCH(qclamp_minmax_fn, qclamp_min_stub);
|
| 236 |
+
DECLARE_DISPATCH(qclamp_minmax_fn, qclamp_max_stub);
|
| 237 |
+
DECLARE_DISPATCH(qelu_fn, qelu_stub);
|
| 238 |
+
DECLARE_DISPATCH(qhardsigmoid_fn, qhardsigmoid_stub);
|
| 239 |
+
DECLARE_DISPATCH(qhardswish_fn, qhardswish_stub);
|
| 240 |
+
DECLARE_DISPATCH(qdropout_fn, qdropout_stub);
|
| 241 |
+
DECLARE_DISPATCH(qmaxpool_2d_fn, qmaxpool_2d_nhwc_stub);
|
| 242 |
+
DECLARE_DISPATCH(qmaxpool_3d_fn, qmaxpool_3d_nthwc_stub);
|
| 243 |
+
DECLARE_DISPATCH(qnormalize_fn, quantized_normalize_stub);
|
| 244 |
+
DECLARE_DISPATCH(qnormalize_nhwc_fn, quantized_groupnorm_nhwc_stub);
|
| 245 |
+
DECLARE_DISPATCH(qrelu_fn, qrelu_stub);
|
| 246 |
+
DECLARE_DISPATCH(qrelu_leaky_fn, qrelu_leaky_stub);
|
| 247 |
+
DECLARE_DISPATCH(qgelu_fn, qgelu_stub);
|
| 248 |
+
DECLARE_DISPATCH(qsigmoid_fn, qsigmoid_stub);
|
| 249 |
+
DECLARE_DISPATCH(qtanh_fn, qtanh_stub);
|
| 250 |
+
DECLARE_DISPATCH(qthreshold_fn, qthreshold_stub);
|
| 251 |
+
DECLARE_DISPATCH(qtopk_fn, qtopk_stub);
|
| 252 |
+
DECLARE_DISPATCH(qupsample_bilinear2d_fn, qupsample_bilinear2d_nhwc_stub);
|
| 253 |
+
DECLARE_DISPATCH(qmean_inner_dim_fn, qmean_inner_dim_stub);
|
| 254 |
+
DECLARE_DISPATCH(qstd_inner_dim_fn, qstd_inner_dim_stub);
|
| 255 |
+
DECLARE_DISPATCH(qprelu_fn, qprelu_stub);
|
| 256 |
+
|
| 257 |
+
} // namespace native
|
| 258 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/RuyUtils.h
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#ifdef USE_RUY_QMATMUL
|
| 4 |
+
|
| 5 |
+
#include <ruy/ruy.h>
|
| 6 |
+
|
| 7 |
+
namespace at {
|
| 8 |
+
namespace native {
|
| 9 |
+
namespace ruy_utils {
|
| 10 |
+
|
| 11 |
+
ruy::Context* get_ruy_context();
|
| 12 |
+
|
| 13 |
+
void quantize_multiplier(double scale,
|
| 14 |
+
int* multiplier_fixedpoint,
|
| 15 |
+
int* multiplier_exponent);
|
| 16 |
+
|
| 17 |
+
} // namespace ruy_utils
|
| 18 |
+
} // namespace native
|
| 19 |
+
} // namespace
|
| 20 |
+
|
| 21 |
+
#endif // USE_RUY_QMATMUL
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/XnnpackUtils.h
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#ifdef USE_XNNPACK
|
| 4 |
+
#include <cstdint>
|
| 5 |
+
|
| 6 |
+
#include <ATen/core/Tensor.h>
|
| 7 |
+
#include <ATen/native/xnnpack/Common.h>
|
| 8 |
+
|
| 9 |
+
using xnnpack_operator = at::native::xnnpack::Operator;
|
| 10 |
+
|
| 11 |
+
namespace at {
|
| 12 |
+
namespace native {
|
| 13 |
+
namespace xnnp_utils {
|
| 14 |
+
|
| 15 |
+
/*
|
| 16 |
+
* Return shape in the same order as the memory format
|
| 17 |
+
* e.g. channels_last will return NHWC instead of NCHW
|
| 18 |
+
*/
|
| 19 |
+
std::vector<size_t> get_mem_format_aware_shape(const at::Tensor& in);
|
| 20 |
+
|
| 21 |
+
/*
|
| 22 |
+
* Input is always int8_t, output can be [int8_t, uint8_t].
|
| 23 |
+
* input + offset = output
|
| 24 |
+
* int8_t + 128 = uint8_t
|
| 25 |
+
* int8_t + 0 = int8_t
|
| 26 |
+
*/
|
| 27 |
+
template <typename PT>
|
| 28 |
+
void q8_copy_int8_weight_and_add_offset(const at::Tensor& in, at::Tensor& out);
|
| 29 |
+
|
| 30 |
+
template <int kSpatialDim>
|
| 31 |
+
Tensor convert_conv_weights_to_channel_last_tensor(
|
| 32 |
+
const at::Tensor& src,
|
| 33 |
+
int groups,
|
| 34 |
+
bool transpose);
|
| 35 |
+
|
| 36 |
+
/*
|
| 37 |
+
* Series of create wrapper functions to call xnn_create_[de]conv* functions.
|
| 38 |
+
*/
|
| 39 |
+
C10_ALWAYS_INLINE
|
| 40 |
+
enum xnn_status xnnp_create_convolution2d_nhwc(
|
| 41 |
+
uint32_t pad_top,
|
| 42 |
+
uint32_t pad_right,
|
| 43 |
+
uint32_t pad_bottom,
|
| 44 |
+
uint32_t pad_left,
|
| 45 |
+
uint32_t kernel_h,
|
| 46 |
+
uint32_t kernel_w,
|
| 47 |
+
uint32_t stride_h,
|
| 48 |
+
uint32_t stride_w,
|
| 49 |
+
uint32_t dilation_h,
|
| 50 |
+
uint32_t dilation_w,
|
| 51 |
+
uint32_t groups,
|
| 52 |
+
size_t group_input_channels,
|
| 53 |
+
size_t group_output_channels,
|
| 54 |
+
size_t ip_chan_stride,
|
| 55 |
+
size_t op_chan_stride,
|
| 56 |
+
int8_t izp,
|
| 57 |
+
float ip_scale,
|
| 58 |
+
int8_t kzp,
|
| 59 |
+
const float* k_scales,
|
| 60 |
+
const int8_t* kernel,
|
| 61 |
+
const int32_t* bias,
|
| 62 |
+
int8_t ozp,
|
| 63 |
+
float op_scale,
|
| 64 |
+
int8_t op_min,
|
| 65 |
+
int8_t op_max,
|
| 66 |
+
uint32_t flags,
|
| 67 |
+
xnn_operator_t* op,
|
| 68 |
+
bool per_channel,
|
| 69 |
+
bool transpose) {
|
| 70 |
+
/* Symmetric quantization forces kzp = 0 */
|
| 71 |
+
TORCH_CHECK(!kzp, "XNNPACK Q[SC]8 conv kernels expects kernel zero point to be zero."
|
| 72 |
+
"But got: ", kzp);
|
| 73 |
+
|
| 74 |
+
if (transpose) {
|
| 75 |
+
TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!");
|
| 76 |
+
return xnn_create_deconvolution2d_nhwc_qs8(
|
| 77 |
+
pad_top, /* uint32_t output_padding_top */
|
| 78 |
+
pad_right, /* uint32_t output_padding_right */
|
| 79 |
+
pad_bottom, /* uint32_t output_padding_bottom */
|
| 80 |
+
pad_left, /* uint32_t output_padding_left */
|
| 81 |
+
kernel_h, /* uint32_t kernel_height */
|
| 82 |
+
kernel_w, /* uint32_t kernel_width */
|
| 83 |
+
stride_h, /* uint32_t stride_height */
|
| 84 |
+
stride_w, /* uint32_t stride_width */
|
| 85 |
+
dilation_h, /* uint32_t dilation_height */
|
| 86 |
+
dilation_w, /* uint32_t dilation_width */
|
| 87 |
+
groups, /* uint32_t groups */
|
| 88 |
+
group_input_channels, /* size_t group_input_channels */
|
| 89 |
+
group_output_channels, /* size_t group_output_channels */
|
| 90 |
+
ip_chan_stride, /* size_t input_pixel_stride */
|
| 91 |
+
op_chan_stride, /* size_t output_pixel_stride */
|
| 92 |
+
izp, /* int8_t input_zero_point */
|
| 93 |
+
ip_scale, /* float input_scale */
|
| 94 |
+
k_scales[0], /* float kernel_scale */
|
| 95 |
+
kernel, /* const int8_t* kernel */
|
| 96 |
+
bias, /* const int32_t* bias */
|
| 97 |
+
ozp, /* int8_t output_zero_point */
|
| 98 |
+
op_scale, /* float output_scale */
|
| 99 |
+
op_min, /* int8_t output_min */
|
| 100 |
+
op_max, /* int8_t output_max */
|
| 101 |
+
flags, /* uint32_t flags */
|
| 102 |
+
nullptr, /* xnn_caches_t caches */
|
| 103 |
+
nullptr, /* xnn_weights_cache_t weights_cache */
|
| 104 |
+
op); /* xnn_operator_t* deconvolution_op_out */
|
| 105 |
+
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
if (!per_channel) {
|
| 109 |
+
return xnn_create_convolution2d_nhwc_qs8(
|
| 110 |
+
pad_top, /* uint32_t input_padding_top */
|
| 111 |
+
pad_right, /* uint32_t input_padding_right */
|
| 112 |
+
pad_bottom, /* uint32_t input_padding_bottom */
|
| 113 |
+
pad_left, /* uint32_t input_padding_left */
|
| 114 |
+
kernel_h, /* uint32_t kernel_height */
|
| 115 |
+
kernel_w, /* uint32_t kernel_width */
|
| 116 |
+
stride_h, /* uint32_t subsampling_height */
|
| 117 |
+
stride_w, /* uint32_t subsampling_width */
|
| 118 |
+
dilation_h, /* uint32_t dilation_height */
|
| 119 |
+
dilation_w, /* uint32_t dilation_width */
|
| 120 |
+
groups, /* uint32_t groups */
|
| 121 |
+
group_input_channels, /* size_t group_input_channels */
|
| 122 |
+
group_output_channels, /* size_t group_output_channels*/
|
| 123 |
+
ip_chan_stride, /* size_t input_channel_stride */
|
| 124 |
+
op_chan_stride, /* size_t output_channel_stride */
|
| 125 |
+
izp, /* int8_t input_zero_point */
|
| 126 |
+
ip_scale, /* float input_scale */
|
| 127 |
+
k_scales[0], /* float kernel_scale */
|
| 128 |
+
kernel, /* const int8_t* kernel */
|
| 129 |
+
bias, /* const int32_t* bias */
|
| 130 |
+
ozp, /* int8_t output_zero_point */
|
| 131 |
+
op_scale, /* float output_scale */
|
| 132 |
+
op_min, /* int8_t output_min */
|
| 133 |
+
op_max, /* int8_t output_max */
|
| 134 |
+
flags, /* uint32_t flags */
|
| 135 |
+
nullptr, /* xnn_caches_t caches */
|
| 136 |
+
nullptr, /* xnn_weights_cache_t weights_cache */
|
| 137 |
+
op); /* xnn_operator_t* convolution_op_out */
|
| 138 |
+
} else { /* per_channel */
|
| 139 |
+
return xnn_create_convolution2d_nhwc_qs8_qc8w(
|
| 140 |
+
pad_top, /* uint32_t input_padding_top */
|
| 141 |
+
pad_right, /* uint32_t input_padding_right */
|
| 142 |
+
pad_bottom, /* uint32_t input_padding_bottom */
|
| 143 |
+
pad_left, /* uint32_t input_padding_left */
|
| 144 |
+
kernel_h, /* uint32_t kernel_height */
|
| 145 |
+
kernel_w, /* uint32_t kernel_width */
|
| 146 |
+
stride_h, /* uint32_t subsampling_height */
|
| 147 |
+
stride_w, /* uint32_t subsampling_width */
|
| 148 |
+
dilation_h, /* uint32_t dilation_height */
|
| 149 |
+
dilation_w, /* uint32_t dilation_width */
|
| 150 |
+
groups, /* uint32_t groups */
|
| 151 |
+
group_input_channels, /* size_t group_input_channels */
|
| 152 |
+
group_output_channels, /* size_t group_output_channels*/
|
| 153 |
+
ip_chan_stride, /* size_t input_channel_stride */
|
| 154 |
+
op_chan_stride, /* size_t output_channel_stride */
|
| 155 |
+
izp, /* int8_t input_zero_point */
|
| 156 |
+
ip_scale, /* float input_scale */
|
| 157 |
+
k_scales, /* const float* kernel_scale */
|
| 158 |
+
kernel, /* const int8_t* kernel */
|
| 159 |
+
bias, /* const int32_t* bias */
|
| 160 |
+
ozp, /* int8_t output_zero_point */
|
| 161 |
+
op_scale, /* float output_scale */
|
| 162 |
+
op_min, /* int8_t output_min */
|
| 163 |
+
op_max, /* int8_t output_max */
|
| 164 |
+
flags, /* uint32_t flags */
|
| 165 |
+
nullptr, /* xnn_caches_t caches */
|
| 166 |
+
nullptr, /* xnn_weights_cache_t weights_cache */
|
| 167 |
+
op); /* xnn_operator_t* convolution_op_out */
|
| 168 |
+
}
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
/*
|
| 172 |
+
* Series of reshape wrapper functions to call xnn_reshape_[de]conv* functions.
|
| 173 |
+
*/
|
| 174 |
+
C10_ALWAYS_INLINE
|
| 175 |
+
enum xnn_status xnnp_reshape_convolution2d_nhwc(
|
| 176 |
+
xnn_operator_t op,
|
| 177 |
+
size_t batch,
|
| 178 |
+
size_t in_h,
|
| 179 |
+
size_t in_w,
|
| 180 |
+
pthreadpool_t pt_pool,
|
| 181 |
+
bool per_channel = false,
|
| 182 |
+
bool transpose = false,
|
| 183 |
+
uint32_t adj_h = 0,
|
| 184 |
+
uint32_t adj_w = 0) {
|
| 185 |
+
if(transpose) {
|
| 186 |
+
TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!");
|
| 187 |
+
return xnn_reshape_deconvolution2d_nhwc_qs8(
|
| 188 |
+
op, /* xnn_operator_t deconvolution_op */
|
| 189 |
+
batch, /* size_t batch_size */
|
| 190 |
+
in_h, /* size_t input_height */
|
| 191 |
+
in_w, /* size_t input_width */
|
| 192 |
+
adj_h, /* uint32_t adjustment_height */
|
| 193 |
+
adj_w, /* uint32_t adjustment_width */
|
| 194 |
+
nullptr, /* size_t* output_height_out */
|
| 195 |
+
nullptr, /* size_t* output_width_out */
|
| 196 |
+
pt_pool); /* pthreadpool_t threadpool */
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
size_t workspace_size = SIZE_MAX;
|
| 200 |
+
size_t workspace_alignment = SIZE_MAX;
|
| 201 |
+
|
| 202 |
+
if (!per_channel) {
|
| 203 |
+
return xnn_reshape_convolution2d_nhwc_qs8(
|
| 204 |
+
op, /* xnn_operator_t convolution_op */
|
| 205 |
+
batch, /* size_t batch_size */
|
| 206 |
+
in_h, /* size_t input_height */
|
| 207 |
+
in_w, /* size_t input_width */
|
| 208 |
+
&workspace_size, /* size_t* workspace_size */
|
| 209 |
+
&workspace_alignment, /* size_t* workspace_alignment */
|
| 210 |
+
nullptr, /* size_t* output_height_out */
|
| 211 |
+
nullptr, /* size_t* output_width_out */
|
| 212 |
+
pt_pool); /* pthreadpool_t threadpool */
|
| 213 |
+
} else { /* per_channel */
|
| 214 |
+
return xnn_reshape_convolution2d_nhwc_qs8_qc8w(
|
| 215 |
+
op, /* xnn_operator_t convolution_op */
|
| 216 |
+
batch, /* size_t batch_size */
|
| 217 |
+
in_h, /* size_t input_height */
|
| 218 |
+
in_w, /* size_t input_width */
|
| 219 |
+
&workspace_size, /* size_t* workspace_size */
|
| 220 |
+
&workspace_alignment, /* size_t* workspace_alignment */
|
| 221 |
+
nullptr, /* size_t* output_height_out */
|
| 222 |
+
nullptr, /* size_t* output_width_out */
|
| 223 |
+
pt_pool); /* pthreadpool_t threadpool */
|
| 224 |
+
}
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
/*
|
| 229 |
+
* Series of setup wrapper functions to call xnn_setup_[de]conv* functions.
|
| 230 |
+
*/
|
| 231 |
+
C10_ALWAYS_INLINE
|
| 232 |
+
enum xnn_status xnnp_setup_convolution2d_nhwc(
|
| 233 |
+
xnn_operator_t op,
|
| 234 |
+
const int8_t* inp,
|
| 235 |
+
int8_t* outp,
|
| 236 |
+
bool per_channel = false,
|
| 237 |
+
bool transpose = false) {
|
| 238 |
+
if(transpose) {
|
| 239 |
+
TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!");
|
| 240 |
+
|
| 241 |
+
return xnn_setup_deconvolution2d_nhwc_qs8(
|
| 242 |
+
op, /* xnn_operator_t deconvolution_op */
|
| 243 |
+
inp, /* const int8_t* input */
|
| 244 |
+
outp); /* int8_t* output */
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
if (!per_channel) {
|
| 248 |
+
return xnn_setup_convolution2d_nhwc_qs8(
|
| 249 |
+
op, /* xnn_operator_t deconvolution_op */
|
| 250 |
+
nullptr, /* void workspace */
|
| 251 |
+
inp, /* const int8_t* input */
|
| 252 |
+
outp); /* int8_t* output */
|
| 253 |
+
} else { /* per_channel */
|
| 254 |
+
return xnn_setup_convolution2d_nhwc_qs8_qc8w(
|
| 255 |
+
op, /* xnn_operator_t deconvolution_op */
|
| 256 |
+
nullptr, /* void workspace */
|
| 257 |
+
inp, /* const int8_t* input */
|
| 258 |
+
outp); /* int8_t* output */
|
| 259 |
+
}
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
/*
|
| 264 |
+
* Series of wrapper functions to call xnn_create* and xnn_setup*
|
| 265 |
+
* functions for linear
|
| 266 |
+
*/
|
| 267 |
+
C10_ALWAYS_INLINE
|
| 268 |
+
enum xnn_status xnnp_create_fully_connected_nc(
|
| 269 |
+
size_t input_channels,
|
| 270 |
+
size_t output_channels,
|
| 271 |
+
size_t input_stride,
|
| 272 |
+
size_t output_stride,
|
| 273 |
+
int8_t input_zero_point,
|
| 274 |
+
float input_scale,
|
| 275 |
+
int8_t kernel_zero_point,
|
| 276 |
+
float kernel_scale,
|
| 277 |
+
const int8_t* kernel,
|
| 278 |
+
const int32_t* bias,
|
| 279 |
+
int8_t output_zero_point,
|
| 280 |
+
float output_scale,
|
| 281 |
+
int8_t output_min,
|
| 282 |
+
int8_t output_max,
|
| 283 |
+
uint32_t flags,
|
| 284 |
+
xnn_operator_t* fully_connected_op_out) {
|
| 285 |
+
/* Symmetric quantization forces kzp = 0 */
|
| 286 |
+
TORCH_CHECK(!kernel_zero_point, "XNNPACK QS8 linear kernel expects kernel zero point to be zero."
|
| 287 |
+
"But got: ", kernel_zero_point);
|
| 288 |
+
return xnn_create_fully_connected_nc_qs8(
|
| 289 |
+
input_channels, /* size_t input_channels */
|
| 290 |
+
output_channels, /* size_t output_channels */
|
| 291 |
+
input_stride, /* size_t input_stride */
|
| 292 |
+
output_stride, /* size_t output_stride */
|
| 293 |
+
input_zero_point, /* int8_t input_zero_point */
|
| 294 |
+
input_scale, /* float input_scale */
|
| 295 |
+
kernel_scale, /* float kernel_scale */
|
| 296 |
+
kernel, /* const int8_t* kernel */
|
| 297 |
+
bias, /* const int32_t* bias */
|
| 298 |
+
output_zero_point, /* int8_t output_zero_point */
|
| 299 |
+
output_scale, /* float output_scale */
|
| 300 |
+
output_min, /* int8_t output_min */
|
| 301 |
+
output_max, /* int8_t output_max */
|
| 302 |
+
flags, /* uint32_t flags */
|
| 303 |
+
nullptr, /* xnn_caches_t caches */
|
| 304 |
+
nullptr, /* xnn_weights_cache_t */
|
| 305 |
+
fully_connected_op_out); /* xnn_operator_t* fully_connected_op_out */
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
C10_ALWAYS_INLINE
|
| 309 |
+
enum xnn_status xnnp_reshape_fully_connected_nc(
|
| 310 |
+
xnn_operator_t fully_connected_op,
|
| 311 |
+
size_t batch_size,
|
| 312 |
+
pthreadpool_t threadpool) {
|
| 313 |
+
return xnn_reshape_fully_connected_nc_qs8(
|
| 314 |
+
fully_connected_op, /* xnn_operator_t fully_connected_op */
|
| 315 |
+
batch_size, /* size_t batch_size */
|
| 316 |
+
threadpool); /* pthreadpool_t threadpool */
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
C10_ALWAYS_INLINE
|
| 320 |
+
enum xnn_status xnnp_setup_fully_connected_nc(
|
| 321 |
+
xnn_operator_t fully_connected_op,
|
| 322 |
+
const int8_t* input,
|
| 323 |
+
int8_t* output) {
|
| 324 |
+
return xnn_setup_fully_connected_nc_qs8(
|
| 325 |
+
fully_connected_op, /* xnn_operator_t fully_connected_op */
|
| 326 |
+
input, /* const int8_t* input */
|
| 327 |
+
output /* int8_t* output */
|
| 328 |
+
);
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
} // namespace xnnp_utils
|
| 332 |
+
} // namespace native
|
| 333 |
+
} // namespace at
|
| 334 |
+
|
| 335 |
+
#endif // USE_XNNPACK
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/conv_serialization.h
ADDED
|
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/core/List.h>
|
| 5 |
+
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
| 6 |
+
#include <ATen/native/quantized/cpu/QnnpackUtils.h>
|
| 7 |
+
#include <ATen/native/quantized/cpu/OnednnUtils.h>
|
| 8 |
+
#include <c10/util/irange.h>
|
| 9 |
+
#if !defined(__s390x__) && !defined(__powerpc__)
|
| 10 |
+
#include <cpuinfo.h>
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 14 |
+
#include <ATen/Functions.h>
|
| 15 |
+
#else
|
| 16 |
+
#include <ATen/ops/from_blob.h>
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#include <tuple>
|
| 21 |
+
|
| 22 |
+
/* Convolution prepacked parameters serialization.
|
| 23 |
+
*
|
| 24 |
+
* Version 1
|
| 25 |
+
*
|
| 26 |
+
* - Fields:
|
| 27 |
+
* 1. weight
|
| 28 |
+
* 2. bias
|
| 29 |
+
* 3. stride x kSpatialDim
|
| 30 |
+
* 4. padding x kSpatialDim
|
| 31 |
+
* 5. dilation x kSpatialDim
|
| 32 |
+
* 6. groups
|
| 33 |
+
*
|
| 34 |
+
* Version 2
|
| 35 |
+
*
|
| 36 |
+
* - Fields:
|
| 37 |
+
* 0. version (string)
|
| 38 |
+
* 1. list of non-optional tensors
|
| 39 |
+
* 0: packed parameters (int16_t)
|
| 40 |
+
* - kSpatialDim
|
| 41 |
+
* - stride x kSpatialDim
|
| 42 |
+
* - padding x kSpatialDim
|
| 43 |
+
* - dilation x kSpatialDim
|
| 44 |
+
* - output_padding x kSpatialDim
|
| 45 |
+
* - groups
|
| 46 |
+
* - transpose (0 or 1)
|
| 47 |
+
* 1: weight
|
| 48 |
+
* 2. list of optional tensors
|
| 49 |
+
* 0: bias
|
| 50 |
+
*
|
| 51 |
+
* Version 3
|
| 52 |
+
*
|
| 53 |
+
* - Fields:
|
| 54 |
+
* 0. version (int64_t)
|
| 55 |
+
* 1. list of int64_t configuration values
|
| 56 |
+
* - kSpatialDim
|
| 57 |
+
* - stride x kSpatialDim
|
| 58 |
+
* - padding x kSpatialDim
|
| 59 |
+
* - dilation x kSpatialDim
|
| 60 |
+
* - output_padding x kSpatialDim
|
| 61 |
+
* - groups
|
| 62 |
+
* - flags (bitmask)
|
| 63 |
+
* - (1 << 0) transpose (1 = yes)
|
| 64 |
+
* 2. list of optional tensors
|
| 65 |
+
* 0: None (helps with type inference)
|
| 66 |
+
* 1: weight (this must be present)
|
| 67 |
+
* 2: bias
|
| 68 |
+
*/
|
| 69 |
+
|
| 70 |
+
using ConvParamsSerializationTypeV2 = std::tuple<
|
| 71 |
+
// version, for versions 2 and up
|
| 72 |
+
std::string,
|
| 73 |
+
// non-optional tensors
|
| 74 |
+
std::vector<at::Tensor>,
|
| 75 |
+
// optional tensors
|
| 76 |
+
std::vector<std::optional<at::Tensor>>>;
|
| 77 |
+
|
| 78 |
+
using ConvParamsSerializationTypeV3 = std::tuple<
|
| 79 |
+
// version, int for versions 3 and up
|
| 80 |
+
int64_t,
|
| 81 |
+
// configuration values
|
| 82 |
+
std::vector<int64_t>,
|
| 83 |
+
// optional tensors
|
| 84 |
+
std::vector<std::optional<at::Tensor>>>;
|
| 85 |
+
|
| 86 |
+
// Parses any historical conv packed params format into
|
| 87 |
+
// the current format.
|
| 88 |
+
template <uint32_t kSpatialDim>
|
| 89 |
+
ConvParamsSerializationTypeV3 parse_conv_serialized_state(c10::IValue v) {
|
| 90 |
+
|
| 91 |
+
// determine the version based on IValue contents
|
| 92 |
+
int version = -1;
|
| 93 |
+
if (v.isTuple()) {
|
| 94 |
+
const auto& elements = v.toTupleRef().elements();
|
| 95 |
+
if (!elements.empty()) {
|
| 96 |
+
auto firstElement = elements[0];
|
| 97 |
+
if (firstElement.isTensor()) {
|
| 98 |
+
version = 1;
|
| 99 |
+
} else if (firstElement.isString()) {
|
| 100 |
+
const std::string& version_str = firstElement.toStringRef();
|
| 101 |
+
// note: not parsing the string to automatically handle bad
|
| 102 |
+
// inputs
|
| 103 |
+
if (version_str == "2") {
|
| 104 |
+
version = 2;
|
| 105 |
+
}
|
| 106 |
+
} else if (firstElement.isInt()) {
|
| 107 |
+
auto raw_version = firstElement.toInt();
|
| 108 |
+
if (raw_version == 3) {
|
| 109 |
+
version = 3;
|
| 110 |
+
}
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
TORCH_INTERNAL_ASSERT(version != -1, "Unable to parse serialization version");
|
| 115 |
+
|
| 116 |
+
if (version == 1) {
|
| 117 |
+
// version 1 - convert to version 3 manually
|
| 118 |
+
|
| 119 |
+
const auto& elements = v.toTupleRef().elements();
|
| 120 |
+
|
| 121 |
+
at::Tensor weight = elements[0].toTensor();
|
| 122 |
+
std::optional<at::Tensor> bias = elements[1].toOptional<at::Tensor>();
|
| 123 |
+
torch::List<at::Tensor> stride_x_kSpatialDim = elements[2].toTensorList();
|
| 124 |
+
torch::List<at::Tensor> padding_x_kSpatialDim = elements[3].toTensorList();
|
| 125 |
+
torch::List<at::Tensor> dilation_x_kSpatialDim = elements[4].toTensorList();
|
| 126 |
+
at::Tensor groups = elements[5].toTensor();
|
| 127 |
+
|
| 128 |
+
std::vector<int64_t> config_vals;
|
| 129 |
+
config_vals.reserve(
|
| 130 |
+
stride_x_kSpatialDim.size() + padding_x_kSpatialDim.size() +
|
| 131 |
+
dilation_x_kSpatialDim.size() + kSpatialDim + 3);
|
| 132 |
+
config_vals.push_back(kSpatialDim);
|
| 133 |
+
for (const auto i : c10::irange(stride_x_kSpatialDim.size())) {
|
| 134 |
+
auto stride = stride_x_kSpatialDim.get(i);
|
| 135 |
+
config_vals.push_back(stride[0].item<int16_t>());
|
| 136 |
+
}
|
| 137 |
+
for (const auto i : c10::irange(padding_x_kSpatialDim.size())) {
|
| 138 |
+
auto padding = padding_x_kSpatialDim.get(i);
|
| 139 |
+
config_vals.push_back(padding[0].item<int16_t>());
|
| 140 |
+
}
|
| 141 |
+
for (const auto i : c10::irange(dilation_x_kSpatialDim.size())) {
|
| 142 |
+
auto dilation = dilation_x_kSpatialDim.get(i);
|
| 143 |
+
config_vals.push_back(dilation[0].item<int16_t>());
|
| 144 |
+
}
|
| 145 |
+
// output_padding does not exist in v1, so we fill in a default value
|
| 146 |
+
for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) {
|
| 147 |
+
config_vals.push_back(0);
|
| 148 |
+
}
|
| 149 |
+
config_vals.push_back(groups[0].item<int16_t>());
|
| 150 |
+
// transpose does not exist in v1, so we fill in a default value
|
| 151 |
+
config_vals.push_back(0);
|
| 152 |
+
|
| 153 |
+
std::vector<std::optional<at::Tensor>> tensors;
|
| 154 |
+
tensors.emplace_back();
|
| 155 |
+
tensors.emplace_back(weight);
|
| 156 |
+
tensors.emplace_back(bias);
|
| 157 |
+
|
| 158 |
+
int64_t version = 3;
|
| 159 |
+
return std::tie(version, config_vals, tensors);
|
| 160 |
+
} else if (version == 2) {
|
| 161 |
+
// version 2
|
| 162 |
+
const auto& elements = v.toTupleRef().elements();
|
| 163 |
+
std::vector<at::Tensor> non_optional = elements[1].toTensorList().vec();
|
| 164 |
+
std::vector<std::optional<at::Tensor>> optional;
|
| 165 |
+
|
| 166 |
+
if (elements[2].isTensorList()) {
|
| 167 |
+
for (const auto& elem : elements[2].toTensorList()) {
|
| 168 |
+
optional.emplace_back(static_cast<at::Tensor>(elem));
|
| 169 |
+
}
|
| 170 |
+
} else {
|
| 171 |
+
for (const auto& elem : elements[2].toList()) {
|
| 172 |
+
optional.emplace_back(static_cast<c10::IValue>(elem).toOptional<at::Tensor>());
|
| 173 |
+
}
|
| 174 |
+
}
|
| 175 |
+
// create default optional value for bias
|
| 176 |
+
if (optional.empty()) {
|
| 177 |
+
optional.emplace_back();
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
auto config_a = non_optional[0].accessor<int16_t, 1>();
|
| 181 |
+
std::vector<int64_t> config_vals;
|
| 182 |
+
config_vals.reserve(config_a.size(0));
|
| 183 |
+
for (const auto i : c10::irange(config_a.size(0))) {
|
| 184 |
+
config_vals.emplace_back(config_a[i]);
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
auto weight = non_optional[1];
|
| 188 |
+
auto bias = optional[0];
|
| 189 |
+
|
| 190 |
+
std::vector<std::optional<at::Tensor>> tensors;
|
| 191 |
+
tensors.emplace_back();
|
| 192 |
+
tensors.emplace_back(weight);
|
| 193 |
+
tensors.emplace_back(bias);
|
| 194 |
+
|
| 195 |
+
int64_t version = 3;
|
| 196 |
+
return std::tie(version, config_vals, tensors);
|
| 197 |
+
} else if (version == 3) {
|
| 198 |
+
return v.to<ConvParamsSerializationTypeV3>();
|
| 199 |
+
} else {
|
| 200 |
+
TORCH_INTERNAL_ASSERT(false, "Unexpected serialized qconv version: ",
|
| 201 |
+
version);
|
| 202 |
+
}
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
#define QCONV_SERIALIZATION_VERSION 2
|
| 206 |
+
|
| 207 |
+
#if QCONV_SERIALIZATION_VERSION == 2
|
| 208 |
+
using ConvParamsSerializationType = ConvParamsSerializationTypeV2;
|
| 209 |
+
|
| 210 |
+
template <uint32_t kSpatialDim>
|
| 211 |
+
ConvParamsSerializationTypeV2 serialize_conv(
|
| 212 |
+
const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>>& params) {
|
| 213 |
+
|
| 214 |
+
std::string version = "2";
|
| 215 |
+
std::vector<at::Tensor> non_optional;
|
| 216 |
+
std::vector<std::optional<at::Tensor>> optional;
|
| 217 |
+
|
| 218 |
+
// create a packed int8_t tensor for conv params
|
| 219 |
+
std::vector<int16_t> params_vec;
|
| 220 |
+
params_vec.push_back(kSpatialDim);
|
| 221 |
+
auto stride = params->stride().vec();
|
| 222 |
+
params_vec.insert(params_vec.end(), stride.begin(), stride.end());
|
| 223 |
+
auto padding = params->padding().vec();
|
| 224 |
+
params_vec.insert(params_vec.end(), padding.begin(), padding.end());
|
| 225 |
+
auto dilation = params->dilation().vec();
|
| 226 |
+
params_vec.insert(params_vec.end(), dilation.begin(), dilation.end());
|
| 227 |
+
auto output_padding = params->output_padding().vec();
|
| 228 |
+
params_vec.insert(params_vec.end(), output_padding.begin(),
|
| 229 |
+
output_padding.end());
|
| 230 |
+
params_vec.push_back(params->groups());
|
| 231 |
+
params_vec.push_back(params->transpose());
|
| 232 |
+
int64_t vec_size = params_vec.size();
|
| 233 |
+
at::Tensor params_tensor = at::from_blob(
|
| 234 |
+
params_vec.data(), {vec_size},
|
| 235 |
+
at::TensorOptions().dtype(at::kShort))
|
| 236 |
+
// clone to retain ownership of the data
|
| 237 |
+
.clone();
|
| 238 |
+
|
| 239 |
+
auto [weight, bias] = params->unpack();
|
| 240 |
+
|
| 241 |
+
non_optional.emplace_back(std::move(params_tensor));
|
| 242 |
+
non_optional.emplace_back(std::move(weight));
|
| 243 |
+
optional.emplace_back(std::move(bias));
|
| 244 |
+
|
| 245 |
+
return std::tie(version, non_optional, optional);
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
#elif QCONV_SERIALIZATION_VERSION == 3
|
| 249 |
+
using ConvParamsSerializationType = ConvParamsSerializationTypeV3;
|
| 250 |
+
|
| 251 |
+
template <uint32_t kSpatialDim>
|
| 252 |
+
ConvParamsSerializationTypeV3 serialize_conv(
|
| 253 |
+
const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>>& params) {
|
| 254 |
+
std::vector<int64_t> config_vals;
|
| 255 |
+
config_vals.push_back(kSpatialDim);
|
| 256 |
+
auto stride = params->stride().vec();
|
| 257 |
+
config_vals.insert(config_vals.end(), stride.begin(), stride.end());
|
| 258 |
+
auto padding = params->padding().vec();
|
| 259 |
+
config_vals.insert(config_vals.end(), padding.begin(), padding.end());
|
| 260 |
+
auto dilation = params->dilation().vec();
|
| 261 |
+
config_vals.insert(config_vals.end(), dilation.begin(), dilation.end());
|
| 262 |
+
auto output_padding = params->output_padding().vec();
|
| 263 |
+
config_vals.insert(config_vals.end(), output_padding.begin(),
|
| 264 |
+
output_padding.end());
|
| 265 |
+
config_vals.push_back(params->groups());
|
| 266 |
+
config_vals.push_back(params->transpose());
|
| 267 |
+
|
| 268 |
+
auto [weight, bias] = params->unpack();
|
| 269 |
+
|
| 270 |
+
std::vector<std::optional<at::Tensor>> tensors;
|
| 271 |
+
tensors.emplace_back();
|
| 272 |
+
tensors.emplace_back(weight);
|
| 273 |
+
tensors.emplace_back(bias);
|
| 274 |
+
|
| 275 |
+
int64_t version = 3;
|
| 276 |
+
return std::tie(version, config_vals, tensors);
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
#else
|
| 280 |
+
#error "Invalid qconv serialization version."
|
| 281 |
+
#endif
|
| 282 |
+
|
| 283 |
+
template <uint32_t kSpatialDim>
|
| 284 |
+
c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> deserialize_conv(
|
| 285 |
+
ConvParamsSerializationTypeV3 state) {
|
| 286 |
+
auto [version, config_vals, tensors] = state;
|
| 287 |
+
TORCH_INTERNAL_ASSERT(version == 3, "Unexpected serialized qconv version: ", version);
|
| 288 |
+
|
| 289 |
+
TORCH_CHECK(tensors.size() == 3, "Wrong number of tensors", tensors.size());
|
| 290 |
+
std::optional<at::Tensor> weight = tensors[1];
|
| 291 |
+
std::optional<at::Tensor> bias = tensors[2];
|
| 292 |
+
TORCH_INTERNAL_ASSERT(weight, "Weight should always be present in serialized qconv.");
|
| 293 |
+
|
| 294 |
+
torch::List<int64_t> stride, padding, output_padding, dilation;
|
| 295 |
+
// skip kSpatialDim
|
| 296 |
+
int idx = 1;
|
| 297 |
+
for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) {
|
| 298 |
+
stride.emplace_back(config_vals.at(idx));
|
| 299 |
+
idx++;
|
| 300 |
+
}
|
| 301 |
+
for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) {
|
| 302 |
+
padding.emplace_back(config_vals.at(idx));
|
| 303 |
+
idx++;
|
| 304 |
+
}
|
| 305 |
+
for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) {
|
| 306 |
+
dilation.emplace_back(config_vals.at(idx));
|
| 307 |
+
idx++;
|
| 308 |
+
}
|
| 309 |
+
for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) {
|
| 310 |
+
TORCH_INTERNAL_ASSERT(idx < static_cast<int64_t>(config_vals.size()),
|
| 311 |
+
"Unexpected index = ", idx, " for config_vals of size ",
|
| 312 |
+
config_vals.size());
|
| 313 |
+
output_padding.emplace_back(config_vals.at(idx));
|
| 314 |
+
idx++;
|
| 315 |
+
}
|
| 316 |
+
int64_t groups = config_vals.at(idx);
|
| 317 |
+
idx++;
|
| 318 |
+
int64_t flags = config_vals.at(idx);
|
| 319 |
+
idx++;
|
| 320 |
+
TORCH_INTERNAL_ASSERT(idx == static_cast<int64_t>(config_vals.size()),
|
| 321 |
+
"Unexpected length of config_vals, expected ",
|
| 322 |
+
idx,
|
| 323 |
+
" got ",
|
| 324 |
+
config_vals.size());
|
| 325 |
+
|
| 326 |
+
bool transpose = flags & (1 << 0);
|
| 327 |
+
|
| 328 |
+
int64_t other_flags = flags & ~(1 << 0);
|
| 329 |
+
TORCH_INTERNAL_ASSERT(other_flags == 0, "Unexpected flags set in ", flags, ".");
|
| 330 |
+
|
| 331 |
+
auto& ctx = at::globalContext();
|
| 332 |
+
|
| 333 |
+
#ifdef USE_FBGEMM
|
| 334 |
+
if (ctx.qEngine() == at::QEngine::X86) {
|
| 335 |
+
#if AT_MKLDNN_ENABLED()
|
| 336 |
+
bool use_onednn = onednn_utils::should_use_onednn_quant(
|
| 337 |
+
weight.value(), transpose, groups, output_padding);
|
| 338 |
+
if (use_onednn) {
|
| 339 |
+
return PackedConvWeightsOnednn<kSpatialDim>::prepack(
|
| 340 |
+
weight.value(),
|
| 341 |
+
bias,
|
| 342 |
+
stride,
|
| 343 |
+
padding,
|
| 344 |
+
output_padding,
|
| 345 |
+
dilation,
|
| 346 |
+
groups,
|
| 347 |
+
transpose
|
| 348 |
+
);
|
| 349 |
+
}
|
| 350 |
+
#endif
|
| 351 |
+
return PackedConvWeight<kSpatialDim>::prepack(
|
| 352 |
+
weight.value(),
|
| 353 |
+
bias,
|
| 354 |
+
stride,
|
| 355 |
+
padding,
|
| 356 |
+
output_padding,
|
| 357 |
+
dilation,
|
| 358 |
+
groups,
|
| 359 |
+
transpose
|
| 360 |
+
);
|
| 361 |
+
} // x86
|
| 362 |
+
#endif
|
| 363 |
+
|
| 364 |
+
#ifdef USE_FBGEMM
|
| 365 |
+
if (ctx.qEngine() == at::QEngine::FBGEMM) {
|
| 366 |
+
return PackedConvWeight<kSpatialDim>::prepack(
|
| 367 |
+
weight.value(),
|
| 368 |
+
bias,
|
| 369 |
+
stride,
|
| 370 |
+
padding,
|
| 371 |
+
output_padding,
|
| 372 |
+
dilation,
|
| 373 |
+
groups,
|
| 374 |
+
transpose
|
| 375 |
+
);
|
| 376 |
+
}
|
| 377 |
+
#endif // USE_FBGEMM
|
| 378 |
+
#ifdef USE_PYTORCH_QNNPACK
|
| 379 |
+
if (ctx.qEngine() == at::QEngine::QNNPACK) {
|
| 380 |
+
TORCH_CHECK(
|
| 381 |
+
kSpatialDim == 2,
|
| 382 |
+
"prepack/__setstate__: QNNPACK only supports Conv2d "
|
| 383 |
+
"now.");
|
| 384 |
+
return PackedConvWeightsQnnp<kSpatialDim>::prepack(
|
| 385 |
+
weight.value(),
|
| 386 |
+
bias,
|
| 387 |
+
stride,
|
| 388 |
+
padding,
|
| 389 |
+
output_padding,
|
| 390 |
+
dilation,
|
| 391 |
+
groups,
|
| 392 |
+
transpose
|
| 393 |
+
);
|
| 394 |
+
}
|
| 395 |
+
#endif // USE_PYTORCH_QNNPACK
|
| 396 |
+
#if AT_MKLDNN_ENABLED()
|
| 397 |
+
if (ctx.qEngine() == at::QEngine::ONEDNN) {
|
| 398 |
+
return PackedConvWeightsOnednn<kSpatialDim>::prepack(
|
| 399 |
+
weight.value(),
|
| 400 |
+
bias,
|
| 401 |
+
stride,
|
| 402 |
+
padding,
|
| 403 |
+
output_padding,
|
| 404 |
+
dilation,
|
| 405 |
+
groups,
|
| 406 |
+
transpose
|
| 407 |
+
);
|
| 408 |
+
}
|
| 409 |
+
#endif // AT_MKLDNN_ENABLED()
|
| 410 |
+
TORCH_CHECK(
|
| 411 |
+
false,
|
| 412 |
+
"Didn't find engine for when deserializing ConvPackedParams: ",
|
| 413 |
+
toString(ctx.qEngine()));
|
| 414 |
+
}
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/fbgemm_utils.h
ADDED
|
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/Tensor.h>
|
| 4 |
+
#include <ATen/native/quantized/PackedParams.h>
|
| 5 |
+
#include <ATen/native/quantized/cpu/EmbeddingPackedParams.h>
|
| 6 |
+
#include <c10/core/QScheme.h>
|
| 7 |
+
#include <c10/util/irange.h>
|
| 8 |
+
|
| 9 |
+
#ifdef USE_FBGEMM
|
| 10 |
+
#include <fbgemm/Fbgemm.h>
|
| 11 |
+
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Winconsistent-missing-destructor-override")
|
| 12 |
+
#include <fbgemm/FbgemmFP16.h>
|
| 13 |
+
C10_DIAGNOSTIC_POP()
|
| 14 |
+
#include <fbgemm/QuantUtils.h>
|
| 15 |
+
|
| 16 |
+
// The struct for the packed weight matrix (PackBMatrix) and the corresponding
|
| 17 |
+
// column offsets used for the fully connect layer, which are both prepared in
|
| 18 |
+
// the prepacking step to save the computations in the inference. Note the
|
| 19 |
+
// column offsets include the sum of the B columns as well as the scalar term
|
| 20 |
+
// B_zero_point * K, whereas the row offsets created by
|
| 21 |
+
// PackAWithQuantRowOffset/PackAWithIm2Col/PackAWithRowOffset are only the sum
|
| 22 |
+
// of the A rows. The column offsets are needed for the asymmetric quantization
|
| 23 |
+
// (affine quantization) of input matrix.
|
| 24 |
+
// Note that in JIT mode we can think of a way to fuse col_offsets with bias.
|
| 25 |
+
struct TORCH_API PackedLinearWeight : public LinearPackedParamsBase {
|
| 26 |
+
PackedLinearWeight(
|
| 27 |
+
std::unique_ptr<fbgemm::PackBMatrix<int8_t>> w,
|
| 28 |
+
std::optional<at::Tensor> bias,
|
| 29 |
+
std::vector<int32_t> col_offsets,
|
| 30 |
+
std::vector<float> w_scale,
|
| 31 |
+
std::vector<int32_t> w_zp,
|
| 32 |
+
c10::QScheme q_scheme)
|
| 33 |
+
: w(std::move(w)),
|
| 34 |
+
bias_(std::move(bias)),
|
| 35 |
+
col_offsets(std::move(col_offsets)),
|
| 36 |
+
w_scale(std::move(w_scale)),
|
| 37 |
+
w_zp(std::move(w_zp)),
|
| 38 |
+
q_scheme(std::move(q_scheme)) {}
|
| 39 |
+
std::unique_ptr<fbgemm::PackBMatrix<int8_t>> w;
|
| 40 |
+
std::optional<at::Tensor> bias_;
|
| 41 |
+
std::vector<int32_t> col_offsets;
|
| 42 |
+
std::vector<float> w_scale;
|
| 43 |
+
std::vector<int32_t> w_zp;
|
| 44 |
+
c10::QScheme q_scheme;
|
| 45 |
+
|
| 46 |
+
at::Tensor apply(
|
| 47 |
+
at::Tensor input,
|
| 48 |
+
double output_scale,
|
| 49 |
+
int64_t output_zero_point) override;
|
| 50 |
+
|
| 51 |
+
at::Tensor apply_relu(
|
| 52 |
+
at::Tensor input,
|
| 53 |
+
double output_scale,
|
| 54 |
+
int64_t output_zero_point) override;
|
| 55 |
+
|
| 56 |
+
at::Tensor& apply_out(
|
| 57 |
+
const at::Tensor& input,
|
| 58 |
+
double output_scale,
|
| 59 |
+
int64_t output_zero_point,
|
| 60 |
+
at::Tensor& output) override;
|
| 61 |
+
|
| 62 |
+
at::Tensor& apply_relu_out(
|
| 63 |
+
const at::Tensor& input,
|
| 64 |
+
double output_scale,
|
| 65 |
+
int64_t output_zero_point,
|
| 66 |
+
at::Tensor& output) override;
|
| 67 |
+
|
| 68 |
+
at::Tensor apply_with_input_q_dq_qweight_dq_output_fp32(
|
| 69 |
+
at::Tensor input,
|
| 70 |
+
double input_scale,
|
| 71 |
+
int64_t input_zero_point) override;
|
| 72 |
+
|
| 73 |
+
at::Tensor apply_with_input_q_dq_qweight_dq_relu_output_fp32(
|
| 74 |
+
at::Tensor input,
|
| 75 |
+
double input_scale,
|
| 76 |
+
int64_t input_zero_point) override;
|
| 77 |
+
|
| 78 |
+
at::Tensor apply_dynamic(at::Tensor input, bool reduce_range = false)
|
| 79 |
+
override;
|
| 80 |
+
|
| 81 |
+
at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range = false)
|
| 82 |
+
override;
|
| 83 |
+
|
| 84 |
+
std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() override;
|
| 85 |
+
|
| 86 |
+
std::optional<at::Tensor> bias() override {
|
| 87 |
+
return bias_;
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
static c10::intrusive_ptr<LinearPackedParamsBase> prepack(
|
| 91 |
+
at::Tensor weight,
|
| 92 |
+
std::optional<at::Tensor> bias);
|
| 93 |
+
|
| 94 |
+
private:
|
| 95 |
+
template <bool ReluFused>
|
| 96 |
+
at::Tensor& apply_impl(
|
| 97 |
+
const at::Tensor& input,
|
| 98 |
+
double output_scale,
|
| 99 |
+
int64_t output_zero_point,
|
| 100 |
+
at::Tensor& output);
|
| 101 |
+
|
| 102 |
+
template <bool ReluFused>
|
| 103 |
+
at::Tensor apply_with_input_q_dq_qweight_dq_output_fp32_impl(
|
| 104 |
+
const at::Tensor& input,
|
| 105 |
+
double input_scale,
|
| 106 |
+
int64_t input_zero_point);
|
| 107 |
+
|
| 108 |
+
template <bool ReluFused>
|
| 109 |
+
at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range = false);
|
| 110 |
+
};
|
| 111 |
+
|
| 112 |
+
struct TORCH_API PackedLinearWeightFp16 : public LinearPackedParamsBase {
|
| 113 |
+
PackedLinearWeightFp16(
|
| 114 |
+
std::unique_ptr<fbgemm::PackedGemmMatrixFP16> w,
|
| 115 |
+
std::optional<at::Tensor> bias)
|
| 116 |
+
: w(std::move(w)), bias_(std::move(bias)) {}
|
| 117 |
+
|
| 118 |
+
std::unique_ptr<fbgemm::PackedGemmMatrixFP16> w;
|
| 119 |
+
std::optional<at::Tensor> bias_;
|
| 120 |
+
|
| 121 |
+
at::Tensor apply(
|
| 122 |
+
at::Tensor /*input*/,
|
| 123 |
+
double /*output_scale*/,
|
| 124 |
+
int64_t /*output_zero_point*/) override {
|
| 125 |
+
TORCH_INTERNAL_ASSERT(false);
|
| 126 |
+
}
|
| 127 |
+
at::Tensor apply_relu(
|
| 128 |
+
at::Tensor /*input*/,
|
| 129 |
+
double /*output_scale*/,
|
| 130 |
+
int64_t /*output_zero_point*/) override {
|
| 131 |
+
TORCH_INTERNAL_ASSERT(false);
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
at::Tensor apply_dynamic(at::Tensor input, bool reduce_range = false)
|
| 135 |
+
override;
|
| 136 |
+
at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range = false)
|
| 137 |
+
override;
|
| 138 |
+
|
| 139 |
+
at::Tensor& apply_dynamic_out(
|
| 140 |
+
const at::Tensor& input,
|
| 141 |
+
at::Tensor& output,
|
| 142 |
+
bool reduce_range = false) override;
|
| 143 |
+
at::Tensor& apply_dynamic_relu_out(
|
| 144 |
+
const at::Tensor& input,
|
| 145 |
+
at::Tensor& output,
|
| 146 |
+
bool reduce_range = false) override;
|
| 147 |
+
|
| 148 |
+
std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() override;
|
| 149 |
+
|
| 150 |
+
std::optional<at::Tensor> bias() override {
|
| 151 |
+
return bias_;
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
static c10::intrusive_ptr<LinearPackedParamsBase> prepack(
|
| 155 |
+
at::Tensor weight,
|
| 156 |
+
std::optional<at::Tensor> bias);
|
| 157 |
+
|
| 158 |
+
void set_bias(std::optional<at::Tensor> bias) override;
|
| 159 |
+
|
| 160 |
+
private:
|
| 161 |
+
template <bool ReluFused>
|
| 162 |
+
at::Tensor& apply_dynamic_impl(const at::Tensor& input, at::Tensor& output);
|
| 163 |
+
};
|
| 164 |
+
|
| 165 |
+
template <int kSpatialDim = 2>
|
| 166 |
+
struct TORCH_API PackedConvWeight : public ConvPackedParamsBase<kSpatialDim> {
|
| 167 |
+
PackedConvWeight(
|
| 168 |
+
std::unique_ptr<fbgemm::PackWeightsForConv<kSpatialDim>> w,
|
| 169 |
+
std::optional<at::Tensor> bias,
|
| 170 |
+
torch::List<int64_t> stride,
|
| 171 |
+
torch::List<int64_t> padding,
|
| 172 |
+
torch::List<int64_t> output_padding,
|
| 173 |
+
torch::List<int64_t> dilation,
|
| 174 |
+
int64_t groups,
|
| 175 |
+
uint8_t transpose,
|
| 176 |
+
std::vector<int32_t> col_offsets,
|
| 177 |
+
std::vector<int64_t> kernel,
|
| 178 |
+
std::vector<float> w_scale,
|
| 179 |
+
std::vector<int32_t> w_zp,
|
| 180 |
+
c10::QScheme q_scheme)
|
| 181 |
+
: w(std::move(w)),
|
| 182 |
+
bias(std::move(bias)),
|
| 183 |
+
stride_(std::move(stride)),
|
| 184 |
+
padding_(std::move(padding)),
|
| 185 |
+
output_padding_(std::move(output_padding)),
|
| 186 |
+
dilation_(std::move(dilation)),
|
| 187 |
+
groups_(groups),
|
| 188 |
+
transpose_(transpose),
|
| 189 |
+
col_offsets(std::move(col_offsets)),
|
| 190 |
+
kernel(std::move(kernel)),
|
| 191 |
+
w_scale(std::move(w_scale)),
|
| 192 |
+
w_zp(std::move(w_zp)),
|
| 193 |
+
q_scheme(q_scheme) {}
|
| 194 |
+
|
| 195 |
+
std::unique_ptr<fbgemm::PackWeightsForConv<kSpatialDim>> w;
|
| 196 |
+
std::optional<at::Tensor> bias;
|
| 197 |
+
torch::List<int64_t> stride_;
|
| 198 |
+
torch::List<int64_t> padding_;
|
| 199 |
+
torch::List<int64_t> output_padding_;
|
| 200 |
+
torch::List<int64_t> dilation_;
|
| 201 |
+
int64_t groups_;
|
| 202 |
+
uint8_t transpose_;
|
| 203 |
+
std::vector<int32_t> col_offsets;
|
| 204 |
+
std::vector<int64_t> kernel;
|
| 205 |
+
std::vector<float> w_scale;
|
| 206 |
+
std::vector<int32_t> w_zp;
|
| 207 |
+
c10::QScheme q_scheme;
|
| 208 |
+
|
| 209 |
+
at::Tensor apply(
|
| 210 |
+
const at::Tensor& input,
|
| 211 |
+
double output_scale,
|
| 212 |
+
int64_t output_zero_point) override;
|
| 213 |
+
|
| 214 |
+
at::Tensor apply_relu(
|
| 215 |
+
const at::Tensor& input,
|
| 216 |
+
double output_scale,
|
| 217 |
+
int64_t output_zero_point) override;
|
| 218 |
+
|
| 219 |
+
at::Tensor apply_dynamic(
|
| 220 |
+
const at::Tensor& input,
|
| 221 |
+
bool reduce_range) override;
|
| 222 |
+
|
| 223 |
+
std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() override;
|
| 224 |
+
|
| 225 |
+
static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> prepack(
|
| 226 |
+
at::Tensor weight,
|
| 227 |
+
std::optional<at::Tensor> bias,
|
| 228 |
+
torch::List<int64_t> stride,
|
| 229 |
+
torch::List<int64_t> padding,
|
| 230 |
+
torch::List<int64_t> output_padding,
|
| 231 |
+
torch::List<int64_t> dilation,
|
| 232 |
+
int64_t groups,
|
| 233 |
+
bool transpose);
|
| 234 |
+
|
| 235 |
+
const float* GetBiasData(at::Tensor* bias);
|
| 236 |
+
|
| 237 |
+
void GetQuantizationParams(
|
| 238 |
+
float act_scale,
|
| 239 |
+
float out_scale,
|
| 240 |
+
std::vector<float>* output_multiplier_float,
|
| 241 |
+
std::vector<float>* act_times_w_scale);
|
| 242 |
+
|
| 243 |
+
torch::List<int64_t> stride() const override {
|
| 244 |
+
return stride_;
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
torch::List<int64_t> padding() const override {
|
| 248 |
+
return padding_;
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
torch::List<int64_t> output_padding() const override {
|
| 252 |
+
return output_padding_;
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
torch::List<int64_t> dilation() const override {
|
| 256 |
+
return dilation_;
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
int64_t groups() const override {
|
| 260 |
+
return groups_;
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
bool transpose() const override {
|
| 264 |
+
return (bool)transpose_;
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
private:
|
| 268 |
+
template <bool ReluFused>
|
| 269 |
+
at::Tensor apply_impl(
|
| 270 |
+
const at::Tensor& input,
|
| 271 |
+
double output_scale,
|
| 272 |
+
int64_t output_zero_point);
|
| 273 |
+
};
|
| 274 |
+
|
| 275 |
+
// PackWeight: Convert the weight from uint8 to int8.
|
| 276 |
+
inline void convert_uint8_int8(
|
| 277 |
+
int len,
|
| 278 |
+
const uint8_t* src_uint8,
|
| 279 |
+
int8_t* dst_int8) {
|
| 280 |
+
for (const auto i : c10::irange(len)) {
|
| 281 |
+
dst_int8[i] = static_cast<int8_t>(static_cast<int32_t>(src_uint8[i]) - 128);
|
| 282 |
+
}
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
// UnpackWeight: Convert the weight from int8 to uint8.
|
| 286 |
+
inline void convert_int8_uint8(
|
| 287 |
+
int len,
|
| 288 |
+
const int8_t* src_int8,
|
| 289 |
+
uint8_t* dst_uint8) {
|
| 290 |
+
for (const auto i : c10::irange(len)) {
|
| 291 |
+
dst_uint8[i] =
|
| 292 |
+
static_cast<uint8_t>(static_cast<int32_t>(src_int8[i]) + 128);
|
| 293 |
+
}
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
namespace at {
|
| 297 |
+
namespace native {
|
| 298 |
+
namespace fbgemm_utils {
|
| 299 |
+
|
| 300 |
+
template <int kSpatialDim = 2>
|
| 301 |
+
fbgemm::conv_param_t<kSpatialDim> MakeFbgemmConvParam(
|
| 302 |
+
int N,
|
| 303 |
+
int C,
|
| 304 |
+
int M,
|
| 305 |
+
const std::vector<int>& image_shape,
|
| 306 |
+
int groups,
|
| 307 |
+
const std::vector<int>& kernels,
|
| 308 |
+
const std::vector<int>& strides,
|
| 309 |
+
const std::vector<int>& pads,
|
| 310 |
+
const std::vector<int>& dilations,
|
| 311 |
+
const std::vector<int>& output_padding = std::vector<int>(kSpatialDim, 0),
|
| 312 |
+
bool transposed = false);
|
| 313 |
+
|
| 314 |
+
// TODO: Remove functions below when ChannelsLast3d is ready.
|
| 315 |
+
Tensor MakeStridedQTensorCPU(
|
| 316 |
+
const IntArrayRef& sizes,
|
| 317 |
+
const IntArrayRef& strides,
|
| 318 |
+
const TensorOptions& options,
|
| 319 |
+
QuantizerPtr quantizer);
|
| 320 |
+
|
| 321 |
+
Tensor MakeEmptyAffineQuantizedChannelsLast3dTensor(
|
| 322 |
+
int64_t N,
|
| 323 |
+
int64_t C,
|
| 324 |
+
int64_t D,
|
| 325 |
+
int64_t H,
|
| 326 |
+
int64_t W,
|
| 327 |
+
const TensorOptions& options,
|
| 328 |
+
double scale,
|
| 329 |
+
int64_t zero_point);
|
| 330 |
+
|
| 331 |
+
Tensor MakeEmptyPerChannelAffineQuantizedChannelsLast3dTensor(
|
| 332 |
+
int64_t N,
|
| 333 |
+
int64_t C,
|
| 334 |
+
int64_t D,
|
| 335 |
+
int64_t H,
|
| 336 |
+
int64_t W,
|
| 337 |
+
const TensorOptions& options,
|
| 338 |
+
const Tensor& scales,
|
| 339 |
+
const Tensor& zero_points);
|
| 340 |
+
|
| 341 |
+
Tensor ConvertToChannelsLast3dTensor(const Tensor& src);
|
| 342 |
+
|
| 343 |
+
template <int kSpatialDim = 2>
|
| 344 |
+
Tensor TransposeConvTensorUnpackConversion(const Tensor& src, int groups);
|
| 345 |
+
|
| 346 |
+
template <int kSpatialDim>
|
| 347 |
+
Tensor ConvertConvWeightsToChannelLastTensor(
|
| 348 |
+
const at::Tensor& src,
|
| 349 |
+
int groups,
|
| 350 |
+
bool transpose);
|
| 351 |
+
} // namespace fbgemm_utils
|
| 352 |
+
} // namespace native
|
| 353 |
+
} // namespace at
|
| 354 |
+
|
| 355 |
+
#endif // USE_FBGEMM
|
| 356 |
+
|
| 357 |
+
struct TORCH_API PackedEmbeddingBagWeight : public EmbeddingPackedParamsBase {
|
| 358 |
+
PackedEmbeddingBagWeight(
|
| 359 |
+
at::Tensor packed_w,
|
| 360 |
+
std::vector<float> w_scale,
|
| 361 |
+
std::vector<float> w_zp,
|
| 362 |
+
int64_t bit_rate,
|
| 363 |
+
c10::QScheme q_scheme,
|
| 364 |
+
int64_t version)
|
| 365 |
+
: packed_w(std::move(packed_w)),
|
| 366 |
+
w_scale(std::move(w_scale)),
|
| 367 |
+
w_zp(std::move(w_zp)),
|
| 368 |
+
bit_rate_(bit_rate),
|
| 369 |
+
q_scheme(q_scheme),
|
| 370 |
+
version_(version) {
|
| 371 |
+
// NOLINTNEXTLINE(clang-analyzer-cplusplus.Move)
|
| 372 |
+
if (!packed_w.is_contiguous()) {
|
| 373 |
+
packed_w = packed_w.contiguous();
|
| 374 |
+
}
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
at::Tensor packed_w;
|
| 378 |
+
std::vector<float> w_scale;
|
| 379 |
+
std::vector<float> w_zp;
|
| 380 |
+
int64_t bit_rate_;
|
| 381 |
+
c10::QScheme q_scheme;
|
| 382 |
+
int64_t version_;
|
| 383 |
+
|
| 384 |
+
at::Tensor unpack() override;
|
| 385 |
+
static c10::intrusive_ptr<EmbeddingPackedParamsBase> prepack(
|
| 386 |
+
at::Tensor weight);
|
| 387 |
+
|
| 388 |
+
int64_t bit_rate() const override {
|
| 389 |
+
return bit_rate_;
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
int64_t version() const override {
|
| 393 |
+
return version_;
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
at::Tensor embeddingbag_byte(
|
| 397 |
+
const at::Tensor& indices,
|
| 398 |
+
const std::optional<at::Tensor>& offsets,
|
| 399 |
+
bool pruned_weights,
|
| 400 |
+
const std::optional<at::Tensor>& per_sample_weights_,
|
| 401 |
+
const std::optional<at::Tensor>& compressed_indices_mapping,
|
| 402 |
+
bool include_last_offset,
|
| 403 |
+
bool is_embedding_op) override;
|
| 404 |
+
|
| 405 |
+
at::Tensor embeddingbag_4bit(
|
| 406 |
+
const at::Tensor& indices,
|
| 407 |
+
const std::optional<at::Tensor>& offsets,
|
| 408 |
+
bool pruned_weights,
|
| 409 |
+
const std::optional<at::Tensor>& per_sample_weights_,
|
| 410 |
+
const std::optional<at::Tensor>& compressed_indices_mapping,
|
| 411 |
+
bool include_last_offset,
|
| 412 |
+
bool is_embedding_op) override;
|
| 413 |
+
};
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/init_qnnpack.h
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#ifdef USE_PYTORCH_QNNPACK
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
namespace native {
|
| 7 |
+
|
| 8 |
+
void initQNNPACK();
|
| 9 |
+
|
| 10 |
+
} // namespace native
|
| 11 |
+
} // namespace at
|
| 12 |
+
|
| 13 |
+
#endif
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/qembeddingbag.h
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Tensor.h>
|
| 3 |
+
#include <cstdint>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
namespace native {
|
| 7 |
+
Tensor& embedding_bag_byte_rowwise_offsets_out(
|
| 8 |
+
Tensor& output,
|
| 9 |
+
const Tensor& weight,
|
| 10 |
+
const Tensor& indices,
|
| 11 |
+
const std::optional<Tensor>& offsets_in,
|
| 12 |
+
const bool /* scale_grad_by_freq */,
|
| 13 |
+
const int64_t /* mode */,
|
| 14 |
+
bool pruned_weights,
|
| 15 |
+
const std::optional<Tensor>& per_sample_weights_,
|
| 16 |
+
const std::optional<Tensor>& compressed_indices_mapping,
|
| 17 |
+
bool include_last_offset);
|
| 18 |
+
|
| 19 |
+
Tensor& embedding_bag_4bit_rowwise_offsets_out(
|
| 20 |
+
Tensor& output,
|
| 21 |
+
const Tensor& weight,
|
| 22 |
+
const Tensor& indices,
|
| 23 |
+
const std::optional<Tensor>& offsets_in,
|
| 24 |
+
const bool /* scale_grad_by_freq */,
|
| 25 |
+
const int64_t /* mode */,
|
| 26 |
+
bool pruned_weights,
|
| 27 |
+
const std::optional<Tensor>& per_sample_weights_,
|
| 28 |
+
const std::optional<Tensor>& compressed_indices_mapping,
|
| 29 |
+
bool include_last_offset);
|
| 30 |
+
|
| 31 |
+
Tensor& qembeddingbag_byte_unpack_out(Tensor& output, const Tensor& packed_weight);
|
| 32 |
+
|
| 33 |
+
} // native
|
| 34 |
+
} // at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/qembeddingbag_prepack.h
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Tensor.h>
|
| 3 |
+
|
| 4 |
+
namespace at { namespace native {
|
| 5 |
+
|
| 6 |
+
Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight);
|
| 7 |
+
|
| 8 |
+
Tensor qembeddingbag_byte_prepack(const Tensor& weight);
|
| 9 |
+
|
| 10 |
+
Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight);
|
| 11 |
+
|
| 12 |
+
} // namespace native
|
| 13 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/transformers/attention.h
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Tensor.h>
|
| 3 |
+
#include <c10/macros/Export.h>
|
| 4 |
+
#include <ATen/native/DispatchStub.h>
|
| 5 |
+
#include <ATen/native/transformers/attention.h>
|
| 6 |
+
#include <optional>
|
| 7 |
+
|
| 8 |
+
namespace at {
|
| 9 |
+
namespace native {
|
| 10 |
+
|
| 11 |
+
using fused_sdp_choice_fn = int64_t (*)(const Tensor& query_, const Tensor& key, const Tensor& value,
|
| 12 |
+
const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale, bool enable_gqa);
|
| 13 |
+
|
| 14 |
+
DECLARE_DISPATCH(fused_sdp_choice_fn, _fused_sdp_choice_stub);
|
| 15 |
+
|
| 16 |
+
TORCH_API Tensor bmm_nt(const Tensor& a, const Tensor& b);
|
| 17 |
+
TORCH_API Tensor masked_softmax(
|
| 18 |
+
Tensor& attn_scores,
|
| 19 |
+
std::optional<Tensor> attn_mask,
|
| 20 |
+
const Tensor& query,
|
| 21 |
+
std::optional<int64_t> mask_type = {});
|
| 22 |
+
|
| 23 |
+
using transform_bias_rescale_qkv_fn = void(*)(
|
| 24 |
+
at::ScalarType type,
|
| 25 |
+
void* _q_k_v,
|
| 26 |
+
const void* _qkv,
|
| 27 |
+
const void* _qkv_bias,
|
| 28 |
+
int64_t B,
|
| 29 |
+
int64_t T,
|
| 30 |
+
int64_t D,
|
| 31 |
+
int64_t num_head);
|
| 32 |
+
|
| 33 |
+
DECLARE_DISPATCH(transform_bias_rescale_qkv_fn, transform_bias_rescale_qkv_stub);
|
| 34 |
+
|
| 35 |
+
TORCH_API Tensor transform0213_gemm_nt_bias(
|
| 36 |
+
const Tensor& a,
|
| 37 |
+
const Tensor& b,
|
| 38 |
+
const Tensor& c,
|
| 39 |
+
const Tensor& query);
|
| 40 |
+
|
| 41 |
+
TORCH_API Tensor bmm_nn(Tensor& out, const Tensor& a, const Tensor& b);
|
| 42 |
+
|
| 43 |
+
TORCH_API void debug_assert_shape(int line, const Tensor& t, c10::IntArrayRef shape);
|
| 44 |
+
|
| 45 |
+
TORCH_API Tensor qkv_projection(
|
| 46 |
+
const Tensor& query,
|
| 47 |
+
const Tensor& key,
|
| 48 |
+
const Tensor& value,
|
| 49 |
+
const int64_t embed_dim,
|
| 50 |
+
const Tensor& qkv_weight);
|
| 51 |
+
|
| 52 |
+
using flash_attention_fn = void (*)(
|
| 53 |
+
const Tensor& output, const Tensor& logsumexp,
|
| 54 |
+
const Tensor& query, const Tensor& key, const Tensor& value,
|
| 55 |
+
double dropout_p, bool is_causal,
|
| 56 |
+
std::optional<Tensor> attn_mask,
|
| 57 |
+
std::optional<double> scale);
|
| 58 |
+
|
| 59 |
+
using flash_attention_backward_fn = void (*)(
|
| 60 |
+
const Tensor& grad_q, const Tensor& grad_k,
|
| 61 |
+
const Tensor& grad_v, const Tensor& grad_out,
|
| 62 |
+
const Tensor& query, const Tensor& key,
|
| 63 |
+
const Tensor& value, const Tensor& out, const Tensor& logsumexp,
|
| 64 |
+
double dropout_p, bool is_causal,
|
| 65 |
+
std::optional<Tensor> attn_mask,
|
| 66 |
+
std::optional<double> scale);
|
| 67 |
+
|
| 68 |
+
DECLARE_DISPATCH(flash_attention_fn, flash_attention_kernel);
|
| 69 |
+
DECLARE_DISPATCH(flash_attention_backward_fn, flash_attention_backward_kernel);
|
| 70 |
+
|
| 71 |
+
} // namespace native
|
| 72 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/transformers/sdp_utils_cpp.h
ADDED
|
@@ -0,0 +1,566 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/Context.h>
|
| 3 |
+
#include <ATen/NestedTensorImpl.h>
|
| 4 |
+
#include <ATen/TensorSubclassLikeUtils.h>
|
| 5 |
+
#include <ATen/TensorUtils.h>
|
| 6 |
+
#include <ATen/core/Tensor.h>
|
| 7 |
+
#include <ATen/core/grad_mode.h>
|
| 8 |
+
#include <ATen/native/DispatchStub.h>
|
| 9 |
+
#include <c10/core/ScalarType.h>
|
| 10 |
+
|
| 11 |
+
#include <c10/util/Exception.h>
|
| 12 |
+
#include <c10/util/env.h>
|
| 13 |
+
#include <c10/util/irange.h>
|
| 14 |
+
|
| 15 |
+
#include <c10/core/SymInt.h>
|
| 16 |
+
#include <c10/core/SymFloat.h>
|
| 17 |
+
#include <c10/util/string_view.h>
|
| 18 |
+
#include <c10/util/Array.h>
|
| 19 |
+
#include <cmath>
|
| 20 |
+
#include <cstdint>
|
| 21 |
+
#include <functional>
|
| 22 |
+
|
| 23 |
+
namespace sdp {
|
| 24 |
+
|
| 25 |
+
constexpr int32_t num_backends = 5;
|
| 26 |
+
enum class SDPBackend {
|
| 27 |
+
error = -1,
|
| 28 |
+
math = 0,
|
| 29 |
+
flash_attention = 1,
|
| 30 |
+
efficient_attention = 2,
|
| 31 |
+
cudnn_attention = 3,
|
| 32 |
+
overrideable = 4
|
| 33 |
+
};
|
| 34 |
+
|
| 35 |
+
// Note that if this changed make sure to update
|
| 36 |
+
// the templated enum in mem_eff/kernel_forward.h and mem_eff/kernel_backward.h
|
| 37 |
+
enum class CustomMaskType {
|
| 38 |
+
NoCustomMask = 0,
|
| 39 |
+
CausalFromTopLeft = 1,
|
| 40 |
+
CausalFromBottomRight = 2,
|
| 41 |
+
NumCustomMaskTypes,
|
| 42 |
+
};
|
| 43 |
+
|
| 44 |
+
struct sdp_params {
|
| 45 |
+
at::Tensor query;
|
| 46 |
+
at::Tensor key;
|
| 47 |
+
at::Tensor value;
|
| 48 |
+
std::optional<at::Tensor> attn_mask;
|
| 49 |
+
double dropout;
|
| 50 |
+
bool is_causal;
|
| 51 |
+
bool enable_gqa;
|
| 52 |
+
};
|
| 53 |
+
|
| 54 |
+
SDPBackend select_sdp_backend_cpp(sdp_params const& kernel_params);
|
| 55 |
+
|
| 56 |
+
inline c10::SymFloat calculate_scale(
|
| 57 |
+
const at::Tensor& query,
|
| 58 |
+
std::optional<double> scale) {
|
| 59 |
+
const auto softmax_scale = scale.has_value()
|
| 60 |
+
? scale.value()
|
| 61 |
+
: (c10::SymFloat(1.0) / (c10::SymFloat(query.sym_size(-1)).sqrt()));
|
| 62 |
+
return c10::SymFloat(softmax_scale);
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
using c10::array_of;
|
| 66 |
+
|
| 67 |
+
inline bool input_requires_grad(sdp_params const& params) {
|
| 68 |
+
const bool any_inputs_require_grad = params.query.requires_grad() ||
|
| 69 |
+
params.key.requires_grad() || params.value.requires_grad();
|
| 70 |
+
const bool gradmode_enabled = at::GradMode::is_enabled();
|
| 71 |
+
return any_inputs_require_grad && gradmode_enabled;
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
inline bool has_for_nested_inputs(sdp_params const& params) {
|
| 75 |
+
return
|
| 76 |
+
(params.query.is_nested() && params.query.layout() == c10::kStrided) ||
|
| 77 |
+
(params.key.is_nested() && params.key.layout() == c10::kStrided) ||
|
| 78 |
+
(params.value.is_nested() && params.value.layout() == c10::kStrided);
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
inline bool has_for_dense_inputs(sdp_params const& params) {
|
| 82 |
+
return !params.query.is_nested() || !params.key.is_nested() || !params.value.is_nested();
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
inline bool has_only_dense_inputs(sdp_params const& params) {
|
| 86 |
+
return !params.query.is_nested() && !params.key.is_nested() && !params.value.is_nested();
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
template <typename dtype_vector>
|
| 90 |
+
inline bool check_tensor_dtype(
|
| 91 |
+
sdp_params const& params,
|
| 92 |
+
dtype_vector allowed_dtypes,
|
| 93 |
+
bool debug) {
|
| 94 |
+
auto query_dtype = params.query.dtype();
|
| 95 |
+
if (!(query_dtype == params.key.dtype() &&
|
| 96 |
+
query_dtype == params.value.dtype() &&
|
| 97 |
+
(std::find(allowed_dtypes.begin(), allowed_dtypes.end(), query_dtype) !=
|
| 98 |
+
allowed_dtypes.end()))) {
|
| 99 |
+
if (debug) {
|
| 100 |
+
TORCH_WARN(
|
| 101 |
+
"Expected query, key and value to all be of dtype: {",
|
| 102 |
+
c10::Join(", ", allowed_dtypes),
|
| 103 |
+
"}. Got ",
|
| 104 |
+
"Query dtype: ",
|
| 105 |
+
params.query.dtype(),
|
| 106 |
+
", Key dtype: ",
|
| 107 |
+
params.key.dtype(),
|
| 108 |
+
", and Value dtype: ",
|
| 109 |
+
params.value.dtype(),
|
| 110 |
+
" instead.");
|
| 111 |
+
}
|
| 112 |
+
return false;
|
| 113 |
+
}
|
| 114 |
+
return true;
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
inline bool try_broadcast_param_size(
|
| 119 |
+
const c10::SymInt q_size,
|
| 120 |
+
const c10::SymInt k_size,
|
| 121 |
+
const c10::SymInt v_size,
|
| 122 |
+
c10::string_view param_name,
|
| 123 |
+
bool debug) {
|
| 124 |
+
auto max_size = std::max({q_size, k_size, v_size});
|
| 125 |
+
if ((q_size != max_size && q_size != 1) ||
|
| 126 |
+
(k_size != max_size && k_size != 1) ||
|
| 127 |
+
(v_size != max_size && v_size != 1)) {
|
| 128 |
+
if (debug) {
|
| 129 |
+
TORCH_WARN(
|
| 130 |
+
"Both fused kernels require query, key and value to have broadcastable ",
|
| 131 |
+
param_name,
|
| 132 |
+
"got Query ",
|
| 133 |
+
param_name,
|
| 134 |
+
q_size,
|
| 135 |
+
", Key ",
|
| 136 |
+
param_name,
|
| 137 |
+
k_size,
|
| 138 |
+
", Value ",
|
| 139 |
+
param_name,
|
| 140 |
+
v_size,
|
| 141 |
+
" instead.");
|
| 142 |
+
}
|
| 143 |
+
return false;
|
| 144 |
+
}
|
| 145 |
+
return true;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
inline bool check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper(
|
| 149 |
+
at::Tensor const& param,
|
| 150 |
+
c10::string_view param_name,
|
| 151 |
+
bool debug) {
|
| 152 |
+
const auto nt_tensor_impl = at::native::get_nested_tensor_impl(param);
|
| 153 |
+
const at::Tensor& sizes = nt_tensor_impl->get_nested_sizes();
|
| 154 |
+
auto num_head_dims = nt_tensor_impl->opt_size(1);
|
| 155 |
+
if (!num_head_dims.has_value()) {
|
| 156 |
+
// num_head_dims is ragged
|
| 157 |
+
if (debug) {
|
| 158 |
+
TORCH_WARN(
|
| 159 |
+
"Fused kernels do not support ragged num_head_dims, ",
|
| 160 |
+
param_name,
|
| 161 |
+
"has a ragged num_heads.");
|
| 162 |
+
}
|
| 163 |
+
return false;
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
auto* sizes_ptr = sizes.data_ptr<int64_t>();
|
| 167 |
+
const int64_t n_tensors = param.size(0);
|
| 168 |
+
const int64_t size_tensor_stride = sizes.stride(0);
|
| 169 |
+
|
| 170 |
+
// This is being called inside sdp with shape [batch, heads, {seq_len}, dim]
|
| 171 |
+
for (const auto i : c10::irange(n_tensors)) {
|
| 172 |
+
if (sizes_ptr[(i * size_tensor_stride) + 1] == 0) {
|
| 173 |
+
if (debug) {
|
| 174 |
+
TORCH_WARN(
|
| 175 |
+
"Fused kernels do not support seq_len == 0, ",
|
| 176 |
+
param_name,
|
| 177 |
+
"has a seq len of 0.");
|
| 178 |
+
}
|
| 179 |
+
return false;
|
| 180 |
+
}
|
| 181 |
+
}
|
| 182 |
+
return true;
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
inline bool check_for_seq_len_0_nested_tensor(sdp_params const& params, bool debug) {
|
| 186 |
+
// When this function is called we are assured that the nt is dim==4
|
| 187 |
+
bool q_is_safe = params.query.is_nested()
|
| 188 |
+
? check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper(
|
| 189 |
+
params.query, "query ", debug)
|
| 190 |
+
: true;
|
| 191 |
+
// short circuit if any is unsafe
|
| 192 |
+
if (!q_is_safe) {
|
| 193 |
+
return false;
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
bool k_is_safe = params.key.is_nested()
|
| 197 |
+
? check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper(
|
| 198 |
+
params.key, "key ", debug)
|
| 199 |
+
: true;
|
| 200 |
+
if (!k_is_safe) {
|
| 201 |
+
return false;
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
bool v_is_safe = params.value.is_nested()
|
| 205 |
+
? check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper(
|
| 206 |
+
params.value, "value ", debug)
|
| 207 |
+
: true;
|
| 208 |
+
if (!v_is_safe) {
|
| 209 |
+
return false;
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
// We now know none of the inputs have ragged num_heads, so we can safely
|
| 213 |
+
// access .size(1)
|
| 214 |
+
auto q_num_heads = params.query.size(1);
|
| 215 |
+
auto k_num_heads = params.key.size(1);
|
| 216 |
+
auto v_num_heads = params.value.size(1);
|
| 217 |
+
bool same_num_heads =
|
| 218 |
+
q_num_heads == k_num_heads && q_num_heads == v_num_heads;
|
| 219 |
+
|
| 220 |
+
if (!same_num_heads) {
|
| 221 |
+
if (input_requires_grad(params)){
|
| 222 |
+
if (debug) {
|
| 223 |
+
TORCH_WARN(
|
| 224 |
+
"Both fused kernels do not support training with broadcasted NT inputs.");
|
| 225 |
+
}
|
| 226 |
+
return false;
|
| 227 |
+
}
|
| 228 |
+
return try_broadcast_param_size(
|
| 229 |
+
q_num_heads, k_num_heads, v_num_heads, "num heads ", debug);
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
return true;
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
inline bool check_nested_tensor(sdp_params const& params, bool debug) {
|
| 236 |
+
// Return false if have nested tensor
|
| 237 |
+
if (!has_only_dense_inputs(params)) {
|
| 238 |
+
if (debug) {
|
| 239 |
+
TORCH_WARN(
|
| 240 |
+
"Both fused kernels of cpp version currently do not support Nested Tensor inputs.");
|
| 241 |
+
}
|
| 242 |
+
return false;
|
| 243 |
+
}
|
| 244 |
+
return true;
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
inline bool check_for_dropout(sdp_params const& params, bool debug) {
|
| 248 |
+
if (params.dropout > 0.0) {
|
| 249 |
+
if (debug) {
|
| 250 |
+
TORCH_WARN("Both fused kernels do not support non-zero dropout.");
|
| 251 |
+
}
|
| 252 |
+
return false;
|
| 253 |
+
}
|
| 254 |
+
return true;
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
inline bool check_requires_grad_and_nested(sdp_params const& params, bool debug) {
|
| 258 |
+
if (input_requires_grad(params)) {
|
| 259 |
+
if (debug) {
|
| 260 |
+
TORCH_WARN(
|
| 261 |
+
"Memory efficient attention currently doesn't support training with NT inputs.");
|
| 262 |
+
}
|
| 263 |
+
return false;
|
| 264 |
+
}
|
| 265 |
+
return true;
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
inline bool check_for_attn_mask(sdp_params const& params, bool debug) {
|
| 269 |
+
if (params.attn_mask.has_value()) {
|
| 270 |
+
if (debug) {
|
| 271 |
+
TORCH_WARN("Flash Attention does not support non-null attn_mask.");
|
| 272 |
+
}
|
| 273 |
+
return false;
|
| 274 |
+
}
|
| 275 |
+
return true;
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
inline bool check_attn_mask_shape(sdp_params const& params, bool debug) {
|
| 279 |
+
auto attn_mask = params.attn_mask;
|
| 280 |
+
if (!attn_mask.has_value()) {
|
| 281 |
+
return true;
|
| 282 |
+
}
|
| 283 |
+
if (attn_mask.value().requires_grad()) {
|
| 284 |
+
return false;
|
| 285 |
+
}
|
| 286 |
+
auto batchSize = params.query.sym_size(0);
|
| 287 |
+
auto qSize = params.query.sym_size(2);
|
| 288 |
+
auto kvSize = params.key.sym_size(2);
|
| 289 |
+
auto num_head = params.query.sym_size(1);
|
| 290 |
+
if (attn_mask.value().sym_size(-2) != qSize && attn_mask.value().sym_size(-2) != 1) {
|
| 291 |
+
return false;
|
| 292 |
+
}
|
| 293 |
+
if (attn_mask.value().sym_size(-1) != kvSize && attn_mask.value().sym_size(-1) != 1) {
|
| 294 |
+
return false;
|
| 295 |
+
}
|
| 296 |
+
if (attn_mask.value().dim() == 2) {
|
| 297 |
+
return true;
|
| 298 |
+
} else if (attn_mask.value().dim() == 4) {
|
| 299 |
+
if ((attn_mask.value().sym_size(0) == 1 || attn_mask.value().sym_size(0) == batchSize)
|
| 300 |
+
&& (attn_mask.value().sym_size(1) == 1 || attn_mask.value().sym_size(1) == num_head)) {
|
| 301 |
+
return true;
|
| 302 |
+
}
|
| 303 |
+
}
|
| 304 |
+
if (debug) {
|
| 305 |
+
TORCH_WARN("Please use the following attn mask shapes: ",
|
| 306 |
+
"2d - ({Q_seq_len, 1} x {KV_seq_len, 1}); ",
|
| 307 |
+
"4d - ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1} x {KV_seq_len, 1})");
|
| 308 |
+
}
|
| 309 |
+
return false;
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
inline bool check_tensor_shapes(sdp_params const& params, bool debug) {
|
| 313 |
+
auto query_dim = params.query.dim();
|
| 314 |
+
if (!(query_dim == params.key.dim() && query_dim == params.value.dim() &&
|
| 315 |
+
(query_dim == 4))) {
|
| 316 |
+
if (debug) {
|
| 317 |
+
TORCH_WARN(
|
| 318 |
+
"All fused kernels requires query, key and value to be 4 dimensional, but got Query dim: ",
|
| 319 |
+
query_dim,
|
| 320 |
+
", Key dim: ",
|
| 321 |
+
params.key.dim(),
|
| 322 |
+
", Value dim: ",
|
| 323 |
+
params.value.dim(),
|
| 324 |
+
" instead.");
|
| 325 |
+
}
|
| 326 |
+
return false;
|
| 327 |
+
}
|
| 328 |
+
return true;
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
inline bool check_safe_kv_broadcast(at::Tensor const& param, bool debug) {
|
| 332 |
+
const auto nt_tensor_impl = at::native::get_nested_tensor_impl(param);
|
| 333 |
+
auto seq_len = nt_tensor_impl->opt_size(2);
|
| 334 |
+
if (!seq_len.has_value()) {
|
| 335 |
+
if (debug) {
|
| 336 |
+
TORCH_WARN(
|
| 337 |
+
"For both fused kernels, if one of key/value batch_size requires "
|
| 338 |
+
"broadcasting and the other does not, then the other must have a ",
|
| 339 |
+
"consistent seq_len dim.")
|
| 340 |
+
}
|
| 341 |
+
return false;
|
| 342 |
+
}
|
| 343 |
+
return true;
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
inline bool check_grouped_query_attention(sdp_params const& params, bool debug) {
|
| 347 |
+
const auto q_num_heads = params.query.sym_size(-3);
|
| 348 |
+
const auto k_num_heads = params.key.sym_size(-3);
|
| 349 |
+
const auto v_num_heads = params.value.sym_size(-3);
|
| 350 |
+
const bool same_kv_heads = k_num_heads == v_num_heads;
|
| 351 |
+
|
| 352 |
+
if (!(same_kv_heads)){
|
| 353 |
+
if (debug) {
|
| 354 |
+
TORCH_WARN(
|
| 355 |
+
"Both fused kernels require key and value to have the same num_heads and batch_size but got: ",
|
| 356 |
+
"Key sizes: ",
|
| 357 |
+
params.key.sizes(),
|
| 358 |
+
", Value sizes: ",
|
| 359 |
+
params.value.sizes(),
|
| 360 |
+
", Query sizes: ",
|
| 361 |
+
params.query.sizes(),
|
| 362 |
+
" instead.");
|
| 363 |
+
}
|
| 364 |
+
return false;
|
| 365 |
+
}
|
| 366 |
+
// Check if grouped query attention is supported and validate the number of
|
| 367 |
+
// heads
|
| 368 |
+
if (q_num_heads % k_num_heads != 0) {
|
| 369 |
+
if (debug) {
|
| 370 |
+
TORCH_WARN(
|
| 371 |
+
"FlashAttentionV2 only supports grouped query attention, where the number of heads in key/value must divide number of heads in query.",
|
| 372 |
+
"Got input Key sizes(): ",
|
| 373 |
+
params.key.sym_size(-3),
|
| 374 |
+
", Value sizes(): ",
|
| 375 |
+
params.value.sym_size(-3),
|
| 376 |
+
", Query sizes(): ",
|
| 377 |
+
params.query.sym_size(-3),
|
| 378 |
+
" instead.");
|
| 379 |
+
}
|
| 380 |
+
return false;
|
| 381 |
+
}
|
| 382 |
+
return true;
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
template <bool supports_gqa>
|
| 386 |
+
inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool debug) {
|
| 387 |
+
// This is expected to be called after check_tensor_shapes ensuring that the
|
| 388 |
+
// size() calls won't error since the inputs are all 4 dimensional
|
| 389 |
+
|
| 390 |
+
auto q_batch_size = params.query.sym_size(0);
|
| 391 |
+
auto k_batch_size = params.key.sym_size(0);
|
| 392 |
+
auto v_batch_size = params.value.sym_size(0);
|
| 393 |
+
|
| 394 |
+
bool same_batch_size =
|
| 395 |
+
q_batch_size == k_batch_size && q_batch_size == v_batch_size;
|
| 396 |
+
|
| 397 |
+
auto q_num_heads = params.query.sym_size(-3);
|
| 398 |
+
auto k_num_heads = params.key.sym_size(-3);
|
| 399 |
+
auto v_num_heads = params.value.sym_size(-3);
|
| 400 |
+
|
| 401 |
+
bool same_num_heads =
|
| 402 |
+
q_num_heads == k_num_heads && q_num_heads == v_num_heads;
|
| 403 |
+
|
| 404 |
+
if (!same_batch_size){
|
| 405 |
+
if(debug) {
|
| 406 |
+
TORCH_WARN(
|
| 407 |
+
"For dense inputs, both fused kernels require query, key and value to have the same batch_size. ",
|
| 408 |
+
"Query.sizes(): ",
|
| 409 |
+
params.query.sizes(),
|
| 410 |
+
", Key.sizes(): ",
|
| 411 |
+
params.key.sizes(),
|
| 412 |
+
", Value.sizes(): ",
|
| 413 |
+
params.value.sizes(),
|
| 414 |
+
" instead. To broadcast dense inputs, try using unsqueeze and expand_to before passing them into the kernel.");
|
| 415 |
+
}
|
| 416 |
+
return false;
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
if(params.enable_gqa && supports_gqa){
|
| 420 |
+
return check_grouped_query_attention(params, debug);
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
if (!same_num_heads){
|
| 424 |
+
if (debug) {
|
| 425 |
+
TORCH_WARN(
|
| 426 |
+
"For dense input, both fused kernels require query, key and value to have the same num_heads. ",
|
| 427 |
+
"Query.sizes(): ",
|
| 428 |
+
params.query.sizes(),
|
| 429 |
+
", Key sizes(): ",
|
| 430 |
+
params.key.sizes(),
|
| 431 |
+
", Value sizes(): ",
|
| 432 |
+
params.value.sizes(),
|
| 433 |
+
" instead. To broadcast dense inputs, try using unsqueeze and expand_to before passing them into the kernel.");
|
| 434 |
+
}
|
| 435 |
+
return false;
|
| 436 |
+
}
|
| 437 |
+
// If all checks pass, return true
|
| 438 |
+
return true;
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
inline bool check_batch_size_nested(sdp_params const& params, bool debug) {
|
| 442 |
+
// This is expected to be called after check_tensor_shapes ensuring that the
|
| 443 |
+
// size() calls won't error since the inputs are all 4 dimensional
|
| 444 |
+
auto q_batch_size = params.query.sym_size(0);
|
| 445 |
+
auto k_batch_size = params.key.sym_size(0);
|
| 446 |
+
auto v_batch_size = params.value.sym_size(0);
|
| 447 |
+
|
| 448 |
+
bool same_batch_size =
|
| 449 |
+
q_batch_size == k_batch_size && q_batch_size == v_batch_size;
|
| 450 |
+
|
| 451 |
+
// num_heads logic for nested input is checked in
|
| 452 |
+
// check_for_seq_len_0_nested_tensor as there is handling there to make sure
|
| 453 |
+
// num_heads is not ragged
|
| 454 |
+
bool broadcastable_batch_size = true;
|
| 455 |
+
if (!same_batch_size) {
|
| 456 |
+
if (input_requires_grad(params)){
|
| 457 |
+
if (debug) {
|
| 458 |
+
TORCH_WARN(
|
| 459 |
+
"Both fused kernels do not support training with broadcasted NT inputs.");
|
| 460 |
+
}
|
| 461 |
+
return false;
|
| 462 |
+
}
|
| 463 |
+
// try to broadcast batchsize
|
| 464 |
+
broadcastable_batch_size = try_broadcast_param_size(
|
| 465 |
+
q_batch_size, k_batch_size, v_batch_size, "batch size ", debug);
|
| 466 |
+
|
| 467 |
+
// if only one of k or v require broadcasting of batch size, the other
|
| 468 |
+
// must have a consistent seq_len dim
|
| 469 |
+
if (broadcastable_batch_size) {
|
| 470 |
+
if (k_batch_size == 1 && v_batch_size != 1 &&
|
| 471 |
+
!check_safe_kv_broadcast(params.value, debug)) {
|
| 472 |
+
return false;
|
| 473 |
+
}
|
| 474 |
+
if (v_batch_size == 1 && k_batch_size != 1 &&
|
| 475 |
+
!check_safe_kv_broadcast(params.key, debug)) {
|
| 476 |
+
return false;
|
| 477 |
+
}
|
| 478 |
+
}
|
| 479 |
+
}
|
| 480 |
+
return broadcastable_batch_size;
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
inline bool check_nonzero_sequence_lengths_dense(sdp_params const& params, bool debug) {
|
| 484 |
+
// In some cases people will pass in 0 sized tensors, this will
|
| 485 |
+
// cause the fused path to error with unaligned mask
|
| 486 |
+
bool zero_seq_len_q = params.query.sym_size(-2) == 0;
|
| 487 |
+
bool zero_seq_len_k = params.key.sym_size(-2) == 0;
|
| 488 |
+
if (zero_seq_len_q || zero_seq_len_k) {
|
| 489 |
+
if (debug) {
|
| 490 |
+
TORCH_WARN(
|
| 491 |
+
"All fused kernels do not support zero seq_len_q or seq_len_kv.");
|
| 492 |
+
}
|
| 493 |
+
return false;
|
| 494 |
+
}
|
| 495 |
+
return true;
|
| 496 |
+
}
|
| 497 |
+
|
| 498 |
+
template<bool ignore_singleton_dim>
|
| 499 |
+
inline bool check_last_dim_stride_equals_1_dense(sdp_params const& params, bool debug) {
|
| 500 |
+
// The stride checking for NestedTensors is done within the kernel
|
| 501 |
+
// And .contiguous will be called if needed
|
| 502 |
+
|
| 503 |
+
// This function checks that the last dimension of the inputs to
|
| 504 |
+
// fused_attention have stride 1
|
| 505 |
+
bool qkv_strides_equal_1 = params.query.sym_stride(-1) == 1 &&
|
| 506 |
+
params.key.sym_stride(-1) == 1 && params.value.sym_stride(-1) == 1;
|
| 507 |
+
|
| 508 |
+
// https://github.com/pytorch/pytorch/issues/116333
|
| 509 |
+
// If the head_dim is size 1 the stride won't matter, but we
|
| 510 |
+
// check this condition before padding the head_dim to 1
|
| 511 |
+
if (ignore_singleton_dim){
|
| 512 |
+
qkv_strides_equal_1 = qkv_strides_equal_1 || params.query.sym_size(-1) == 1;
|
| 513 |
+
}
|
| 514 |
+
bool mask_stride_equal_1 = params.attn_mask.has_value()
|
| 515 |
+
? params.attn_mask.value().sym_stride(-1) == 1
|
| 516 |
+
: true;
|
| 517 |
+
if (!(qkv_strides_equal_1 && mask_stride_equal_1)) {
|
| 518 |
+
if (debug) {
|
| 519 |
+
std::ostringstream epilogue_message;
|
| 520 |
+
if (params.attn_mask.has_value()) {
|
| 521 |
+
epilogue_message << ", Attn_mask.stride(-1): "
|
| 522 |
+
<< params.attn_mask.value().sym_stride(-1);
|
| 523 |
+
}
|
| 524 |
+
epilogue_message << " instead.";
|
| 525 |
+
TORCH_WARN(
|
| 526 |
+
"All fused kernels require the last dimension of the input to have stride 1. ",
|
| 527 |
+
"Got Query.stride(-1): ",
|
| 528 |
+
params.query.sym_stride(-1),
|
| 529 |
+
", Key.stride(-1): ",
|
| 530 |
+
params.key.sym_stride(-1),
|
| 531 |
+
", Value.stride(-1): ",
|
| 532 |
+
params.value.sym_stride(-1),
|
| 533 |
+
epilogue_message.str());
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
return false;
|
| 537 |
+
}
|
| 538 |
+
return true;
|
| 539 |
+
}
|
| 540 |
+
|
| 541 |
+
inline bool check_runtime_disabled_flash(sdp_params const& params, bool debug) {
|
| 542 |
+
// We check the global context to see if user has explicitly turned of flash
|
| 543 |
+
// sdp kernels
|
| 544 |
+
if (!at::globalContext().userEnabledFlashSDP()) {
|
| 545 |
+
if (debug) {
|
| 546 |
+
TORCH_WARN("Flash attention has been runtime disabled.");
|
| 547 |
+
}
|
| 548 |
+
return false;
|
| 549 |
+
}
|
| 550 |
+
return true;
|
| 551 |
+
}
|
| 552 |
+
|
| 553 |
+
inline bool check_runtime_disabled_mem_efficient(sdp_params const& params, bool debug) {
|
| 554 |
+
// We check the global context to see if user has explicitly turned of
|
| 555 |
+
// mem_efficient sdp kernels
|
| 556 |
+
if (!at::globalContext().userEnabledMemEfficientSDP()) {
|
| 557 |
+
if (debug) {
|
| 558 |
+
TORCH_WARN("Memory Efficient attention has been runtime disabled.");
|
| 559 |
+
}
|
| 560 |
+
return false;
|
| 561 |
+
}
|
| 562 |
+
return true;
|
| 563 |
+
}
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
} // namespace sdp
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/Factory.h
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
namespace native {
|
| 7 |
+
namespace mobile {
|
| 8 |
+
|
| 9 |
+
Tensor allocate_padded_contiguous_if_needed(
|
| 10 |
+
const Tensor& input,
|
| 11 |
+
c10::MemoryFormat memory_format);
|
| 12 |
+
|
| 13 |
+
// TODO: Remove this function when at::native::empty() is modified to accept a
|
| 14 |
+
// custom memory allocator.
|
| 15 |
+
|
| 16 |
+
at::Tensor empty_with_tail_padding(
|
| 17 |
+
IntArrayRef size,
|
| 18 |
+
const caffe2::TypeMeta dtype,
|
| 19 |
+
c10::MemoryFormat memory_format,
|
| 20 |
+
std::optional<DimnameList> maybe_names);
|
| 21 |
+
|
| 22 |
+
} // namespace mobile
|
| 23 |
+
} // namespace native
|
| 24 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/ParamUtils.h
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/util/ArrayRef.h>
|
| 4 |
+
#include <vector>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
namespace native {
|
| 8 |
+
|
| 9 |
+
template <typename T>
|
| 10 |
+
inline std::vector<T> _expand_param_if_needed(
|
| 11 |
+
ArrayRef<T> list_param,
|
| 12 |
+
const char* param_name,
|
| 13 |
+
int64_t expected_dim) {
|
| 14 |
+
if (list_param.size() == 1) {
|
| 15 |
+
return std::vector<T>(expected_dim, list_param[0]);
|
| 16 |
+
} else if ((int64_t)list_param.size() != expected_dim) {
|
| 17 |
+
std::ostringstream ss;
|
| 18 |
+
ss << "expected " << param_name << " to be a single integer value or a "
|
| 19 |
+
<< "list of " << expected_dim << " values to match the convolution "
|
| 20 |
+
<< "dimensions, but got " << param_name << "=" << list_param;
|
| 21 |
+
AT_ERROR(ss.str());
|
| 22 |
+
} else {
|
| 23 |
+
return list_param.vec();
|
| 24 |
+
}
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
inline std::vector<int64_t> expand_param_if_needed(
|
| 28 |
+
IntArrayRef list_param,
|
| 29 |
+
const char* param_name,
|
| 30 |
+
int64_t expected_dim) {
|
| 31 |
+
return _expand_param_if_needed(list_param, param_name, expected_dim);
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
inline std::vector<c10::SymInt> expand_param_if_needed(
|
| 35 |
+
SymIntArrayRef list_param,
|
| 36 |
+
const char* param_name,
|
| 37 |
+
int64_t expected_dim) {
|
| 38 |
+
return _expand_param_if_needed(list_param, param_name, expected_dim);
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
} // namespace native
|
| 42 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/utils/ParamsHash.h
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/util/irange.h>
|
| 4 |
+
#include <memory>
|
| 5 |
+
#include <mutex>
|
| 6 |
+
|
| 7 |
+
namespace at::native {
|
| 8 |
+
|
| 9 |
+
// Hashing machinery for Params
|
| 10 |
+
// Fowler–Noll–Vo hash function
|
| 11 |
+
// see
|
| 12 |
+
// https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
|
| 13 |
+
template <typename Params>
|
| 14 |
+
struct ParamsHash {
|
| 15 |
+
// Params must be a POD because we read out its memory
|
| 16 |
+
// contents as char* when hashing
|
| 17 |
+
static_assert(std::is_standard_layout_v<Params>, "Params is not POD");
|
| 18 |
+
|
| 19 |
+
size_t operator()(const Params& params) const {
|
| 20 |
+
auto ptr = reinterpret_cast<const uint8_t*>(¶ms);
|
| 21 |
+
uint32_t value = 0x811C9DC5;
|
| 22 |
+
for (const auto i : c10::irange(sizeof(Params))) {
|
| 23 |
+
value ^= ptr[i];
|
| 24 |
+
value *= 0x01000193;
|
| 25 |
+
}
|
| 26 |
+
return (size_t)value;
|
| 27 |
+
}
|
| 28 |
+
};
|
| 29 |
+
|
| 30 |
+
template <typename Params>
|
| 31 |
+
struct ParamsEqual {
|
| 32 |
+
// Params must be a POD because we read out its memory
|
| 33 |
+
// contents as char* when comparing
|
| 34 |
+
static_assert(std::is_standard_layout_v<Params>, "Params is not POD");
|
| 35 |
+
|
| 36 |
+
bool operator()(const Params& a, const Params& b) const {
|
| 37 |
+
auto ptr1 = reinterpret_cast<const uint8_t*>(&a);
|
| 38 |
+
auto ptr2 = reinterpret_cast<const uint8_t*>(&b);
|
| 39 |
+
return memcmp(ptr1, ptr2, sizeof(Params)) == 0;
|
| 40 |
+
}
|
| 41 |
+
};
|
| 42 |
+
|
| 43 |
+
// Provide explicit byte-for-byte constructors to avoid uwittingly leaving
|
| 44 |
+
// padding bytes unitialized (e.g., when passing Params by value)
|
| 45 |
+
template <typename T>
|
| 46 |
+
struct ParamsWrapper {
|
| 47 |
+
T pod;
|
| 48 |
+
static_assert(
|
| 49 |
+
std::is_standard_layout_v<T>,
|
| 50 |
+
"ParamsWrapper cannot wrap non-POD data");
|
| 51 |
+
|
| 52 |
+
ParamsWrapper() {
|
| 53 |
+
memset(&(this->pod), 0, sizeof(this->pod));
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
ParamsWrapper(const ParamsWrapper& other) {
|
| 57 |
+
memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
ParamsWrapper(ParamsWrapper&& other) noexcept {
|
| 61 |
+
memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
ParamsWrapper& operator=(const ParamsWrapper& other) {
|
| 65 |
+
memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
|
| 66 |
+
return *this;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
ParamsWrapper& operator=(ParamsWrapper&& other) noexcept {
|
| 70 |
+
memcpy(&(this->pod), &(other.pod), sizeof(this->pod));
|
| 71 |
+
return *this;
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
inline friend bool operator==(
|
| 75 |
+
const ParamsWrapper& lhs,
|
| 76 |
+
const ParamsWrapper& rhs) noexcept {
|
| 77 |
+
auto ptr1 = reinterpret_cast<const uint8_t*>(&(lhs.pod));
|
| 78 |
+
auto ptr2 = reinterpret_cast<const uint8_t*>(&(rhs.pod));
|
| 79 |
+
return memcmp(ptr1, ptr2, sizeof(lhs.pod)) == 0;
|
| 80 |
+
}
|
| 81 |
+
};
|
| 82 |
+
|
| 83 |
+
// Wrapped version: this allows the outer struct to have custom copy and move
|
| 84 |
+
// constructors for additional safety
|
| 85 |
+
template <typename ParamsWrapper>
|
| 86 |
+
struct ParamsWrapperHash {
|
| 87 |
+
// Params must be a POD because we read out its memory
|
| 88 |
+
// contents as char* when hashing
|
| 89 |
+
static_assert(
|
| 90 |
+
std::is_standard_layout_v<decltype(ParamsWrapper::pod)>,
|
| 91 |
+
"ParamsWrapper cannot wrap non-POD data");
|
| 92 |
+
|
| 93 |
+
size_t operator()(const ParamsWrapper& params_wrapper) const {
|
| 94 |
+
auto ptr = reinterpret_cast<const uint8_t*>(&(params_wrapper.pod));
|
| 95 |
+
uint32_t value = 0x811C9DC5;
|
| 96 |
+
for (const auto i : c10::irange(sizeof(params_wrapper.pod))) {
|
| 97 |
+
value ^= ptr[i];
|
| 98 |
+
value *= 0x01000193;
|
| 99 |
+
}
|
| 100 |
+
return (size_t)value;
|
| 101 |
+
}
|
| 102 |
+
};
|
| 103 |
+
|
| 104 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_empty_per_channel_affine_quantized.h
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Function.h
|
| 4 |
+
|
| 5 |
+
#include <ATen/Context.h>
|
| 6 |
+
#include <ATen/DeviceGuard.h>
|
| 7 |
+
#include <ATen/TensorUtils.h>
|
| 8 |
+
#include <ATen/TracerMode.h>
|
| 9 |
+
#include <ATen/core/Generator.h>
|
| 10 |
+
#include <ATen/core/Reduction.h>
|
| 11 |
+
#include <ATen/core/Tensor.h>
|
| 12 |
+
#include <c10/core/Scalar.h>
|
| 13 |
+
#include <c10/core/Storage.h>
|
| 14 |
+
#include <c10/core/TensorOptions.h>
|
| 15 |
+
#include <c10/util/Deprecated.h>
|
| 16 |
+
#include <optional>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#include <ATen/ops/_empty_per_channel_affine_quantized_ops.h>
|
| 21 |
+
|
| 22 |
+
namespace at {
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
// aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor
|
| 26 |
+
inline at::Tensor _empty_per_channel_affine_quantized(at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::TensorOptions options={}, ::std::optional<at::MemoryFormat> memory_format=c10::MemoryFormat::Contiguous) {
|
| 27 |
+
return at::_ops::_empty_per_channel_affine_quantized::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
|
| 28 |
+
}
|
| 29 |
+
namespace symint {
|
| 30 |
+
template <typename T, typename = std::enable_if_t<std::is_same<T, int64_t>::value>>
|
| 31 |
+
at::Tensor _empty_per_channel_affine_quantized(at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::TensorOptions options={}, ::std::optional<at::MemoryFormat> memory_format=c10::MemoryFormat::Contiguous) {
|
| 32 |
+
return at::_ops::_empty_per_channel_affine_quantized::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
|
| 33 |
+
}
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
// aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor
|
| 37 |
+
inline at::Tensor _empty_per_channel_affine_quantized(at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional<at::ScalarType> dtype, ::std::optional<at::Layout> layout, ::std::optional<at::Device> device, ::std::optional<bool> pin_memory, ::std::optional<at::MemoryFormat> memory_format) {
|
| 38 |
+
return at::_ops::_empty_per_channel_affine_quantized::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, dtype, layout, device, pin_memory, memory_format);
|
| 39 |
+
}
|
| 40 |
+
namespace symint {
|
| 41 |
+
template <typename T, typename = std::enable_if_t<std::is_same<T, int64_t>::value>>
|
| 42 |
+
at::Tensor _empty_per_channel_affine_quantized(at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional<at::ScalarType> dtype, ::std::optional<at::Layout> layout, ::std::optional<at::Device> device, ::std::optional<bool> pin_memory, ::std::optional<at::MemoryFormat> memory_format) {
|
| 43 |
+
return at::_ops::_empty_per_channel_affine_quantized::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, dtype, layout, device, pin_memory, memory_format);
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
// aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor
|
| 48 |
+
inline at::Tensor _empty_per_channel_affine_quantized_symint(c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::TensorOptions options={}, ::std::optional<at::MemoryFormat> memory_format=c10::MemoryFormat::Contiguous) {
|
| 49 |
+
return at::_ops::_empty_per_channel_affine_quantized::call(size, scales, zero_points, axis, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
|
| 50 |
+
}
|
| 51 |
+
namespace symint {
|
| 52 |
+
template <typename T, typename = std::enable_if_t<std::is_same<T, c10::SymInt>::value>>
|
| 53 |
+
at::Tensor _empty_per_channel_affine_quantized(c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::TensorOptions options={}, ::std::optional<at::MemoryFormat> memory_format=c10::MemoryFormat::Contiguous) {
|
| 54 |
+
return at::_ops::_empty_per_channel_affine_quantized::call(size, scales, zero_points, axis, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
|
| 55 |
+
}
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
// aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor
|
| 59 |
+
inline at::Tensor _empty_per_channel_affine_quantized_symint(c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional<at::ScalarType> dtype, ::std::optional<at::Layout> layout, ::std::optional<at::Device> device, ::std::optional<bool> pin_memory, ::std::optional<at::MemoryFormat> memory_format) {
|
| 60 |
+
return at::_ops::_empty_per_channel_affine_quantized::call(size, scales, zero_points, axis, dtype, layout, device, pin_memory, memory_format);
|
| 61 |
+
}
|
| 62 |
+
namespace symint {
|
| 63 |
+
template <typename T, typename = std::enable_if_t<std::is_same<T, c10::SymInt>::value>>
|
| 64 |
+
at::Tensor _empty_per_channel_affine_quantized(c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional<at::ScalarType> dtype, ::std::optional<at::Layout> layout, ::std::optional<at::Device> device, ::std::optional<bool> pin_memory, ::std::optional<at::MemoryFormat> memory_format) {
|
| 65 |
+
return at::_ops::_empty_per_channel_affine_quantized::call(size, scales, zero_points, axis, dtype, layout, device, pin_memory, memory_format);
|
| 66 |
+
}
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
// aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!)
|
| 70 |
+
inline at::Tensor & _empty_per_channel_affine_quantized_out(at::Tensor & out, at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional<at::MemoryFormat> memory_format=c10::MemoryFormat::Contiguous) {
|
| 71 |
+
return at::_ops::_empty_per_channel_affine_quantized_out::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, memory_format, out);
|
| 72 |
+
}
|
| 73 |
+
namespace symint {
|
| 74 |
+
template <typename T, typename = std::enable_if_t<std::is_same<T, int64_t>::value>>
|
| 75 |
+
at::Tensor & _empty_per_channel_affine_quantized_out(at::Tensor & out, at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional<at::MemoryFormat> memory_format=c10::MemoryFormat::Contiguous) {
|
| 76 |
+
return at::_ops::_empty_per_channel_affine_quantized_out::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, memory_format, out);
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
// aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!)
|
| 81 |
+
inline at::Tensor & _empty_per_channel_affine_quantized_outf(at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional<at::MemoryFormat> memory_format, at::Tensor & out) {
|
| 82 |
+
return at::_ops::_empty_per_channel_affine_quantized_out::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, memory_format, out);
|
| 83 |
+
}
|
| 84 |
+
namespace symint {
|
| 85 |
+
template <typename T, typename = std::enable_if_t<std::is_same<T, int64_t>::value>>
|
| 86 |
+
at::Tensor & _empty_per_channel_affine_quantized_outf(at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional<at::MemoryFormat> memory_format, at::Tensor & out) {
|
| 87 |
+
return at::_ops::_empty_per_channel_affine_quantized_out::call(c10::fromIntArrayRefSlow(size), scales, zero_points, axis, memory_format, out);
|
| 88 |
+
}
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
// aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!)
|
| 92 |
+
inline at::Tensor & _empty_per_channel_affine_quantized_symint_out(at::Tensor & out, c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional<at::MemoryFormat> memory_format=c10::MemoryFormat::Contiguous) {
|
| 93 |
+
return at::_ops::_empty_per_channel_affine_quantized_out::call(size, scales, zero_points, axis, memory_format, out);
|
| 94 |
+
}
|
| 95 |
+
namespace symint {
|
| 96 |
+
template <typename T, typename = std::enable_if_t<std::is_same<T, c10::SymInt>::value>>
|
| 97 |
+
at::Tensor & _empty_per_channel_affine_quantized_out(at::Tensor & out, c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional<at::MemoryFormat> memory_format=c10::MemoryFormat::Contiguous) {
|
| 98 |
+
return at::_ops::_empty_per_channel_affine_quantized_out::call(size, scales, zero_points, axis, memory_format, out);
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
// aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!)
|
| 103 |
+
inline at::Tensor & _empty_per_channel_affine_quantized_symint_outf(c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional<at::MemoryFormat> memory_format, at::Tensor & out) {
|
| 104 |
+
return at::_ops::_empty_per_channel_affine_quantized_out::call(size, scales, zero_points, axis, memory_format, out);
|
| 105 |
+
}
|
| 106 |
+
namespace symint {
|
| 107 |
+
template <typename T, typename = std::enable_if_t<std::is_same<T, c10::SymInt>::value>>
|
| 108 |
+
at::Tensor & _empty_per_channel_affine_quantized_outf(c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional<at::MemoryFormat> memory_format, at::Tensor & out) {
|
| 109 |
+
return at::_ops::_empty_per_channel_affine_quantized_out::call(size, scales, zero_points, axis, memory_format, out);
|
| 110 |
+
}
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
}
|
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_addcmul_native.h
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from NativeFunction.h
|
| 4 |
+
|
| 5 |
+
#include <c10/core/Scalar.h>
|
| 6 |
+
#include <c10/core/Storage.h>
|
| 7 |
+
#include <c10/core/TensorOptions.h>
|
| 8 |
+
#include <c10/util/Deprecated.h>
|
| 9 |
+
#include <optional>
|
| 10 |
+
#include <c10/core/QScheme.h>
|
| 11 |
+
#include <ATen/core/Reduction.h>
|
| 12 |
+
#include <ATen/core/Tensor.h>
|
| 13 |
+
#include <tuple>
|
| 14 |
+
#include <vector>
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
namespace at {
|
| 18 |
+
namespace native {
|
| 19 |
+
TORCH_API ::std::vector<at::Tensor> foreach_tensor_addcmul_scalar_slow(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1);
|
| 20 |
+
TORCH_API void _foreach_addcmul_Scalar_out(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value, at::TensorList out);
|
| 21 |
+
TORCH_API void foreach_tensor_addcmul_scalar_slow_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1);
|
| 22 |
+
TORCH_API ::std::vector<at::Tensor> foreach_tensor_addcmul_scalar_cuda(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1);
|
| 23 |
+
TORCH_API void foreach_tensor_addcmul_scalar_cuda_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1);
|
| 24 |
+
TORCH_API ::std::vector<at::Tensor> foreach_tensor_addcmul_scalarlist_slow(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef<at::Scalar> scalars);
|
| 25 |
+
TORCH_API void _foreach_addcmul_ScalarList_out(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef<at::Scalar> scalars, at::TensorList out);
|
| 26 |
+
TORCH_API void foreach_tensor_addcmul_scalarlist_slow_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef<at::Scalar> scalars);
|
| 27 |
+
TORCH_API ::std::vector<at::Tensor> foreach_tensor_addcmul_scalarlist_cuda(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef<at::Scalar> scalars);
|
| 28 |
+
TORCH_API void foreach_tensor_addcmul_scalarlist_cuda_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef<at::Scalar> scalars);
|
| 29 |
+
TORCH_API ::std::vector<at::Tensor> foreach_tensor_addcmul_tensor_slow(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars);
|
| 30 |
+
TORCH_API void _foreach_addcmul_Tensor_out(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars, at::TensorList out);
|
| 31 |
+
TORCH_API void foreach_tensor_addcmul_tensor_slow_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars);
|
| 32 |
+
TORCH_API ::std::vector<at::Tensor> foreach_tensor_addcmul_tensor_cuda(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars);
|
| 33 |
+
TORCH_API void foreach_tensor_addcmul_tensor_cuda_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars);
|
| 34 |
+
} // namespace native
|
| 35 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_expm1_ops.h
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operator.h
|
| 4 |
+
|
| 5 |
+
#include <tuple>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 9 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 10 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 11 |
+
#include <ATen/core/ATen_fwd.h>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
namespace _ops {
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
struct TORCH_API _foreach_expm1 {
|
| 18 |
+
using schema = ::std::vector<at::Tensor> (at::TensorList);
|
| 19 |
+
using ptr_schema = schema*;
|
| 20 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 21 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_foreach_expm1")
|
| 22 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 23 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_foreach_expm1(Tensor[] self) -> Tensor[]")
|
| 24 |
+
static ::std::vector<at::Tensor> call(at::TensorList self);
|
| 25 |
+
static ::std::vector<at::Tensor> redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList self);
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
struct TORCH_API _foreach_expm1_ {
|
| 29 |
+
using schema = void (at::TensorList);
|
| 30 |
+
using ptr_schema = schema*;
|
| 31 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 32 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_foreach_expm1_")
|
| 33 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 34 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_foreach_expm1_(Tensor(a!)[] self) -> ()")
|
| 35 |
+
static void call(at::TensorList self);
|
| 36 |
+
static void redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList self);
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
struct TORCH_API _foreach_expm1_out {
|
| 40 |
+
using schema = void (at::TensorList, at::TensorList);
|
| 41 |
+
using ptr_schema = schema*;
|
| 42 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 43 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_foreach_expm1")
|
| 44 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
|
| 45 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_foreach_expm1.out(Tensor[] self, *, Tensor(a!)[] out) -> ()")
|
| 46 |
+
static void call(at::TensorList self, at::TensorList out);
|
| 47 |
+
static void redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out);
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
+
}} // namespace at::_ops
|
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_sparse_mask_projection_ops.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operator.h
|
| 4 |
+
|
| 5 |
+
#include <tuple>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 9 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 10 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 11 |
+
#include <ATen/core/ATen_fwd.h>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
namespace _ops {
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
struct TORCH_API _sparse_mask_projection {
|
| 18 |
+
using schema = at::Tensor (const at::Tensor &, const at::Tensor &, bool);
|
| 19 |
+
using ptr_schema = schema*;
|
| 20 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 21 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_sparse_mask_projection")
|
| 22 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 23 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_sparse_mask_projection(Tensor self, Tensor mask, bool accumulate_matches=False) -> Tensor")
|
| 24 |
+
static at::Tensor call(const at::Tensor & self, const at::Tensor & mask, bool accumulate_matches);
|
| 25 |
+
static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, bool accumulate_matches);
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
struct TORCH_API _sparse_mask_projection_out {
|
| 29 |
+
using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, bool, at::Tensor &);
|
| 30 |
+
using ptr_schema = schema*;
|
| 31 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 32 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_sparse_mask_projection")
|
| 33 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
|
| 34 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_sparse_mask_projection.out(Tensor self, Tensor mask, bool accumulate_matches=False, *, Tensor(a!) out) -> Tensor(a!)")
|
| 35 |
+
static at::Tensor & call(const at::Tensor & self, const at::Tensor & mask, bool accumulate_matches, at::Tensor & out);
|
| 36 |
+
static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, bool accumulate_matches, at::Tensor & out);
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
}} // namespace at::_ops
|
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_sparse_sum_backward_ops.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operator.h
|
| 4 |
+
|
| 5 |
+
#include <tuple>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 9 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 10 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 11 |
+
#include <ATen/core/ATen_fwd.h>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
namespace _ops {
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
struct TORCH_API _sparse_sum_backward {
|
| 18 |
+
using schema = at::Tensor (const at::Tensor &, const at::Tensor &, at::IntArrayRef);
|
| 19 |
+
using ptr_schema = schema*;
|
| 20 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 21 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_sparse_sum_backward")
|
| 22 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 23 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_sparse_sum_backward(Tensor grad, Tensor self, int[] dim) -> Tensor")
|
| 24 |
+
static at::Tensor call(const at::Tensor & grad, const at::Tensor & self, at::IntArrayRef dim);
|
| 25 |
+
static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, at::IntArrayRef dim);
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
struct TORCH_API _sparse_sum_backward_out {
|
| 29 |
+
using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, at::IntArrayRef, at::Tensor &);
|
| 30 |
+
using ptr_schema = schema*;
|
| 31 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 32 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_sparse_sum_backward")
|
| 33 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
|
| 34 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_sparse_sum_backward.out(Tensor grad, Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!)")
|
| 35 |
+
static at::Tensor & call(const at::Tensor & grad, const at::Tensor & self, at::IntArrayRef dim, at::Tensor & out);
|
| 36 |
+
static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, at::IntArrayRef dim, at::Tensor & out);
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
}} // namespace at::_ops
|
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_test_autograd_multiple_dispatch_view_copy_ops.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operator.h
|
| 4 |
+
|
| 5 |
+
#include <tuple>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 9 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 10 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 11 |
+
#include <ATen/core/ATen_fwd.h>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
namespace _ops {
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
struct TORCH_API _test_autograd_multiple_dispatch_view_copy {
|
| 18 |
+
using schema = at::Tensor (const at::Tensor &);
|
| 19 |
+
using ptr_schema = schema*;
|
| 20 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 21 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_test_autograd_multiple_dispatch_view_copy")
|
| 22 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 23 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_test_autograd_multiple_dispatch_view_copy(Tensor self) -> Tensor")
|
| 24 |
+
static at::Tensor call(const at::Tensor & self);
|
| 25 |
+
static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self);
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
struct TORCH_API _test_autograd_multiple_dispatch_view_copy_out {
|
| 29 |
+
using schema = at::Tensor & (const at::Tensor &, at::Tensor &);
|
| 30 |
+
using ptr_schema = schema*;
|
| 31 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 32 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_test_autograd_multiple_dispatch_view_copy")
|
| 33 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
|
| 34 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_test_autograd_multiple_dispatch_view_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)")
|
| 35 |
+
static at::Tensor & call(const at::Tensor & self, at::Tensor & out);
|
| 36 |
+
static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out);
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
}} // namespace at::_ops
|
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_unsafe_masked_index_put_accumulate_ops.h
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operator.h
|
| 4 |
+
|
| 5 |
+
#include <tuple>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 9 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 10 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 11 |
+
#include <ATen/core/ATen_fwd.h>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
namespace _ops {
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
struct TORCH_API _unsafe_masked_index_put_accumulate {
|
| 18 |
+
using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const c10::List<::std::optional<at::Tensor>> &, const at::Tensor &);
|
| 19 |
+
using ptr_schema = schema*;
|
| 20 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 21 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_unsafe_masked_index_put_accumulate")
|
| 22 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 23 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_unsafe_masked_index_put_accumulate(Tensor self, Tensor mask, Tensor?[] indices, Tensor values) -> Tensor")
|
| 24 |
+
static at::Tensor call(const at::Tensor & self, const at::Tensor & mask, const c10::List<::std::optional<at::Tensor>> & indices, const at::Tensor & values);
|
| 25 |
+
static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, const c10::List<::std::optional<at::Tensor>> & indices, const at::Tensor & values);
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
}} // namespace at::_ops
|
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/alias_ops.h
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operator.h
|
| 4 |
+
|
| 5 |
+
#include <tuple>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 9 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 10 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 11 |
+
#include <ATen/core/ATen_fwd.h>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
namespace _ops {
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
struct TORCH_API alias {
|
| 18 |
+
using schema = at::Tensor (const at::Tensor &);
|
| 19 |
+
using ptr_schema = schema*;
|
| 20 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 21 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::alias")
|
| 22 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 23 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "alias(Tensor(a) self) -> Tensor(a)")
|
| 24 |
+
static at::Tensor call(const at::Tensor & self);
|
| 25 |
+
static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self);
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
}} // namespace at::_ops
|
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/arcsinh_compositeimplicitautograd_dispatch.h
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace compositeimplicitautograd {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor arcsinh(const at::Tensor & self);
|
| 21 |
+
TORCH_API at::Tensor & arcsinh_out(at::Tensor & out, const at::Tensor & self);
|
| 22 |
+
TORCH_API at::Tensor & arcsinh_outf(const at::Tensor & self, at::Tensor & out);
|
| 23 |
+
TORCH_API at::Tensor & arcsinh_(at::Tensor & self);
|
| 24 |
+
|
| 25 |
+
} // namespace compositeimplicitautograd
|
| 26 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/block_diag_native.h
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from NativeFunction.h
|
| 4 |
+
|
| 5 |
+
#include <c10/core/Scalar.h>
|
| 6 |
+
#include <c10/core/Storage.h>
|
| 7 |
+
#include <c10/core/TensorOptions.h>
|
| 8 |
+
#include <c10/util/Deprecated.h>
|
| 9 |
+
#include <optional>
|
| 10 |
+
#include <c10/core/QScheme.h>
|
| 11 |
+
#include <ATen/core/Reduction.h>
|
| 12 |
+
#include <ATen/core/Tensor.h>
|
| 13 |
+
#include <tuple>
|
| 14 |
+
#include <vector>
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
namespace at {
|
| 18 |
+
namespace native {
|
| 19 |
+
TORCH_API at::Tensor block_diag(at::TensorList tensors);
|
| 20 |
+
TORCH_API at::Tensor & block_diag_out(at::TensorList tensors, at::Tensor & out);
|
| 21 |
+
} // namespace native
|
| 22 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/ops/cat_compositeexplicitautogradnonfunctional_dispatch.h
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace compositeexplicitautogradnonfunctional {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor cat(const at::ITensorListRef & tensors, int64_t dim=0);
|
| 21 |
+
|
| 22 |
+
} // namespace compositeexplicitautogradnonfunctional
|
| 23 |
+
} // namespace at
|