Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/ArrayRef.h +7 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/Backend.h +7 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/CPUApplyUtils.h +356 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/CPUFixedAllocator.h +38 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/CPUGeneratorImpl.h +54 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/CUDAFunctions.h +34 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/CollapseDims.h +99 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h +30 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/Config.h +28 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/Device.h +7 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/DeviceAccelerator.h +118 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/DimVector.h +7 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/Dispatch_v2.h +182 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/DynamicLibrary.h +41 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/EmptyTensor.h +171 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/ExpandUtils.h +540 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/FunctionalTensorWrapper.h +476 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/Functions.h +1476 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/InitialTensorOptions.h +20 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/LegacyBatchedTensorImpl.h +166 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/LegacyVmapMode.h +31 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/LegacyVmapTransforms.h +188 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/MethodOperators.h +449 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/NamedTensor.h +6 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/NativeMetaFunctions.h +1352 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/NestedTensorImpl.h +292 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/NumericUtils.h +208 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/ParallelOpenMP.h +59 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/RedispatchFunctions.h +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/RegistrationDeclarations.h +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/SDPBackend.h +21 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/Scalar.h +8 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/StorageUtils.h +54 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/TensorAccessor.h +7 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/ThreadLocalPythonObjects.h +26 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/ThreadLocalState.h +131 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/Utils.h +143 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpp_custom_type_hack.h +115 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/THC/THCAtomics.cuh +8 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/THC/THCDeviceUtils.cuh +8 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/ConvUtils.h +195 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/Fbgemm.h +1515 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmBuild.h +116 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmConvert.h +205 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmEmbedding.h +383 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmFP16.h +60 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmFP32.h +54 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmFPCommon.h +319 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmI64.h +36 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmI8DepthwiseAvx2.h +117 -0
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/ArrayRef.h
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
#include <c10/util/ArrayRef.h>
|
| 4 |
+
|
| 5 |
+
#else
|
| 6 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 7 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/Backend.h
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
#include <c10/core/Backend.h>
|
| 4 |
+
|
| 5 |
+
#else
|
| 6 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 7 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/CPUApplyUtils.h
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/CollapseDims.h>
|
| 5 |
+
#include <ATen/Parallel.h>
|
| 6 |
+
#include <ATen/TensorUtils.h>
|
| 7 |
+
#include <c10/util/irange.h>
|
| 8 |
+
#include <cstring>
|
| 9 |
+
#include <limits>
|
| 10 |
+
|
| 11 |
+
namespace at {
|
| 12 |
+
|
| 13 |
+
/*
|
| 14 |
+
* The basic strategy for apply is as follows:
|
| 15 |
+
*
|
| 16 |
+
* 1. Starting with the outermost index, loop until we reach a dimension where
|
| 17 |
+
* the data is no longer contiguous, i.e. the stride at that dimension is not
|
| 18 |
+
* equal to the size of the tensor defined by the outer dimensions. Let's call
|
| 19 |
+
* this outer (contiguous) tensor A. Note that if the Tensor is contiguous, then
|
| 20 |
+
* A is equal to the entire Tensor. Let's call the inner tensor B.
|
| 21 |
+
*
|
| 22 |
+
* 2. We loop through the indices in B, starting at its outermost dimension. For
|
| 23 |
+
* example, if B is a 2x2 matrix, then we do:
|
| 24 |
+
*
|
| 25 |
+
* B[0][0]
|
| 26 |
+
* B[0][1]
|
| 27 |
+
* B[1][0]
|
| 28 |
+
* B[1][1]
|
| 29 |
+
*
|
| 30 |
+
* We set the offset into the underlying storage as (storageOffset + stride_B *
|
| 31 |
+
* index_B), i.e. basically we compute the offset into the storage as we would
|
| 32 |
+
* normally for a Tensor. But because we are guaranteed the subsequent data is
|
| 33 |
+
* contiguous in memory, we can simply loop for sizeof(A) iterations and perform
|
| 34 |
+
* the operation, without having to follow the order described by the strides of
|
| 35 |
+
* A.
|
| 36 |
+
*
|
| 37 |
+
* 3. As an optimization, we merge dimensions of A that are contiguous in
|
| 38 |
+
* memory. For example, if A is a 3x3x3x3 tensor narrowed from a 3x3x4x3 tensor,
|
| 39 |
+
* then the first two dimensions can be merged for the purposes of APPLY,
|
| 40 |
+
* reducing the number of nested loops.
|
| 41 |
+
*/
|
| 42 |
+
|
| 43 |
+
inline Tensor sort_strides(Tensor& tensor_) {
|
| 44 |
+
IntArrayRef strides = tensor_.strides();
|
| 45 |
+
std::vector<int64_t> indices;
|
| 46 |
+
indices.reserve(tensor_.ndimension());
|
| 47 |
+
for (const auto i : c10::irange(tensor_.ndimension())) {
|
| 48 |
+
indices.push_back(i);
|
| 49 |
+
}
|
| 50 |
+
std::sort(indices.begin(), indices.end(), [&strides](int64_t i1, int64_t i2) {
|
| 51 |
+
return strides[i1] > strides[i2];
|
| 52 |
+
});
|
| 53 |
+
Tensor tensor = tensor_.permute(indices);
|
| 54 |
+
return tensor;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
template <typename T, int N>
|
| 58 |
+
struct strided_tensor_iter_fixed {
|
| 59 |
+
public:
|
| 60 |
+
T* data_ = NULL;
|
| 61 |
+
int64_t dim_ = 0;
|
| 62 |
+
|
| 63 |
+
// NOLINTNEXTLINE(*array*)
|
| 64 |
+
int64_t counter_[N] = {0};
|
| 65 |
+
// NOLINTNEXTLINE(*array*)
|
| 66 |
+
int64_t sizes_[N] = {0};
|
| 67 |
+
// NOLINTNEXTLINE(*array*)
|
| 68 |
+
int64_t strides_[N] = {0};
|
| 69 |
+
|
| 70 |
+
strided_tensor_iter_fixed(strided_tensor_iter_fixed const&) = delete;
|
| 71 |
+
strided_tensor_iter_fixed& operator=(strided_tensor_iter_fixed const& x) =
|
| 72 |
+
delete;
|
| 73 |
+
strided_tensor_iter_fixed(strided_tensor_iter_fixed&&) noexcept = default;
|
| 74 |
+
strided_tensor_iter_fixed& operator=(strided_tensor_iter_fixed&& x) noexcept =
|
| 75 |
+
default;
|
| 76 |
+
~strided_tensor_iter_fixed() noexcept = default;
|
| 77 |
+
strided_tensor_iter_fixed(
|
| 78 |
+
Tensor& tensor,
|
| 79 |
+
[[maybe_unused]] bool sort_strides = false)
|
| 80 |
+
: data_(tensor.data_ptr<T>()) {
|
| 81 |
+
std::memset(counter_, 0, sizeof(int64_t) * N);
|
| 82 |
+
if (tensor.dim() > 0) {
|
| 83 |
+
std::memcpy(
|
| 84 |
+
sizes_, tensor.sizes().data(), tensor.dim() * sizeof(int64_t));
|
| 85 |
+
std::memcpy(
|
| 86 |
+
strides_, tensor.strides().data(), tensor.dim() * sizeof(int64_t));
|
| 87 |
+
}
|
| 88 |
+
dim_ = std::get<1>(collapse_dims(sizes_, strides_, tensor.ndimension()));
|
| 89 |
+
}
|
| 90 |
+
};
|
| 91 |
+
|
| 92 |
+
template <typename T>
|
| 93 |
+
struct strided_tensor_iter {
|
| 94 |
+
private:
|
| 95 |
+
public:
|
| 96 |
+
T* data_ = NULL;
|
| 97 |
+
int64_t dim_;
|
| 98 |
+
|
| 99 |
+
std::vector<int64_t> counter_;
|
| 100 |
+
std::vector<int64_t> sizes_;
|
| 101 |
+
std::vector<int64_t> strides_;
|
| 102 |
+
|
| 103 |
+
strided_tensor_iter(strided_tensor_iter const&) = delete;
|
| 104 |
+
strided_tensor_iter& operator=(strided_tensor_iter const& x) = delete;
|
| 105 |
+
strided_tensor_iter(strided_tensor_iter&&) noexcept = default;
|
| 106 |
+
strided_tensor_iter& operator=(strided_tensor_iter&&) noexcept = default;
|
| 107 |
+
~strided_tensor_iter() noexcept = default;
|
| 108 |
+
strided_tensor_iter(Tensor& tensor)
|
| 109 |
+
: data_(tensor.data_ptr<T>()),
|
| 110 |
+
dim_(tensor.ndimension()),
|
| 111 |
+
counter_(dim_, 0),
|
| 112 |
+
sizes_(tensor.sizes().vec()),
|
| 113 |
+
strides_(tensor.strides().vec()) {
|
| 114 |
+
dim_ = std::get<1>(collapse_dims(sizes_.data(), strides_.data(), dim_));
|
| 115 |
+
}
|
| 116 |
+
};
|
| 117 |
+
|
| 118 |
+
inline bool _all_equal_numel(at::ArrayRef<Tensor> tensors) {
|
| 119 |
+
if (tensors.empty())
|
| 120 |
+
return true;
|
| 121 |
+
int64_t all_numel = tensors[0].numel();
|
| 122 |
+
for (const auto i : c10::irange(1, tensors.size())) {
|
| 123 |
+
if (tensors[i].numel() != all_numel)
|
| 124 |
+
return false;
|
| 125 |
+
}
|
| 126 |
+
return true;
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
inline std::string _all_equal_numel_error(at::ArrayRef<Tensor> tensors) {
|
| 130 |
+
std::ostringstream oss;
|
| 131 |
+
oss << "inconsistent tensor size, expected ";
|
| 132 |
+
for (size_t i = 0; i < tensors.size() - 1; i++) {
|
| 133 |
+
oss << tensors[i].sizes() << ", ";
|
| 134 |
+
}
|
| 135 |
+
oss << "and " << tensors[tensors.size() - 1].sizes()
|
| 136 |
+
<< " to have the same number of elements, but got ";
|
| 137 |
+
for (size_t i = 0; i < tensors.size() - 1; i++) {
|
| 138 |
+
oss << tensors[i].numel() << ", ";
|
| 139 |
+
}
|
| 140 |
+
oss << "and " << tensors[tensors.size() - 1].numel()
|
| 141 |
+
<< " elements respectively";
|
| 142 |
+
return oss.str();
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
inline bool _apply_preamble(ArrayRef<Tensor> tensors) {
|
| 146 |
+
checkDeviceType("CPU_tensor_apply", tensors, kCPU);
|
| 147 |
+
checkLayout("CPU_tensor_apply", tensors, kStrided);
|
| 148 |
+
TORCH_CHECK(_all_equal_numel(tensors), _all_equal_numel_error(tensors));
|
| 149 |
+
// An empty tensor has no elements
|
| 150 |
+
for (auto& t : tensors)
|
| 151 |
+
if (t.numel() == 0)
|
| 152 |
+
return false;
|
| 153 |
+
return true;
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
inline int64_t _max_dim_tensors(ArrayRef<Tensor> tensors) {
|
| 157 |
+
int64_t dim = 0;
|
| 158 |
+
for (auto& t : tensors)
|
| 159 |
+
dim = std::max(dim, t.ndimension());
|
| 160 |
+
return dim;
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
inline void iterate(int64_t /*size*/) {}
|
| 164 |
+
|
| 165 |
+
template <typename Arg, typename... Args>
|
| 166 |
+
inline void iterate(int64_t size, Arg& iter, Args&... iter_tail) {
|
| 167 |
+
iter.counter_[iter.dim_ - 1] += size;
|
| 168 |
+
iter.data_ = iter.data_ + size * iter.strides_[iter.dim_ - 1];
|
| 169 |
+
iterate(size, iter_tail...);
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
inline bool iterate_continue() {
|
| 173 |
+
return true;
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
template <typename Arg, typename... Args>
|
| 177 |
+
inline bool iterate_continue(Arg& iter, Args&... iter_tail) {
|
| 178 |
+
return iter.counter_[iter.dim_ - 1] < iter.sizes_[iter.dim_ - 1] &&
|
| 179 |
+
iterate_continue(iter_tail...);
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
inline int64_t max_iterate_size() {
|
| 183 |
+
return std::numeric_limits<int64_t>::max();
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
template <typename Arg, typename... Args>
|
| 187 |
+
inline int64_t max_iterate_size(Arg& iter, Args&... iter_tail) {
|
| 188 |
+
return std::min(
|
| 189 |
+
(iter.sizes_[iter.dim_ - 1] - iter.counter_[iter.dim_ - 1]),
|
| 190 |
+
max_iterate_size(iter_tail...));
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
inline void iterate_overflow() {}
|
| 194 |
+
|
| 195 |
+
template <typename Arg, typename... Args>
|
| 196 |
+
inline void iterate_overflow(Arg& iter, Args&... iter_tail) {
|
| 197 |
+
if (iter.counter_[iter.dim_ - 1] == iter.sizes_[iter.dim_ - 1]) {
|
| 198 |
+
for (int64_t i = iter.dim_ - 1; i > 0; i--) {
|
| 199 |
+
if (iter.counter_[i] == iter.sizes_[i]) {
|
| 200 |
+
iter.counter_[i] = 0;
|
| 201 |
+
iter.counter_[i - 1]++;
|
| 202 |
+
iter.data_ = iter.data_ - (iter.sizes_[i] * iter.strides_[i]) +
|
| 203 |
+
iter.strides_[i - 1];
|
| 204 |
+
}
|
| 205 |
+
}
|
| 206 |
+
}
|
| 207 |
+
iterate_overflow(iter_tail...);
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
inline void forward(int64_t /*offset*/) {}
|
| 211 |
+
|
| 212 |
+
template <typename Arg, typename... Args>
|
| 213 |
+
inline void forward(int64_t offset, Arg& iter, Args&... iter_tail) {
|
| 214 |
+
int64_t multi = offset;
|
| 215 |
+
for (int64_t i = iter.dim_ - 1; i >= 0; i--) {
|
| 216 |
+
int64_t inc = multi % iter.sizes_[i];
|
| 217 |
+
multi = multi / iter.sizes_[i];
|
| 218 |
+
iter.data_ = iter.data_ + inc * iter.strides_[i];
|
| 219 |
+
iter.counter_[i] += inc;
|
| 220 |
+
}
|
| 221 |
+
forward(offset, iter_tail...);
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
inline int64_t max_dim() {
|
| 225 |
+
return 0;
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
template <typename Arg, typename... Args>
|
| 229 |
+
inline int64_t max_dim(Arg& iter, Args&... iter_tail) {
|
| 230 |
+
return std::max(iter.dim_, max_dim(iter_tail...));
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
inline void apply_op() {}
|
| 234 |
+
|
| 235 |
+
template <typename Op, typename... Args>
|
| 236 |
+
inline void apply_op(
|
| 237 |
+
int64_t numel,
|
| 238 |
+
int64_t offset,
|
| 239 |
+
const Op& op,
|
| 240 |
+
Args... iters) {
|
| 241 |
+
// For 0-dim tensors
|
| 242 |
+
if (numel == 1 && max_dim(iters...) == 0) {
|
| 243 |
+
op(*iters.data_...);
|
| 244 |
+
return;
|
| 245 |
+
}
|
| 246 |
+
if (offset > 0)
|
| 247 |
+
forward(offset, iters...);
|
| 248 |
+
// Splitting this into chunks helps the compiler create faster assembly
|
| 249 |
+
for (int64_t i = 0; i < numel;) {
|
| 250 |
+
for (; iterate_continue(iters...) && i < numel;) {
|
| 251 |
+
op(*iters.data_...);
|
| 252 |
+
iterate(1, iters...);
|
| 253 |
+
i++;
|
| 254 |
+
}
|
| 255 |
+
iterate_overflow(iters...);
|
| 256 |
+
}
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
/*
|
| 260 |
+
Apply a pointwise operator to sequence of tensors
|
| 261 |
+
|
| 262 |
+
The calling convention for op is a function/functor that takes the same
|
| 263 |
+
number of pointers of type scalar as the number of given tensors. For example,
|
| 264 |
+
to compute a = b * c, op would be of the form:
|
| 265 |
+
[](scalar* a_val, const scalar* b_val, const scalar* c_val) { a_val[0] =
|
| 266 |
+
b_val[0] * c_val[0]; };
|
| 267 |
+
*/
|
| 268 |
+
|
| 269 |
+
template <typename scalar1, typename scalar2, typename Op>
|
| 270 |
+
inline void CPU_tensor_apply2(Tensor tensor1, Tensor tensor2, const Op op) {
|
| 271 |
+
if (!_apply_preamble({tensor1, tensor2}))
|
| 272 |
+
return;
|
| 273 |
+
if (_max_dim_tensors({tensor1, tensor2}) <= 8) {
|
| 274 |
+
apply_op(
|
| 275 |
+
tensor1.numel(),
|
| 276 |
+
0,
|
| 277 |
+
op,
|
| 278 |
+
strided_tensor_iter_fixed<scalar1, 8>(tensor1),
|
| 279 |
+
strided_tensor_iter_fixed<scalar2, 8>(tensor2));
|
| 280 |
+
} else {
|
| 281 |
+
apply_op(
|
| 282 |
+
tensor1.numel(),
|
| 283 |
+
0,
|
| 284 |
+
op,
|
| 285 |
+
strided_tensor_iter<scalar1>(tensor1),
|
| 286 |
+
strided_tensor_iter<scalar2>(tensor2));
|
| 287 |
+
}
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
template <typename scalar1, typename scalar2, typename scalar3, typename Op>
|
| 291 |
+
inline void CPU_tensor_apply3(
|
| 292 |
+
Tensor tensor1,
|
| 293 |
+
Tensor tensor2,
|
| 294 |
+
Tensor tensor3,
|
| 295 |
+
const Op op) {
|
| 296 |
+
if (!_apply_preamble({tensor1, tensor2, tensor3}))
|
| 297 |
+
return;
|
| 298 |
+
if (_max_dim_tensors({tensor1, tensor2, tensor3}) <= 8) {
|
| 299 |
+
apply_op(
|
| 300 |
+
tensor1.numel(),
|
| 301 |
+
0,
|
| 302 |
+
op,
|
| 303 |
+
strided_tensor_iter_fixed<scalar1, 8>(tensor1),
|
| 304 |
+
strided_tensor_iter_fixed<scalar2, 8>(tensor2),
|
| 305 |
+
strided_tensor_iter_fixed<scalar3, 8>(tensor3));
|
| 306 |
+
} else {
|
| 307 |
+
apply_op(
|
| 308 |
+
tensor1.numel(),
|
| 309 |
+
0,
|
| 310 |
+
op,
|
| 311 |
+
strided_tensor_iter<scalar1>(tensor1),
|
| 312 |
+
strided_tensor_iter<scalar2>(tensor2),
|
| 313 |
+
strided_tensor_iter<scalar3>(tensor3));
|
| 314 |
+
}
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
template <
|
| 318 |
+
typename scalar1,
|
| 319 |
+
typename scalar2,
|
| 320 |
+
typename scalar3,
|
| 321 |
+
typename scalar4,
|
| 322 |
+
typename Op>
|
| 323 |
+
inline void CPU_tensor_apply4(
|
| 324 |
+
Tensor tensor1,
|
| 325 |
+
Tensor tensor2,
|
| 326 |
+
Tensor tensor3,
|
| 327 |
+
Tensor tensor4,
|
| 328 |
+
const Op op) {
|
| 329 |
+
if (!_apply_preamble({tensor1, tensor2, tensor3, tensor4}))
|
| 330 |
+
return;
|
| 331 |
+
if (_max_dim_tensors({tensor1, tensor2, tensor3, tensor4}) <= 8) {
|
| 332 |
+
apply_op(
|
| 333 |
+
tensor1.numel(),
|
| 334 |
+
0,
|
| 335 |
+
op,
|
| 336 |
+
strided_tensor_iter_fixed<scalar1, 8>(tensor1),
|
| 337 |
+
strided_tensor_iter_fixed<scalar2, 8>(tensor2),
|
| 338 |
+
strided_tensor_iter_fixed<scalar3, 8>(tensor3),
|
| 339 |
+
strided_tensor_iter_fixed<scalar4, 8>(tensor4));
|
| 340 |
+
} else {
|
| 341 |
+
apply_op(
|
| 342 |
+
tensor1.numel(),
|
| 343 |
+
0,
|
| 344 |
+
op,
|
| 345 |
+
strided_tensor_iter<scalar1>(tensor1),
|
| 346 |
+
strided_tensor_iter<scalar2>(tensor2),
|
| 347 |
+
strided_tensor_iter<scalar3>(tensor3),
|
| 348 |
+
strided_tensor_iter<scalar4>(tensor4));
|
| 349 |
+
}
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
} // namespace at
|
| 353 |
+
|
| 354 |
+
#else
|
| 355 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 356 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/CPUFixedAllocator.h
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <c10/core/Allocator.h>
|
| 5 |
+
#include <c10/util/Exception.h>
|
| 6 |
+
|
| 7 |
+
// This file creates a fake allocator that just throws exceptions if
|
| 8 |
+
// it is actually used.
|
| 9 |
+
|
| 10 |
+
// state passed to the allocator is the std::function<void(void*)> called
|
| 11 |
+
// when the blob is release by ATen
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
|
| 15 |
+
static void* cpu_fixed_malloc(void*, ptrdiff_t) {
|
| 16 |
+
TORCH_CHECK(false, "attempting to resize a tensor view of an external blob");
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
static void* cpu_fixed_realloc(void*, void*, ptrdiff_t) {
|
| 20 |
+
TORCH_CHECK(false, "attempting to resize a tensor view of an external blob");
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
static void cpu_fixed_free(void* state, void* allocation) {
|
| 24 |
+
auto on_release = static_cast<std::function<void(void*)>*>(state);
|
| 25 |
+
(*on_release)(allocation);
|
| 26 |
+
delete on_release;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
static Allocator CPU_fixed_allocator = {
|
| 30 |
+
cpu_fixed_malloc,
|
| 31 |
+
cpu_fixed_realloc,
|
| 32 |
+
cpu_fixed_free};
|
| 33 |
+
|
| 34 |
+
} // namespace at
|
| 35 |
+
|
| 36 |
+
#else
|
| 37 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 38 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/CPUGeneratorImpl.h
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/core/Generator.h>
|
| 5 |
+
#include <ATen/core/MT19937RNGEngine.h>
|
| 6 |
+
#include <c10/core/GeneratorImpl.h>
|
| 7 |
+
#include <optional>
|
| 8 |
+
|
| 9 |
+
namespace at {
|
| 10 |
+
|
| 11 |
+
struct TORCH_API CPUGeneratorImpl : public c10::GeneratorImpl {
|
| 12 |
+
// Constructors
|
| 13 |
+
CPUGeneratorImpl(uint64_t seed_in = default_rng_seed_val);
|
| 14 |
+
~CPUGeneratorImpl() override = default;
|
| 15 |
+
|
| 16 |
+
// CPUGeneratorImpl methods
|
| 17 |
+
std::shared_ptr<CPUGeneratorImpl> clone() const;
|
| 18 |
+
void set_current_seed(uint64_t seed) override;
|
| 19 |
+
void set_offset(uint64_t offset) override;
|
| 20 |
+
uint64_t get_offset() const override;
|
| 21 |
+
uint64_t current_seed() const override;
|
| 22 |
+
uint64_t seed() override;
|
| 23 |
+
void set_state(const c10::TensorImpl& new_state) override;
|
| 24 |
+
c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
|
| 25 |
+
static c10::DeviceType device_type();
|
| 26 |
+
uint32_t random();
|
| 27 |
+
uint64_t random64();
|
| 28 |
+
std::optional<float> next_float_normal_sample();
|
| 29 |
+
std::optional<double> next_double_normal_sample();
|
| 30 |
+
void set_next_float_normal_sample(std::optional<float> randn);
|
| 31 |
+
void set_next_double_normal_sample(std::optional<double> randn);
|
| 32 |
+
at::mt19937 engine();
|
| 33 |
+
void set_engine(at::mt19937 engine);
|
| 34 |
+
|
| 35 |
+
private:
|
| 36 |
+
CPUGeneratorImpl* clone_impl() const override;
|
| 37 |
+
at::mt19937 engine_;
|
| 38 |
+
std::optional<float> next_float_normal_sample_;
|
| 39 |
+
std::optional<double> next_double_normal_sample_;
|
| 40 |
+
};
|
| 41 |
+
|
| 42 |
+
namespace detail {
|
| 43 |
+
|
| 44 |
+
TORCH_API const Generator& getDefaultCPUGenerator();
|
| 45 |
+
TORCH_API Generator
|
| 46 |
+
createCPUGenerator(uint64_t seed_val = default_rng_seed_val);
|
| 47 |
+
|
| 48 |
+
} // namespace detail
|
| 49 |
+
|
| 50 |
+
} // namespace at
|
| 51 |
+
|
| 52 |
+
#else
|
| 53 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 54 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/CUDAFunctions.h
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#include <ATen/core/TensorBody.h>
|
| 3 |
+
|
| 4 |
+
// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
|
| 5 |
+
// Code introduced to avoid cyclic dependency in static dispatch is no longer
|
| 6 |
+
// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
|
| 7 |
+
// to Operators.cpp for supporting multiple backends with multiple kernels.
|
| 8 |
+
//
|
| 9 |
+
// Note [Avoiding Include Cycles In Static Dispatch]
|
| 10 |
+
// In order to avoid #include cycles in the static dispatch build, we've carefully split out
|
| 11 |
+
// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
|
| 12 |
+
//
|
| 13 |
+
// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
|
| 14 |
+
// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
|
| 15 |
+
// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
|
| 16 |
+
// directly inlined into TensorBody.h.
|
| 17 |
+
// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
|
| 18 |
+
// which include functions that have defaultable std::optional<Tensor> arguments.
|
| 19 |
+
// That requires knowing the full Tensor class definition.
|
| 20 |
+
//
|
| 21 |
+
// We break the cycle by doing the following:
|
| 22 |
+
// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
|
| 23 |
+
// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
|
| 24 |
+
// - CPUFunctions_inl.h includes everything else
|
| 25 |
+
// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
|
| 26 |
+
// and then it includes CPUFunctions_inl.h.
|
| 27 |
+
// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
|
| 28 |
+
// - This also means that static dispatch build, CPUFunctions.h only needs to
|
| 29 |
+
// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
|
| 30 |
+
#include <ATen/CUDAFunctions_inl.h>
|
| 31 |
+
|
| 32 |
+
#else
|
| 33 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 34 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/CollapseDims.h
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#include <c10/util/Exception.h>
|
| 3 |
+
#include <utility>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
|
| 7 |
+
/*
|
| 8 |
+
[collapse dims] Updates sizes, and strides to reflect a "collapse" of
|
| 9 |
+
the info, possibly excluding the optional excludeDim. A "collapsed" version
|
| 10 |
+
of the info is the fewest dims that order the tensor's elements in the same
|
| 11 |
+
way as the original info. If excludeDim is specified, the collapse is the
|
| 12 |
+
fewest dims that order the tensor's elements as the original and preserve the
|
| 13 |
+
excluded dimension, unless the tensor collapses to a point.
|
| 14 |
+
|
| 15 |
+
This function returns a pair of values.
|
| 16 |
+
|
| 17 |
+
1) The (new) index of the preserved dimension if excludeDim is
|
| 18 |
+
specified. 0 if the tensor is collapsed to a point. -1
|
| 19 |
+
otherwise.
|
| 20 |
+
|
| 21 |
+
2) The new number of dimensions.
|
| 22 |
+
*/
|
| 23 |
+
template <typename T>
|
| 24 |
+
inline std::pair<int64_t, int64_t> collapse_dims(
|
| 25 |
+
T* sizes,
|
| 26 |
+
T* strides,
|
| 27 |
+
int64_t dims,
|
| 28 |
+
const int excludeDim = -1) {
|
| 29 |
+
TORCH_CHECK(
|
| 30 |
+
excludeDim >= -1 && excludeDim < dims,
|
| 31 |
+
"expected excluded dim between -1 and dims - 1");
|
| 32 |
+
|
| 33 |
+
int64_t stopDim = (excludeDim == -1) ? dims : excludeDim;
|
| 34 |
+
int64_t newIndex = -1;
|
| 35 |
+
int64_t oldIndex = 0;
|
| 36 |
+
int64_t remappedExcludedDim = -1;
|
| 37 |
+
|
| 38 |
+
while (oldIndex < dims) {
|
| 39 |
+
// Finds a dimension to collapse into
|
| 40 |
+
for (; oldIndex < stopDim; ++oldIndex) {
|
| 41 |
+
if (sizes[oldIndex] == 1) {
|
| 42 |
+
continue;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
++newIndex;
|
| 46 |
+
sizes[newIndex] = sizes[oldIndex];
|
| 47 |
+
strides[newIndex] = strides[oldIndex];
|
| 48 |
+
++oldIndex;
|
| 49 |
+
break;
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
// Collapses dims
|
| 53 |
+
for (; oldIndex < stopDim; ++oldIndex) {
|
| 54 |
+
if (sizes[oldIndex] == 1) {
|
| 55 |
+
continue;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
if (strides[newIndex] == sizes[oldIndex] * strides[oldIndex]) {
|
| 59 |
+
sizes[newIndex] *= sizes[oldIndex];
|
| 60 |
+
strides[newIndex] = strides[oldIndex];
|
| 61 |
+
} else {
|
| 62 |
+
++newIndex;
|
| 63 |
+
sizes[newIndex] = sizes[oldIndex];
|
| 64 |
+
strides[newIndex] = strides[oldIndex];
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
// Handles excludeDim being set (oldIndex == excludeDim)
|
| 69 |
+
if (oldIndex != dims) {
|
| 70 |
+
// Preserves excluded dimension
|
| 71 |
+
++newIndex;
|
| 72 |
+
sizes[newIndex] = sizes[oldIndex];
|
| 73 |
+
strides[newIndex] = strides[oldIndex];
|
| 74 |
+
remappedExcludedDim = newIndex;
|
| 75 |
+
|
| 76 |
+
// Restarts iteration after excludeDim
|
| 77 |
+
++oldIndex;
|
| 78 |
+
stopDim = dims;
|
| 79 |
+
}
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
// Handles special case of all dims size 1
|
| 83 |
+
if (newIndex == -1 || (newIndex == 0 && sizes[0] == 1)) {
|
| 84 |
+
dims = 1;
|
| 85 |
+
sizes[0] = 1;
|
| 86 |
+
strides[0] = 1;
|
| 87 |
+
|
| 88 |
+
return std::pair<int64_t, int64_t>(0, 1);
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
dims = newIndex + 1;
|
| 92 |
+
return std::pair<int64_t, int64_t>(remappedExcludedDim, dims);
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
} // namespace at
|
| 96 |
+
|
| 97 |
+
#else
|
| 98 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 99 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
|
| 4 |
+
|
| 5 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 6 |
+
|
| 7 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 8 |
+
#include <c10/core/MemoryFormat.h>
|
| 9 |
+
#include <c10/core/Scalar.h>
|
| 10 |
+
#include <ATen/core/Reduction.h>
|
| 11 |
+
|
| 12 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 13 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 14 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 15 |
+
Consider including a specific operator from \
|
| 16 |
+
<ATen/ops/{my_operator}_compositeimplicitautogradnestedtensor_dispatch.h>. \
|
| 17 |
+
See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 18 |
+
#endif
|
| 19 |
+
|
| 20 |
+
#include <ATen/ops/randn_like_compositeimplicitautogradnestedtensor_dispatch.h>
|
| 21 |
+
#include <ATen/ops/reshape_compositeimplicitautogradnestedtensor_dispatch.h>
|
| 22 |
+
#include <ATen/ops/reshape_as_compositeimplicitautogradnestedtensor_dispatch.h>
|
| 23 |
+
#include <ATen/ops/zeros_like_compositeimplicitautogradnestedtensor_dispatch.h>
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
#else
|
| 29 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 30 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/Config.h
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
// Test these using #if AT_MKL_ENABLED(), not #ifdef, so that it's
|
| 5 |
+
// obvious if you forgot to include Config.h
|
| 6 |
+
// c.f. https://stackoverflow.com/questions/33759787/generating-an-error-if-checked-boolean-macro-is-not-defined
|
| 7 |
+
//
|
| 8 |
+
// DO NOT put the macros for CUDA libraries in this file; they belong in cuda/CUDAConfig.h
|
| 9 |
+
|
| 10 |
+
#define AT_MKLDNN_ENABLED() 1
|
| 11 |
+
#define AT_MKLDNN_ACL_ENABLED() 0
|
| 12 |
+
#define AT_MKL_ENABLED() 1
|
| 13 |
+
#define AT_MKL_SEQUENTIAL() 0
|
| 14 |
+
#define AT_POCKETFFT_ENABLED() 0
|
| 15 |
+
#define AT_NNPACK_ENABLED() 1
|
| 16 |
+
#define CAFFE2_STATIC_LINK_CUDA() 0
|
| 17 |
+
#define AT_BUILD_WITH_BLAS() 1
|
| 18 |
+
#define AT_BUILD_WITH_LAPACK() 1
|
| 19 |
+
#define AT_PARALLEL_OPENMP 1
|
| 20 |
+
#define AT_PARALLEL_NATIVE 0
|
| 21 |
+
#define AT_BLAS_F2C() 0
|
| 22 |
+
#define AT_BLAS_USE_CBLAS_DOT() 0
|
| 23 |
+
#define AT_KLEIDIAI_ENABLED() 0
|
| 24 |
+
#define AT_USE_EIGEN_SPARSE() 0
|
| 25 |
+
|
| 26 |
+
#else
|
| 27 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 28 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/Device.h
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
#include <c10/core/Device.h>
|
| 4 |
+
|
| 5 |
+
#else
|
| 6 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 7 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/DeviceAccelerator.h
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <c10/core/CachingDeviceAllocator.h>
|
| 5 |
+
#include <c10/core/DeviceCapability.h>
|
| 6 |
+
#include <c10/core/DeviceType.h>
|
| 7 |
+
#include <c10/macros/Macros.h>
|
| 8 |
+
|
| 9 |
+
#include <ATen/detail/MTIAHooksInterface.h>
|
| 10 |
+
#include <optional>
|
| 11 |
+
|
| 12 |
+
namespace at::accelerator {
|
| 13 |
+
|
| 14 |
+
// Note [Accelerator Concept]
|
| 15 |
+
// This file defines the top level Accelerator concept for PyTorch.
|
| 16 |
+
// A device is an accelerator per the definition here if:
|
| 17 |
+
// - It is mutually exclusive with all other accelerators
|
| 18 |
+
// - It performs asynchronous compute via a Stream/Event system
|
| 19 |
+
// - It provides a set of common APIs as defined by AcceleratorHooksInterface
|
| 20 |
+
//
|
| 21 |
+
// As of today, accelerator devices are (in no particular order):
|
| 22 |
+
// CUDA, MTIA, XPU, HIP, MPS, PrivateUse1
|
| 23 |
+
|
| 24 |
+
// Ensures that only one accelerator is available (at
|
| 25 |
+
// compile time if possible) and return it.
|
| 26 |
+
// When checked is true, the returned optional always has a value.
|
| 27 |
+
TORCH_API std::optional<c10::DeviceType> getAccelerator(bool checked = false);
|
| 28 |
+
|
| 29 |
+
// Check if the given device type is an accelerator.
|
| 30 |
+
TORCH_API bool isAccelerator(c10::DeviceType device_type);
|
| 31 |
+
|
| 32 |
+
// Check if the given device type is an accelerator, not the excluded ones.
|
| 33 |
+
template <
|
| 34 |
+
typename... T,
|
| 35 |
+
typename = std::enable_if_t<(std::is_same_v<T, c10::DeviceType> && ...)>>
|
| 36 |
+
inline bool isAcceleratorExcluded(
|
| 37 |
+
c10::DeviceType device_type,
|
| 38 |
+
c10::DeviceType first_excluded,
|
| 39 |
+
T... rest_excluded) {
|
| 40 |
+
if constexpr (sizeof...(rest_excluded) > 0) {
|
| 41 |
+
return device_type != first_excluded &&
|
| 42 |
+
isAcceleratorExcluded(device_type, rest_excluded...);
|
| 43 |
+
} else {
|
| 44 |
+
return device_type != first_excluded && isAccelerator(device_type);
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
// Return the number of the device available. Note that this is *REQUIRED* to
|
| 49 |
+
// not raise any exception.
|
| 50 |
+
TORCH_API c10::DeviceIndex deviceCount();
|
| 51 |
+
|
| 52 |
+
// Set the current device index to the given device index.
|
| 53 |
+
TORCH_API void setDeviceIndex(c10::DeviceIndex device_index);
|
| 54 |
+
|
| 55 |
+
// Get the current device index.
|
| 56 |
+
TORCH_API c10::DeviceIndex getDeviceIndex();
|
| 57 |
+
|
| 58 |
+
// Set the current stream to a given stream. Note that this API doesn't change
|
| 59 |
+
// the current device index.
|
| 60 |
+
TORCH_API void setCurrentStream(c10::Stream stream);
|
| 61 |
+
|
| 62 |
+
// Get the current stream of the given device index.
|
| 63 |
+
TORCH_API c10::Stream getCurrentStream(c10::DeviceIndex device_index);
|
| 64 |
+
|
| 65 |
+
// Wait (by blocking the calling thread) until all the work previously enqueued
|
| 66 |
+
// on the given device index has been completed.
|
| 67 |
+
TORCH_API void synchronizeDevice(c10::DeviceIndex device_index);
|
| 68 |
+
|
| 69 |
+
// Set the current device index to the given device_index and return the
|
| 70 |
+
// original device index that was active before the change.
|
| 71 |
+
TORCH_API c10::DeviceIndex exchangeDevice(c10::DeviceIndex device_index);
|
| 72 |
+
|
| 73 |
+
// Set the current device index to the given device_index. Avoid creating a new
|
| 74 |
+
// context if the context for device_index is not initialized. Return the
|
| 75 |
+
// original device index that was active before the change.
|
| 76 |
+
TORCH_API c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index);
|
| 77 |
+
|
| 78 |
+
// Get the device capability of the given device index.
|
| 79 |
+
TORCH_API c10::DeviceCapability getDeviceCapability(
|
| 80 |
+
c10::DeviceIndex device_index);
|
| 81 |
+
|
| 82 |
+
TORCH_API inline void emptyCache() {
|
| 83 |
+
const auto device_type = getAccelerator(true).value();
|
| 84 |
+
at::getDeviceAllocator(device_type)->emptyCache();
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
TORCH_API inline at::CachingDeviceAllocator::DeviceStats getDeviceStats(
|
| 88 |
+
c10::DeviceIndex device_index) {
|
| 89 |
+
const auto device_type = getAccelerator(true).value();
|
| 90 |
+
return at::getDeviceAllocator(device_type)->getDeviceStats(device_index);
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
TORCH_API inline void resetAccumulatedStats(c10::DeviceIndex device_index) {
|
| 94 |
+
const auto device_type = getAccelerator(true).value();
|
| 95 |
+
at::getDeviceAllocator(device_type)->resetAccumulatedStats(device_index);
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) {
|
| 99 |
+
const auto device_type = getAccelerator(true).value();
|
| 100 |
+
at::getDeviceAllocator(device_type)->resetPeakStats(device_index);
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
TORCH_API inline std::pair<size_t, size_t> getMemoryInfo(
|
| 104 |
+
c10::DeviceIndex device_index) {
|
| 105 |
+
const auto device_type = getAccelerator(true).value();
|
| 106 |
+
return at::getDeviceAllocator(device_type)->getMemoryInfo(device_index);
|
| 107 |
+
}
|
| 108 |
+
} // namespace at::accelerator
|
| 109 |
+
|
| 110 |
+
namespace at {
|
| 111 |
+
// Keep BC only
|
| 112 |
+
using at::accelerator::getAccelerator;
|
| 113 |
+
using at::accelerator::isAccelerator;
|
| 114 |
+
} // namespace at
|
| 115 |
+
|
| 116 |
+
#else
|
| 117 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 118 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/DimVector.h
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
#include <ATen/core/DimVector.h>
|
| 4 |
+
|
| 5 |
+
#else
|
| 6 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 7 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/Dispatch_v2.h
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <torch/headeronly/core/Dispatch_v2.h>
|
| 5 |
+
|
| 6 |
+
// Get AT_DISPATCH_SWITCH and AT_DISPATCH_CASE:
|
| 7 |
+
#include <ATen/Dispatch.h>
|
| 8 |
+
|
| 9 |
+
// This is a new implementation of the AT_DISPATCH macro family from
|
| 10 |
+
// ATen/Dispatch.h
|
| 11 |
+
//
|
| 12 |
+
// The intended usage is:
|
| 13 |
+
//
|
| 14 |
+
// ScalarType scalar_type;
|
| 15 |
+
//
|
| 16 |
+
// AT_DISPATCH_V2(
|
| 17 |
+
// scalar_type,
|
| 18 |
+
// "debug string",
|
| 19 |
+
// AT_WRAP([&] {
|
| 20 |
+
// ... code to specialize with scalar_t ...
|
| 21 |
+
// }),
|
| 22 |
+
// kHalf,
|
| 23 |
+
// AT_EXPAND(AT_ALL_TYPES),
|
| 24 |
+
// ... as many types arguments as needed ...
|
| 25 |
+
// )
|
| 26 |
+
//
|
| 27 |
+
// For example, given an old style:
|
| 28 |
+
//
|
| 29 |
+
// AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
|
| 30 |
+
// kComplexHalf,
|
| 31 |
+
// kHalf,
|
| 32 |
+
// self.scalar_type(),
|
| 33 |
+
// "_local_scalar_dense_cpu",
|
| 34 |
+
// [&] {
|
| 35 |
+
// scalar_t value = *self.data_ptr<scalar_t>();
|
| 36 |
+
// r = Scalar(value);
|
| 37 |
+
// }
|
| 38 |
+
// )
|
| 39 |
+
//
|
| 40 |
+
// You now write:
|
| 41 |
+
//
|
| 42 |
+
// AT_DISPATCH_V2(
|
| 43 |
+
// self.scalar_type(),
|
| 44 |
+
// "_local_scalar_dense_cpu",
|
| 45 |
+
// AT_WRAP([&] {
|
| 46 |
+
// scalar_t value = *self.data_ptr<scalar_t>();
|
| 47 |
+
// r = Scalar(value);
|
| 48 |
+
// }),
|
| 49 |
+
// AT_EXPAND(AT_ALL_TYPES),
|
| 50 |
+
// AT_EXPAND(AT_COMPLEX_TYPES),
|
| 51 |
+
// kComplexHalf,
|
| 52 |
+
// kHalf,
|
| 53 |
+
// )
|
| 54 |
+
//
|
| 55 |
+
// Notably, it sports the following improvements:
|
| 56 |
+
//
|
| 57 |
+
// - It is not necessary to specify the arity (e.g.,
|
| 58 |
+
// AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND{2,3,4,...})
|
| 59 |
+
// when using the macro
|
| 60 |
+
//
|
| 61 |
+
// - It is not necessary to specify each dtype individually; if
|
| 62 |
+
// there is a set of related dtypes and you want to dispatch
|
| 63 |
+
// over all of them, you can simply say, e.g., AT_EXPAND(AT_INTEGRAL_TYPES)
|
| 64 |
+
// in your argument list.
|
| 65 |
+
//
|
| 66 |
+
// However, you must remember to wrap the payload body in AT_WRAP, or commas
|
| 67 |
+
// inside your lambda will be improperly handled. Furthermore, if you more
|
| 68 |
+
// entries to ScalarType than can be supported by this macro, it will fail
|
| 69 |
+
// with an obscure error (due to attempting to concatenate AT_AP with
|
| 70 |
+
// something that is not a number).
|
| 71 |
+
//
|
| 72 |
+
// The implementation strategy is to use the count arguments trick
|
| 73 |
+
// (e.g., as described in https://stackoverflow.com/a/2124385/23845)
|
| 74 |
+
// to discover how many dtypes have been passed, and then dispatch to a
|
| 75 |
+
// hand-written macro for each arity that applies as many DISPATCH_CASE as
|
| 76 |
+
// necessary. The hand-written macros can be regenerated for other arities
|
| 77 |
+
// with the script below.
|
| 78 |
+
//
|
| 79 |
+
// There is some delicacy in the implementation in controlling when
|
| 80 |
+
// macro expansion occurs, mediated with AT_EXPAND and AT_GUARD. I mostly
|
| 81 |
+
// relied on GPT4 to help me get it right.
|
| 82 |
+
|
| 83 |
+
// See documentation above
|
| 84 |
+
#define AT_DISPATCH_V2(TYPE, NAME, BODY, ...) \
|
| 85 |
+
THO_DISPATCH_V2_TMPL( \
|
| 86 |
+
AT_DISPATCH_SWITCH, \
|
| 87 |
+
AT_DISPATCH_CASE, \
|
| 88 |
+
TYPE, \
|
| 89 |
+
NAME, \
|
| 90 |
+
AT_WRAP(BODY), \
|
| 91 |
+
__VA_ARGS__)
|
| 92 |
+
|
| 93 |
+
// Unused helper macros, kept for BC:
|
| 94 |
+
#define AT_AP_VAR(N, T, ...) \
|
| 95 |
+
AT_EXPAND(AT_CONCAT(AT_AP, AT_NUM_ARGS(__VA_ARGS__))(AT_WRAP(N), __VA_ARGS__))
|
| 96 |
+
|
| 97 |
+
// Ensure we never have too many scalar types for the expansion here to
|
| 98 |
+
// support. To bump this, you must regenerate the macros below.
|
| 99 |
+
static_assert(static_cast<int>(c10::ScalarType::NumOptions) < 60);
|
| 100 |
+
|
| 101 |
+
// Python code to regenerate generate code below:
|
| 102 |
+
#if 0
|
| 103 |
+
|
| 104 |
+
num_args = 60
|
| 105 |
+
|
| 106 |
+
for i in range(1, num_args+1):
|
| 107 |
+
args = ', '.join(f'_{i}' for i in range(1, i+1))
|
| 108 |
+
cases = ' '.join([f'AT_DISPATCH_CASE(_{j}, N)' for j in range(1, i+1)])
|
| 109 |
+
print(f'#define AT_AP{i}(N, {args}) {cases}')
|
| 110 |
+
|
| 111 |
+
#endif
|
| 112 |
+
|
| 113 |
+
// Begin generated code
|
| 114 |
+
// clang-format off
|
| 115 |
+
|
| 116 |
+
#define AT_AP1(N, _1) AT_DISPATCH_CASE(_1, N)
|
| 117 |
+
#define AT_AP2(N, _1, _2) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N)
|
| 118 |
+
#define AT_AP3(N, _1, _2, _3) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N)
|
| 119 |
+
#define AT_AP4(N, _1, _2, _3, _4) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N)
|
| 120 |
+
#define AT_AP5(N, _1, _2, _3, _4, _5) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N)
|
| 121 |
+
#define AT_AP6(N, _1, _2, _3, _4, _5, _6) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N)
|
| 122 |
+
#define AT_AP7(N, _1, _2, _3, _4, _5, _6, _7) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N)
|
| 123 |
+
#define AT_AP8(N, _1, _2, _3, _4, _5, _6, _7, _8) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N)
|
| 124 |
+
#define AT_AP9(N, _1, _2, _3, _4, _5, _6, _7, _8, _9) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N)
|
| 125 |
+
#define AT_AP10(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N)
|
| 126 |
+
#define AT_AP11(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N)
|
| 127 |
+
#define AT_AP12(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N)
|
| 128 |
+
#define AT_AP13(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N)
|
| 129 |
+
#define AT_AP14(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N)
|
| 130 |
+
#define AT_AP15(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N)
|
| 131 |
+
#define AT_AP16(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N)
|
| 132 |
+
#define AT_AP17(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N)
|
| 133 |
+
#define AT_AP18(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N)
|
| 134 |
+
#define AT_AP19(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N)
|
| 135 |
+
#define AT_AP20(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N)
|
| 136 |
+
#define AT_AP21(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N)
|
| 137 |
+
#define AT_AP22(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N)
|
| 138 |
+
#define AT_AP23(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N)
|
| 139 |
+
#define AT_AP24(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N)
|
| 140 |
+
#define AT_AP25(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N)
|
| 141 |
+
#define AT_AP26(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N)
|
| 142 |
+
#define AT_AP27(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N)
|
| 143 |
+
#define AT_AP28(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N)
|
| 144 |
+
#define AT_AP29(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N)
|
| 145 |
+
#define AT_AP30(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N)
|
| 146 |
+
#define AT_AP31(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N)
|
| 147 |
+
#define AT_AP32(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N)
|
| 148 |
+
#define AT_AP33(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N)
|
| 149 |
+
#define AT_AP34(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N)
|
| 150 |
+
#define AT_AP35(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N)
|
| 151 |
+
#define AT_AP36(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N)
|
| 152 |
+
#define AT_AP37(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N)
|
| 153 |
+
#define AT_AP38(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N)
|
| 154 |
+
#define AT_AP39(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N)
|
| 155 |
+
#define AT_AP40(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N)
|
| 156 |
+
#define AT_AP41(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N)
|
| 157 |
+
#define AT_AP42(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N)
|
| 158 |
+
#define AT_AP43(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N)
|
| 159 |
+
#define AT_AP44(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N)
|
| 160 |
+
#define AT_AP45(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N)
|
| 161 |
+
#define AT_AP46(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N)
|
| 162 |
+
#define AT_AP47(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N)
|
| 163 |
+
#define AT_AP48(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N)
|
| 164 |
+
#define AT_AP49(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N)
|
| 165 |
+
#define AT_AP50(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N)
|
| 166 |
+
#define AT_AP51(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N)
|
| 167 |
+
#define AT_AP52(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N)
|
| 168 |
+
#define AT_AP53(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N)
|
| 169 |
+
#define AT_AP54(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N)
|
| 170 |
+
#define AT_AP55(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) AT_DISPATCH_CASE(_55, N)
|
| 171 |
+
#define AT_AP56(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) AT_DISPATCH_CASE(_55, N) AT_DISPATCH_CASE(_56, N)
|
| 172 |
+
#define AT_AP57(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) AT_DISPATCH_CASE(_55, N) AT_DISPATCH_CASE(_56, N) AT_DISPATCH_CASE(_57, N)
|
| 173 |
+
#define AT_AP58(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) AT_DISPATCH_CASE(_55, N) AT_DISPATCH_CASE(_56, N) AT_DISPATCH_CASE(_57, N) AT_DISPATCH_CASE(_58, N)
|
| 174 |
+
#define AT_AP59(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) AT_DISPATCH_CASE(_55, N) AT_DISPATCH_CASE(_56, N) AT_DISPATCH_CASE(_57, N) AT_DISPATCH_CASE(_58, N) AT_DISPATCH_CASE(_59, N)
|
| 175 |
+
#define AT_AP60(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) AT_DISPATCH_CASE(_55, N) AT_DISPATCH_CASE(_56, N) AT_DISPATCH_CASE(_57, N) AT_DISPATCH_CASE(_58, N) AT_DISPATCH_CASE(_59, N) AT_DISPATCH_CASE(_60, N)
|
| 176 |
+
|
| 177 |
+
// End generated code
|
| 178 |
+
// clang-format on
|
| 179 |
+
|
| 180 |
+
#else
|
| 181 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 182 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/DynamicLibrary.h
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/Utils.h>
|
| 5 |
+
#include <c10/macros/Export.h>
|
| 6 |
+
#include <c10/util/Exception.h>
|
| 7 |
+
|
| 8 |
+
namespace c10 {
|
| 9 |
+
|
| 10 |
+
class DynamicLibraryError : public Error {
|
| 11 |
+
using Error::Error;
|
| 12 |
+
};
|
| 13 |
+
|
| 14 |
+
} // namespace c10
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
struct DynamicLibrary {
|
| 19 |
+
AT_DISALLOW_COPY_AND_ASSIGN(DynamicLibrary);
|
| 20 |
+
DynamicLibrary(DynamicLibrary&& other) = delete;
|
| 21 |
+
DynamicLibrary& operator=(DynamicLibrary&&) = delete;
|
| 22 |
+
|
| 23 |
+
TORCH_API DynamicLibrary(
|
| 24 |
+
const char* name,
|
| 25 |
+
const char* alt_name = nullptr,
|
| 26 |
+
bool leak_handle = false);
|
| 27 |
+
|
| 28 |
+
TORCH_API void* sym(const char* name);
|
| 29 |
+
|
| 30 |
+
TORCH_API ~DynamicLibrary();
|
| 31 |
+
|
| 32 |
+
private:
|
| 33 |
+
bool leak_handle;
|
| 34 |
+
void* handle = nullptr;
|
| 35 |
+
};
|
| 36 |
+
|
| 37 |
+
} // namespace at
|
| 38 |
+
|
| 39 |
+
#else
|
| 40 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 41 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/EmptyTensor.h
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
#include <ATen/core/TensorBase.h>
|
| 4 |
+
|
| 5 |
+
namespace at::detail {
|
| 6 |
+
|
| 7 |
+
inline void check_size_nonnegative(ArrayRef<int64_t> size) {
|
| 8 |
+
for (const auto& x : size) {
|
| 9 |
+
TORCH_CHECK(
|
| 10 |
+
x >= 0,
|
| 11 |
+
"Trying to create tensor with negative dimension ",
|
| 12 |
+
x,
|
| 13 |
+
": ",
|
| 14 |
+
size);
|
| 15 |
+
}
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
inline void check_size_nonnegative(ArrayRef<c10::SymInt> size) {
|
| 19 |
+
for (const auto& x : size) {
|
| 20 |
+
TORCH_SYM_CHECK(
|
| 21 |
+
x.sym_ge(0),
|
| 22 |
+
"Trying to create tensor with negative dimension ",
|
| 23 |
+
x,
|
| 24 |
+
": ",
|
| 25 |
+
size);
|
| 26 |
+
}
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
TORCH_API size_t computeStorageNbytesContiguous(
|
| 30 |
+
IntArrayRef sizes,
|
| 31 |
+
size_t itemsize,
|
| 32 |
+
size_t storage_offset = 0);
|
| 33 |
+
TORCH_API SymInt computeStorageNbytesContiguous(
|
| 34 |
+
SymIntArrayRef sizes,
|
| 35 |
+
const SymInt& itemsize,
|
| 36 |
+
const SymInt& storage_offset = 0);
|
| 37 |
+
TORCH_API size_t computeStorageNbytes(
|
| 38 |
+
IntArrayRef sizes,
|
| 39 |
+
IntArrayRef strides,
|
| 40 |
+
size_t itemsize,
|
| 41 |
+
size_t storage_offset = 0);
|
| 42 |
+
TORCH_API SymInt computeStorageNbytes(
|
| 43 |
+
SymIntArrayRef sizes,
|
| 44 |
+
SymIntArrayRef strides,
|
| 45 |
+
const SymInt& itemsize,
|
| 46 |
+
const SymInt& storage_offset = 0);
|
| 47 |
+
|
| 48 |
+
TORCH_API TensorBase empty_generic(
|
| 49 |
+
IntArrayRef size,
|
| 50 |
+
c10::Allocator* allocator,
|
| 51 |
+
c10::DispatchKeySet ks,
|
| 52 |
+
ScalarType scalar_type,
|
| 53 |
+
std::optional<c10::MemoryFormat> memory_format_opt);
|
| 54 |
+
|
| 55 |
+
TORCH_API TensorBase empty_generic_symint(
|
| 56 |
+
SymIntArrayRef size,
|
| 57 |
+
c10::Allocator* allocator,
|
| 58 |
+
c10::DispatchKeySet ks,
|
| 59 |
+
ScalarType scalar_type,
|
| 60 |
+
std::optional<c10::MemoryFormat> memory_format_opt);
|
| 61 |
+
|
| 62 |
+
TORCH_API TensorBase empty_strided_generic(
|
| 63 |
+
IntArrayRef size,
|
| 64 |
+
IntArrayRef stride,
|
| 65 |
+
c10::Allocator* allocator,
|
| 66 |
+
c10::DispatchKeySet ks,
|
| 67 |
+
ScalarType scalar_type);
|
| 68 |
+
|
| 69 |
+
TORCH_API TensorBase empty_strided_symint_generic(
|
| 70 |
+
SymIntArrayRef size,
|
| 71 |
+
SymIntArrayRef stride,
|
| 72 |
+
c10::Allocator* allocator,
|
| 73 |
+
c10::DispatchKeySet ks,
|
| 74 |
+
ScalarType scalar_type);
|
| 75 |
+
|
| 76 |
+
TORCH_API TensorBase empty_cpu(
|
| 77 |
+
IntArrayRef size,
|
| 78 |
+
ScalarType dtype,
|
| 79 |
+
bool pin_memory = false,
|
| 80 |
+
std::optional<c10::MemoryFormat> memory_format_opt = std::nullopt);
|
| 81 |
+
|
| 82 |
+
TORCH_API TensorBase empty_cpu(
|
| 83 |
+
IntArrayRef size,
|
| 84 |
+
std::optional<ScalarType> dtype_opt,
|
| 85 |
+
std::optional<Layout> layout_opt,
|
| 86 |
+
std::optional<Device> device_opt,
|
| 87 |
+
std::optional<bool> pin_memory_opt,
|
| 88 |
+
std::optional<c10::MemoryFormat> memory_format_opt);
|
| 89 |
+
|
| 90 |
+
TORCH_API TensorBase empty_cpu(IntArrayRef size, const TensorOptions& options);
|
| 91 |
+
|
| 92 |
+
TORCH_API TensorBase empty_strided_cpu(
|
| 93 |
+
IntArrayRef size,
|
| 94 |
+
IntArrayRef stride,
|
| 95 |
+
ScalarType dtype,
|
| 96 |
+
bool pin_memory = false);
|
| 97 |
+
|
| 98 |
+
TORCH_API TensorBase empty_strided_cpu(
|
| 99 |
+
IntArrayRef size,
|
| 100 |
+
IntArrayRef stride,
|
| 101 |
+
std::optional<ScalarType> dtype_opt,
|
| 102 |
+
std::optional<Layout> layout_opt,
|
| 103 |
+
std::optional<Device> device_opt,
|
| 104 |
+
std::optional<bool> pin_memory_opt);
|
| 105 |
+
|
| 106 |
+
TORCH_API TensorBase empty_strided_cpu(
|
| 107 |
+
IntArrayRef size,
|
| 108 |
+
IntArrayRef stride,
|
| 109 |
+
const TensorOptions& options);
|
| 110 |
+
|
| 111 |
+
TORCH_API TensorBase empty_meta(
|
| 112 |
+
IntArrayRef size,
|
| 113 |
+
ScalarType dtype,
|
| 114 |
+
std::optional<c10::MemoryFormat> memory_format_opt = std::nullopt);
|
| 115 |
+
|
| 116 |
+
TORCH_API TensorBase empty_meta(
|
| 117 |
+
IntArrayRef size,
|
| 118 |
+
std::optional<ScalarType> dtype_opt,
|
| 119 |
+
std::optional<Layout> layout_opt,
|
| 120 |
+
std::optional<Device> device_opt,
|
| 121 |
+
std::optional<bool> pin_memory_opt,
|
| 122 |
+
std::optional<c10::MemoryFormat> memory_format_opt);
|
| 123 |
+
|
| 124 |
+
TORCH_API TensorBase empty_symint_meta(
|
| 125 |
+
SymIntArrayRef size,
|
| 126 |
+
std::optional<ScalarType> dtype_opt,
|
| 127 |
+
std::optional<Layout> layout_opt,
|
| 128 |
+
std::optional<Device> device_opt,
|
| 129 |
+
std::optional<bool> pin_memory_opt,
|
| 130 |
+
std::optional<c10::MemoryFormat> memory_format_opt);
|
| 131 |
+
|
| 132 |
+
TORCH_API TensorBase empty_meta(IntArrayRef size, const TensorOptions& options);
|
| 133 |
+
|
| 134 |
+
TORCH_API TensorBase
|
| 135 |
+
empty_strided_meta(IntArrayRef size, IntArrayRef stride, ScalarType dtype);
|
| 136 |
+
|
| 137 |
+
TORCH_API TensorBase empty_strided_meta(
|
| 138 |
+
IntArrayRef size,
|
| 139 |
+
IntArrayRef stride,
|
| 140 |
+
std::optional<ScalarType> dtype_opt,
|
| 141 |
+
std::optional<Layout> layout_opt,
|
| 142 |
+
std::optional<Device> device_opt,
|
| 143 |
+
std::optional<bool> pin_memory_opt);
|
| 144 |
+
|
| 145 |
+
TORCH_API TensorBase empty_strided_meta(
|
| 146 |
+
IntArrayRef size,
|
| 147 |
+
IntArrayRef stride,
|
| 148 |
+
const TensorOptions& options);
|
| 149 |
+
|
| 150 |
+
TORCH_API TensorBase empty_strided_symint_meta(
|
| 151 |
+
SymIntArrayRef size,
|
| 152 |
+
SymIntArrayRef stride,
|
| 153 |
+
ScalarType dtype);
|
| 154 |
+
|
| 155 |
+
TORCH_API TensorBase empty_strided_symint_meta(
|
| 156 |
+
SymIntArrayRef size,
|
| 157 |
+
SymIntArrayRef stride,
|
| 158 |
+
std::optional<ScalarType> dtype_opt,
|
| 159 |
+
std::optional<Layout> layout_opt,
|
| 160 |
+
std::optional<Device> device_opt);
|
| 161 |
+
|
| 162 |
+
TORCH_API TensorBase empty_strided_symint_meta(
|
| 163 |
+
SymIntArrayRef size,
|
| 164 |
+
SymIntArrayRef stride,
|
| 165 |
+
const TensorOptions& options);
|
| 166 |
+
|
| 167 |
+
} // namespace at::detail
|
| 168 |
+
|
| 169 |
+
#else
|
| 170 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 171 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/ExpandUtils.h
ADDED
|
@@ -0,0 +1,540 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 5 |
+
#include <ATen/Functions.h>
|
| 6 |
+
#else
|
| 7 |
+
#include <ATen/ops/view.h>
|
| 8 |
+
#include <ATen/ops/view_copy.h>
|
| 9 |
+
#endif
|
| 10 |
+
|
| 11 |
+
#include <ATen/Tensor.h>
|
| 12 |
+
#include <ATen/core/DimVector.h>
|
| 13 |
+
#include <c10/util/Exception.h>
|
| 14 |
+
#include <c10/util/MaybeOwned.h>
|
| 15 |
+
#include <c10/util/irange.h>
|
| 16 |
+
|
| 17 |
+
#include <functional>
|
| 18 |
+
#include <tuple>
|
| 19 |
+
#include <utility>
|
| 20 |
+
|
| 21 |
+
namespace at {
|
| 22 |
+
|
| 23 |
+
TORCH_API std::vector<int64_t> infer_size(IntArrayRef a, IntArrayRef b);
|
| 24 |
+
TORCH_API std::vector<SymInt> infer_size_symint(
|
| 25 |
+
SymIntArrayRef a,
|
| 26 |
+
SymIntArrayRef b);
|
| 27 |
+
TORCH_API DimVector infer_size_dimvector(IntArrayRef a, IntArrayRef b);
|
| 28 |
+
TORCH_API SymDimVector
|
| 29 |
+
infer_size_symdimvector(SymIntArrayRef a, SymIntArrayRef b);
|
| 30 |
+
|
| 31 |
+
// Named type instead of a pair/tuple so that we can be sure to
|
| 32 |
+
// construct the vectors in place and get NRVO.
|
| 33 |
+
template <typename Container>
|
| 34 |
+
struct InferExpandGeometryResult {
|
| 35 |
+
Container sizes;
|
| 36 |
+
Container strides;
|
| 37 |
+
explicit InferExpandGeometryResult(size_t ndim)
|
| 38 |
+
: sizes(ndim), strides(ndim) {}
|
| 39 |
+
explicit InferExpandGeometryResult(IntArrayRef sizes_, size_t ndim)
|
| 40 |
+
: sizes(sizes_.begin(), sizes_.end()), strides(ndim) {}
|
| 41 |
+
};
|
| 42 |
+
|
| 43 |
+
TORCH_API std::tuple<std::vector<int64_t>, std::vector<int64_t>>
|
| 44 |
+
inferExpandGeometry(
|
| 45 |
+
IntArrayRef tensor_sizes,
|
| 46 |
+
IntArrayRef tensor_strides,
|
| 47 |
+
IntArrayRef sizes);
|
| 48 |
+
|
| 49 |
+
TORCH_API InferExpandGeometryResult<DimVector> inferExpandGeometry_dimvector(
|
| 50 |
+
IntArrayRef tensor_sizes,
|
| 51 |
+
IntArrayRef tensor_strides,
|
| 52 |
+
IntArrayRef sizes);
|
| 53 |
+
|
| 54 |
+
TORCH_API std::vector<int64_t> infer_dense_strides(
|
| 55 |
+
IntArrayRef tensor_sizes,
|
| 56 |
+
IntArrayRef tensor_strides);
|
| 57 |
+
|
| 58 |
+
// True if input shapes are expandable
|
| 59 |
+
// NOTE: infer_size did a similar check, please keep them sync if change is
|
| 60 |
+
// needed
|
| 61 |
+
inline bool are_expandable(IntArrayRef shape1, IntArrayRef shape2) {
|
| 62 |
+
size_t ndim1 = shape1.size();
|
| 63 |
+
size_t ndim2 = shape2.size();
|
| 64 |
+
size_t ndim = ndim1 < ndim2 ? ndim1 : ndim2;
|
| 65 |
+
|
| 66 |
+
for (int64_t i = static_cast<int64_t>(ndim) - 1; i >= 0; --i) {
|
| 67 |
+
if (shape1[--ndim1] == shape2[--ndim2] || shape1[ndim1] == 1 ||
|
| 68 |
+
shape2[ndim2] == 1) {
|
| 69 |
+
continue;
|
| 70 |
+
}
|
| 71 |
+
return false;
|
| 72 |
+
}
|
| 73 |
+
return true;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
// avoid copy-construction of Tensor by using a reference_wrapper.
|
| 77 |
+
inline void check_defined(
|
| 78 |
+
std::initializer_list<std::reference_wrapper<const Tensor>> tensors,
|
| 79 |
+
const char* api_name) {
|
| 80 |
+
for (auto& t : tensors) {
|
| 81 |
+
if (!t.get().defined()) {
|
| 82 |
+
TORCH_CHECK(false, api_name, "(...) called with an undefined Tensor");
|
| 83 |
+
}
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
// NOTE [ ExpandUtils Borrowing ]
|
| 88 |
+
//
|
| 89 |
+
// Functions in ExpandUtils return `c10::MaybeOwned<Tensor>` because
|
| 90 |
+
// expansion may not actually be needed, in which case we can improve
|
| 91 |
+
// efficiency by returning
|
| 92 |
+
// `c10::MaybeOwned<Tensor>::borrowed(to_expand)`. However, this means
|
| 93 |
+
// that you need to be careful: the returned `c10::MaybeOwned<Tensor>`
|
| 94 |
+
// must not outlive the original `Tensor` object that `to_expand`
|
| 95 |
+
// referred to! The deleted rvalue reference overloads of these
|
| 96 |
+
// functions help with this by preventing trivial use of a temporary
|
| 97 |
+
// resulting from a function call, but it is still possible to make a
|
| 98 |
+
// mistake.
|
| 99 |
+
|
| 100 |
+
inline c10::MaybeOwned<Tensor> expand_inplace(
|
| 101 |
+
const Tensor& tensor,
|
| 102 |
+
const Tensor& to_expand) {
|
| 103 |
+
if (tensor.sym_sizes().equals(to_expand.sym_sizes())) {
|
| 104 |
+
return c10::MaybeOwned<Tensor>::borrowed(to_expand);
|
| 105 |
+
}
|
| 106 |
+
return c10::MaybeOwned<Tensor>::owned(
|
| 107 |
+
to_expand.expand_symint(tensor.sym_sizes()));
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
inline c10::MaybeOwned<Tensor> expand_inplace(
|
| 111 |
+
const Tensor& tensor,
|
| 112 |
+
Tensor&& to_expand) = delete;
|
| 113 |
+
|
| 114 |
+
inline c10::MaybeOwned<Tensor> expand_inplace(
|
| 115 |
+
const Tensor& tensor,
|
| 116 |
+
const Tensor& to_expand,
|
| 117 |
+
const char* api_name) {
|
| 118 |
+
check_defined({tensor, to_expand}, api_name);
|
| 119 |
+
return expand_inplace(tensor, to_expand);
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
inline c10::MaybeOwned<Tensor> expand_inplace(
|
| 123 |
+
const Tensor& tensor,
|
| 124 |
+
Tensor&& to_expand,
|
| 125 |
+
const char* api_name) = delete;
|
| 126 |
+
|
| 127 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 128 |
+
expand_inplace(
|
| 129 |
+
const Tensor& tensor,
|
| 130 |
+
const Tensor& to_expand1,
|
| 131 |
+
const Tensor& to_expand2) {
|
| 132 |
+
if (tensor.sizes().equals(to_expand1.sizes()) &&
|
| 133 |
+
tensor.sizes().equals((to_expand2.sizes()))) {
|
| 134 |
+
return std::make_tuple(
|
| 135 |
+
c10::MaybeOwned<Tensor>::borrowed(to_expand1),
|
| 136 |
+
c10::MaybeOwned<Tensor>::borrowed(to_expand2));
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
return std::make_tuple(
|
| 140 |
+
c10::MaybeOwned<Tensor>::owned(to_expand1.expand(tensor.sizes())),
|
| 141 |
+
c10::MaybeOwned<Tensor>::owned(to_expand2.expand(tensor.sizes())));
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 145 |
+
expand_inplace(
|
| 146 |
+
const Tensor& tensor,
|
| 147 |
+
Tensor&& to_expand1,
|
| 148 |
+
const Tensor& to_expand2) = delete;
|
| 149 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 150 |
+
expand_inplace(
|
| 151 |
+
const Tensor& tensor,
|
| 152 |
+
const Tensor& to_expand1,
|
| 153 |
+
Tensor&& to_expand2) = delete;
|
| 154 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 155 |
+
expand_inplace(const Tensor& tensor, Tensor&& to_expand1, Tensor&& to_expand2) =
|
| 156 |
+
delete;
|
| 157 |
+
|
| 158 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 159 |
+
expand_inplace(
|
| 160 |
+
const Tensor& tensor,
|
| 161 |
+
const Tensor& to_expand1,
|
| 162 |
+
const Tensor& to_expand2,
|
| 163 |
+
const char* api_name) {
|
| 164 |
+
check_defined({tensor, to_expand1, to_expand2}, api_name);
|
| 165 |
+
return expand_inplace(tensor, to_expand1, to_expand2);
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 169 |
+
expand_inplace(
|
| 170 |
+
const Tensor& tensor,
|
| 171 |
+
Tensor&& to_expand1,
|
| 172 |
+
const Tensor& to_expand2,
|
| 173 |
+
const char* api_name) = delete;
|
| 174 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 175 |
+
expand_inplace(
|
| 176 |
+
const Tensor& tensor,
|
| 177 |
+
const Tensor& to_expand1,
|
| 178 |
+
Tensor&& to_expand2,
|
| 179 |
+
const char* api_name) = delete;
|
| 180 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 181 |
+
expand_inplace(
|
| 182 |
+
const Tensor& tensor,
|
| 183 |
+
Tensor&& to_expand1,
|
| 184 |
+
Tensor&& to_expand2,
|
| 185 |
+
const char* api_name) = delete;
|
| 186 |
+
|
| 187 |
+
// See NOTE [ ExpandUtils Borrowing ] above for `MaybeOwned` explanation.
|
| 188 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 189 |
+
expand_outplace(const Tensor& to_expand1, const Tensor& to_expand2) {
|
| 190 |
+
auto s1 = to_expand1.sym_sizes();
|
| 191 |
+
auto s2 = to_expand2.sym_sizes();
|
| 192 |
+
if (s1.equals(s2)) {
|
| 193 |
+
return std::make_tuple(
|
| 194 |
+
c10::MaybeOwned<Tensor>::borrowed(to_expand1),
|
| 195 |
+
c10::MaybeOwned<Tensor>::borrowed(to_expand2));
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
auto expanded_size = infer_size_symdimvector(s1, s2);
|
| 199 |
+
return std::make_tuple(
|
| 200 |
+
c10::MaybeOwned<Tensor>::owned(to_expand1.expand_symint(expanded_size)),
|
| 201 |
+
c10::MaybeOwned<Tensor>::owned(to_expand2.expand_symint(expanded_size)));
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 205 |
+
expand_outplace(Tensor&& to_expand1, const Tensor& to_expand2) = delete;
|
| 206 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 207 |
+
expand_outplace(const Tensor& to_expand1, Tensor&& to_expand2) = delete;
|
| 208 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 209 |
+
expand_outplace(Tensor&& to_expand1, Tensor&& to_expand2) = delete;
|
| 210 |
+
|
| 211 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 212 |
+
expand_outplace(
|
| 213 |
+
const Tensor& to_expand1,
|
| 214 |
+
const Tensor& to_expand2,
|
| 215 |
+
const char* api_name) {
|
| 216 |
+
check_defined({to_expand1, to_expand2}, api_name);
|
| 217 |
+
return expand_outplace(to_expand1, to_expand2);
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 221 |
+
expand_outplace(
|
| 222 |
+
Tensor&& to_expand1,
|
| 223 |
+
const Tensor& to_expand2,
|
| 224 |
+
const char* api_name) = delete;
|
| 225 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 226 |
+
expand_outplace(
|
| 227 |
+
const Tensor& to_expand1,
|
| 228 |
+
Tensor&& to_expand2,
|
| 229 |
+
const char* api_name) = delete;
|
| 230 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 231 |
+
expand_outplace(
|
| 232 |
+
Tensor&& to_expand1,
|
| 233 |
+
Tensor&& to_expand2,
|
| 234 |
+
const char* api_name) = delete;
|
| 235 |
+
|
| 236 |
+
inline std::tuple<
|
| 237 |
+
c10::MaybeOwned<Tensor>,
|
| 238 |
+
c10::MaybeOwned<Tensor>,
|
| 239 |
+
c10::MaybeOwned<Tensor>>
|
| 240 |
+
expand_outplace(
|
| 241 |
+
const Tensor& to_expand1,
|
| 242 |
+
const Tensor& to_expand2,
|
| 243 |
+
const Tensor& to_expand3) {
|
| 244 |
+
if (to_expand1.sizes().equals(to_expand2.sizes()) &&
|
| 245 |
+
to_expand1.sizes().equals(to_expand3.sizes())) {
|
| 246 |
+
return std::make_tuple(
|
| 247 |
+
c10::MaybeOwned<Tensor>::borrowed(to_expand1),
|
| 248 |
+
c10::MaybeOwned<Tensor>::borrowed(to_expand2),
|
| 249 |
+
c10::MaybeOwned<Tensor>::borrowed(to_expand3));
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
auto expanded_size12 =
|
| 253 |
+
infer_size_dimvector(to_expand1.sizes(), to_expand2.sizes());
|
| 254 |
+
auto expanded_size =
|
| 255 |
+
infer_size_dimvector(expanded_size12, to_expand3.sizes());
|
| 256 |
+
return std::make_tuple(
|
| 257 |
+
c10::MaybeOwned<Tensor>::owned(to_expand1.expand(expanded_size)),
|
| 258 |
+
c10::MaybeOwned<Tensor>::owned(to_expand2.expand(expanded_size)),
|
| 259 |
+
c10::MaybeOwned<Tensor>::owned(to_expand3.expand(expanded_size)));
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
inline std::tuple<
|
| 263 |
+
c10::MaybeOwned<Tensor>,
|
| 264 |
+
c10::MaybeOwned<Tensor>,
|
| 265 |
+
c10::MaybeOwned<Tensor>>
|
| 266 |
+
expand_outplace(
|
| 267 |
+
Tensor&& to_expand1,
|
| 268 |
+
const Tensor& to_expand2,
|
| 269 |
+
const Tensor& to_expand3) = delete;
|
| 270 |
+
inline std::tuple<
|
| 271 |
+
c10::MaybeOwned<Tensor>,
|
| 272 |
+
c10::MaybeOwned<Tensor>,
|
| 273 |
+
c10::MaybeOwned<Tensor>>
|
| 274 |
+
expand_outplace(
|
| 275 |
+
const Tensor& to_expand1,
|
| 276 |
+
Tensor&& to_expand2,
|
| 277 |
+
const Tensor& to_expand3) = delete;
|
| 278 |
+
inline std::tuple<
|
| 279 |
+
c10::MaybeOwned<Tensor>,
|
| 280 |
+
c10::MaybeOwned<Tensor>,
|
| 281 |
+
c10::MaybeOwned<Tensor>>
|
| 282 |
+
expand_outplace(
|
| 283 |
+
Tensor&& to_expand1,
|
| 284 |
+
Tensor&& to_expand2,
|
| 285 |
+
const Tensor& to_expand3) = delete;
|
| 286 |
+
inline std::tuple<
|
| 287 |
+
c10::MaybeOwned<Tensor>,
|
| 288 |
+
c10::MaybeOwned<Tensor>,
|
| 289 |
+
c10::MaybeOwned<Tensor>>
|
| 290 |
+
expand_outplace(
|
| 291 |
+
const Tensor& to_expand1,
|
| 292 |
+
const Tensor& to_expand2,
|
| 293 |
+
Tensor&& to_expand3) = delete;
|
| 294 |
+
inline std::tuple<
|
| 295 |
+
c10::MaybeOwned<Tensor>,
|
| 296 |
+
c10::MaybeOwned<Tensor>,
|
| 297 |
+
c10::MaybeOwned<Tensor>>
|
| 298 |
+
expand_outplace(
|
| 299 |
+
Tensor&& to_expand1,
|
| 300 |
+
const Tensor& to_expand2,
|
| 301 |
+
Tensor&& to_expand3) = delete;
|
| 302 |
+
inline std::tuple<
|
| 303 |
+
c10::MaybeOwned<Tensor>,
|
| 304 |
+
c10::MaybeOwned<Tensor>,
|
| 305 |
+
c10::MaybeOwned<Tensor>>
|
| 306 |
+
expand_outplace(
|
| 307 |
+
const Tensor& to_expand1,
|
| 308 |
+
Tensor&& to_expand2,
|
| 309 |
+
Tensor&& to_expand3) = delete;
|
| 310 |
+
inline std::tuple<
|
| 311 |
+
c10::MaybeOwned<Tensor>,
|
| 312 |
+
c10::MaybeOwned<Tensor>,
|
| 313 |
+
c10::MaybeOwned<Tensor>>
|
| 314 |
+
expand_outplace(Tensor&& to_expand1, Tensor&& to_expand2, Tensor&& to_expand3) =
|
| 315 |
+
delete;
|
| 316 |
+
|
| 317 |
+
inline std::tuple<
|
| 318 |
+
c10::MaybeOwned<Tensor>,
|
| 319 |
+
c10::MaybeOwned<Tensor>,
|
| 320 |
+
c10::MaybeOwned<Tensor>>
|
| 321 |
+
expand_outplace(
|
| 322 |
+
const Tensor& to_expand1,
|
| 323 |
+
const Tensor& to_expand2,
|
| 324 |
+
const Tensor& to_expand3,
|
| 325 |
+
const char* api_name) {
|
| 326 |
+
check_defined({to_expand1, to_expand2, to_expand3}, api_name);
|
| 327 |
+
return expand_outplace(to_expand1, to_expand2, to_expand3);
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
inline std::tuple<
|
| 331 |
+
c10::MaybeOwned<Tensor>,
|
| 332 |
+
c10::MaybeOwned<Tensor>,
|
| 333 |
+
c10::MaybeOwned<Tensor>>
|
| 334 |
+
expand_outplace(
|
| 335 |
+
Tensor&& to_expand1,
|
| 336 |
+
const Tensor& to_expand2,
|
| 337 |
+
const Tensor& to_expand3,
|
| 338 |
+
const char* api_name) = delete;
|
| 339 |
+
inline std::tuple<
|
| 340 |
+
c10::MaybeOwned<Tensor>,
|
| 341 |
+
c10::MaybeOwned<Tensor>,
|
| 342 |
+
c10::MaybeOwned<Tensor>>
|
| 343 |
+
expand_outplace(
|
| 344 |
+
const Tensor& to_expand1,
|
| 345 |
+
Tensor&& to_expand2,
|
| 346 |
+
const Tensor& to_expand3,
|
| 347 |
+
const char* api_name) = delete;
|
| 348 |
+
inline std::tuple<
|
| 349 |
+
c10::MaybeOwned<Tensor>,
|
| 350 |
+
c10::MaybeOwned<Tensor>,
|
| 351 |
+
c10::MaybeOwned<Tensor>>
|
| 352 |
+
expand_outplace(
|
| 353 |
+
Tensor&& to_expand1,
|
| 354 |
+
Tensor&& to_expand2,
|
| 355 |
+
const Tensor& to_expand3,
|
| 356 |
+
const char* api_name) = delete;
|
| 357 |
+
inline std::tuple<
|
| 358 |
+
c10::MaybeOwned<Tensor>,
|
| 359 |
+
c10::MaybeOwned<Tensor>,
|
| 360 |
+
c10::MaybeOwned<Tensor>>
|
| 361 |
+
expand_outplace(
|
| 362 |
+
const Tensor& to_expand1,
|
| 363 |
+
const Tensor& to_expand2,
|
| 364 |
+
Tensor&& to_expand3,
|
| 365 |
+
const char* api_name) = delete;
|
| 366 |
+
inline std::tuple<
|
| 367 |
+
c10::MaybeOwned<Tensor>,
|
| 368 |
+
c10::MaybeOwned<Tensor>,
|
| 369 |
+
c10::MaybeOwned<Tensor>>
|
| 370 |
+
expand_outplace(
|
| 371 |
+
Tensor&& to_expand1,
|
| 372 |
+
const Tensor& to_expand2,
|
| 373 |
+
Tensor&& to_expand3,
|
| 374 |
+
const char* api_name) = delete;
|
| 375 |
+
inline std::tuple<
|
| 376 |
+
c10::MaybeOwned<Tensor>,
|
| 377 |
+
c10::MaybeOwned<Tensor>,
|
| 378 |
+
c10::MaybeOwned<Tensor>>
|
| 379 |
+
expand_outplace(
|
| 380 |
+
const Tensor& to_expand1,
|
| 381 |
+
Tensor&& to_expand2,
|
| 382 |
+
Tensor&& to_expand3,
|
| 383 |
+
const char* api_name) = delete;
|
| 384 |
+
inline std::tuple<
|
| 385 |
+
c10::MaybeOwned<Tensor>,
|
| 386 |
+
c10::MaybeOwned<Tensor>,
|
| 387 |
+
c10::MaybeOwned<Tensor>>
|
| 388 |
+
expand_outplace(
|
| 389 |
+
Tensor&& to_expand1,
|
| 390 |
+
Tensor&& to_expand2,
|
| 391 |
+
Tensor&& to_expand3,
|
| 392 |
+
const char* api_name) = delete;
|
| 393 |
+
|
| 394 |
+
inline c10::MaybeOwned<Tensor> expand_size(
|
| 395 |
+
const Tensor& to_expand,
|
| 396 |
+
IntArrayRef sizes) {
|
| 397 |
+
if (to_expand.sizes().equals(sizes)) {
|
| 398 |
+
return c10::MaybeOwned<Tensor>::borrowed(to_expand);
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
return c10::MaybeOwned<Tensor>::owned(to_expand.expand(sizes));
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
inline c10::MaybeOwned<Tensor> expand_size(
|
| 405 |
+
Tensor&& to_expand,
|
| 406 |
+
IntArrayRef sizes) = delete;
|
| 407 |
+
|
| 408 |
+
inline c10::MaybeOwned<Tensor> expand_size(
|
| 409 |
+
const Tensor& to_expand,
|
| 410 |
+
IntArrayRef sizes,
|
| 411 |
+
const char* api_name) {
|
| 412 |
+
check_defined({to_expand}, api_name);
|
| 413 |
+
return expand_size(to_expand, sizes);
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
inline c10::MaybeOwned<Tensor> expand_size(
|
| 417 |
+
Tensor&& to_expand,
|
| 418 |
+
IntArrayRef sizes,
|
| 419 |
+
const char* api_name) = delete;
|
| 420 |
+
|
| 421 |
+
inline std::vector<Tensor> expand_outplace(TensorList to_expand) {
|
| 422 |
+
// expands a list of Tensors; ignores undefined (null) tensors
|
| 423 |
+
bool first = true;
|
| 424 |
+
SymDimVector sizes;
|
| 425 |
+
for (const auto i : c10::irange(to_expand.size())) {
|
| 426 |
+
if (!to_expand[i].defined()) {
|
| 427 |
+
continue;
|
| 428 |
+
} else if (first) {
|
| 429 |
+
sizes = to_expand[i].sym_sizes();
|
| 430 |
+
first = false;
|
| 431 |
+
} else {
|
| 432 |
+
sizes = infer_size_symdimvector(sizes, to_expand[i].sym_sizes());
|
| 433 |
+
}
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
std::vector<Tensor> result(to_expand.size());
|
| 437 |
+
for (const auto i : c10::irange(to_expand.size())) {
|
| 438 |
+
if (!to_expand[i].defined()) {
|
| 439 |
+
continue;
|
| 440 |
+
} else if (to_expand[i].sym_sizes().equals(sizes)) {
|
| 441 |
+
result[i] = to_expand[i];
|
| 442 |
+
} else {
|
| 443 |
+
result[i] = to_expand[i].expand_symint(sizes);
|
| 444 |
+
}
|
| 445 |
+
}
|
| 446 |
+
return result;
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
template <typename T>
|
| 450 |
+
inline Tensor _sum_to(
|
| 451 |
+
Tensor tensor,
|
| 452 |
+
const c10::ArrayRef<T> shape,
|
| 453 |
+
bool always_return_non_view = false) {
|
| 454 |
+
if (shape.size() == 0) {
|
| 455 |
+
return tensor.sum();
|
| 456 |
+
}
|
| 457 |
+
|
| 458 |
+
auto sizes = at::symint::sizes<T>(tensor);
|
| 459 |
+
c10::SmallVector<int64_t, 8> reduce_dims;
|
| 460 |
+
const int64_t leading_dims = sizes.size() - shape.size();
|
| 461 |
+
for (const auto i : c10::irange(leading_dims)) {
|
| 462 |
+
reduce_dims.push_back(i);
|
| 463 |
+
}
|
| 464 |
+
for (int64_t i = leading_dims; i < static_cast<int64_t>(sizes.size()); ++i) {
|
| 465 |
+
if (TORCH_GUARD_OR_FALSE(sym_eq(shape[i - leading_dims], 1)) &&
|
| 466 |
+
TORCH_GUARD_OR_TRUE(sym_ne(sizes[i], 1))) {
|
| 467 |
+
reduce_dims.push_back(i);
|
| 468 |
+
} else {
|
| 469 |
+
// if we assume no reduction due to unbacked we ensure that at runtime.
|
| 470 |
+
TORCH_MAYBE_SYM_CHECK(
|
| 471 |
+
sym_eq(shape[i - leading_dims], sizes[i]),
|
| 472 |
+
"non-reduction path was assumed due to unbacked symbols expected those two sizes to be the same:",
|
| 473 |
+
shape[i - leading_dims],
|
| 474 |
+
", ",
|
| 475 |
+
sizes[i])
|
| 476 |
+
}
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
if (!reduce_dims.empty()) {
|
| 480 |
+
tensor = tensor.sum(reduce_dims, /*keepdim=*/true);
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
if (always_return_non_view) {
|
| 484 |
+
// This is only actually used by the functionalization pass.
|
| 485 |
+
// We want to be able to guarantee that this function doesn't return a view
|
| 486 |
+
// of the input.
|
| 487 |
+
return leading_dims > 0 ? at::symint::view_copy<T>(tensor, shape)
|
| 488 |
+
: tensor.clone();
|
| 489 |
+
} else {
|
| 490 |
+
return leading_dims > 0 ? at::symint::view<T>(tensor, shape) : tensor;
|
| 491 |
+
}
|
| 492 |
+
}
|
| 493 |
+
|
| 494 |
+
inline Tensor sum_to(
|
| 495 |
+
Tensor tensor,
|
| 496 |
+
const c10::SymIntArrayRef shape,
|
| 497 |
+
bool always_return_non_view = false) {
|
| 498 |
+
return _sum_to(std::move(tensor), shape, always_return_non_view);
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
// Sums `tensor` repeatedly to produce a tensor of shape `shape`.
|
| 502 |
+
// Precondition: is_expandable_to(shape, tensor.sizes()) must be true
|
| 503 |
+
inline Tensor sum_to(
|
| 504 |
+
Tensor tensor,
|
| 505 |
+
const IntArrayRef shape,
|
| 506 |
+
bool always_return_non_view = false) {
|
| 507 |
+
return _sum_to(std::move(tensor), shape, always_return_non_view);
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
inline bool is_expandable_to(
|
| 511 |
+
SymIntArrayRef shape,
|
| 512 |
+
c10::SymIntArrayRef desired) {
|
| 513 |
+
size_t ndim = shape.size();
|
| 514 |
+
size_t target_dim = desired.size();
|
| 515 |
+
if (ndim > target_dim) {
|
| 516 |
+
return false;
|
| 517 |
+
}
|
| 518 |
+
for (const auto i : c10::irange(ndim)) {
|
| 519 |
+
const auto& size = shape[ndim - i - 1];
|
| 520 |
+
const auto& target = desired[target_dim - i - 1];
|
| 521 |
+
if (size != target && size != 1) {
|
| 522 |
+
return false;
|
| 523 |
+
}
|
| 524 |
+
}
|
| 525 |
+
return true;
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
inline bool is_expandable_to(IntArrayRef shape, IntArrayRef desired) {
|
| 529 |
+
auto sym_shape = c10::SymIntArrayRef(
|
| 530 |
+
reinterpret_cast<const c10::SymInt*>(shape.data()), shape.size());
|
| 531 |
+
auto sym_desired = c10::SymIntArrayRef(
|
| 532 |
+
reinterpret_cast<const c10::SymInt*>(desired.data()), desired.size());
|
| 533 |
+
return is_expandable_to(sym_shape, sym_desired);
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
} // namespace at
|
| 537 |
+
|
| 538 |
+
#else
|
| 539 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 540 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/FunctionalTensorWrapper.h
ADDED
|
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <ATen/ArrayRef.h>
|
| 6 |
+
#include <ATen/FunctionalStorageImpl.h>
|
| 7 |
+
#include <ATen/core/IListRef.h>
|
| 8 |
+
#include <ATen/core/List.h>
|
| 9 |
+
#include <ATen/core/boxing/BoxedKernel.h>
|
| 10 |
+
#include <ATen/core/boxing/impl/boxing.h>
|
| 11 |
+
#include <ATen/core/dispatch/Dispatcher.h>
|
| 12 |
+
|
| 13 |
+
#include <c10/core/DispatchKey.h>
|
| 14 |
+
|
| 15 |
+
namespace at {
|
| 16 |
+
|
| 17 |
+
// Note [Functionalization Pass In Core]
|
| 18 |
+
// The Functionalization pass is used to remove aliasing from a pytorch program.
|
| 19 |
+
//
|
| 20 |
+
// This is useful for backends that don't support aliasing, like XLA and Vulkan.
|
| 21 |
+
// It's also necessary in order to remove mutation from a program, which is
|
| 22 |
+
// needed in Functorch.
|
| 23 |
+
//
|
| 24 |
+
// Consider this program:
|
| 25 |
+
// a = torch.ones(...)
|
| 26 |
+
// b = a.view(...)
|
| 27 |
+
// b.add_(1)
|
| 28 |
+
//
|
| 29 |
+
// In this program, b is meant to alias with a due to the use of view(). At the
|
| 30 |
+
// end of the program, both a and b are full of 2's. However, backends that
|
| 31 |
+
// don't support aliasing aren't able to correctly implement the view()
|
| 32 |
+
// operator. Instead, they can opt into the Functionalization pass, which will
|
| 33 |
+
// sit between the user and the backend, and provide the necessary aliasing
|
| 34 |
+
// logic.
|
| 35 |
+
//
|
| 36 |
+
// The functionalization pass will turn the above program into a slightly
|
| 37 |
+
// different program that has the same semantics, transparently to the user,
|
| 38 |
+
// that backends like XLA/Vulkan are able to implement a = torch.ones(...) b =
|
| 39 |
+
// a.view_copy(...) # view() replaced with view_copy(). Backends like
|
| 40 |
+
// XLA/Vulkan can implement this! b.add_(1) a.add_(1) # Our functionalization
|
| 41 |
+
// pass machinery knows that a and b are aliased - it applies b's mutation to a
|
| 42 |
+
// too.
|
| 43 |
+
//
|
| 44 |
+
// So, how does the functionalization pass keep track of which tensors are
|
| 45 |
+
// aliased? The pass works by wrapping EVERY tensor in the program inside of a
|
| 46 |
+
// FunctionalTensorWrapper, which knows about its alias'd tensors.
|
| 47 |
+
//
|
| 48 |
+
// See Note [Functionalization: Alias Removal] for details on the aliasing
|
| 49 |
+
// machinery. See Note [Functionalization: Mutation Removal] for details on
|
| 50 |
+
// mutation removal.
|
| 51 |
+
struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
|
| 52 |
+
explicit FunctionalTensorWrapper(const Tensor& value);
|
| 53 |
+
// Additional constructor to create a FunctionalTensorWrapper directly from an
|
| 54 |
+
// underlying tensor that was created from a view. For example, the code b =
|
| 55 |
+
// a.view1() will generate a constructor call to FunctionalTensorWrapper(b, a,
|
| 56 |
+
// view1_meta)
|
| 57 |
+
explicit FunctionalTensorWrapper(
|
| 58 |
+
const Tensor& view_value,
|
| 59 |
+
const FunctionalTensorWrapper* base,
|
| 60 |
+
const std::shared_ptr<functionalization::ViewMeta>& meta);
|
| 61 |
+
|
| 62 |
+
// Get the underlying, actual tensor, that doesn't know anything about
|
| 63 |
+
// functionalization.
|
| 64 |
+
const Tensor& value() const {
|
| 65 |
+
return value_;
|
| 66 |
+
}
|
| 67 |
+
// The concept of "level" is only ever important to functorch; it's exposed
|
| 68 |
+
// here as more of a hook for functorch to use.
|
| 69 |
+
int64_t level() const {
|
| 70 |
+
return level_;
|
| 71 |
+
}
|
| 72 |
+
void set_level(int64_t level) {
|
| 73 |
+
level_ = level;
|
| 74 |
+
}
|
| 75 |
+
bool has_metadata_mutation() const {
|
| 76 |
+
return has_metadata_mutation_;
|
| 77 |
+
}
|
| 78 |
+
uint64_t mutation_counter() const {
|
| 79 |
+
return functional_storage_impl()->mutation_counter();
|
| 80 |
+
}
|
| 81 |
+
void mark_mutation() {
|
| 82 |
+
functional_storage_impl()->mark_mutation();
|
| 83 |
+
}
|
| 84 |
+
// Denotes a mutation that's hidden from autograd,
|
| 85 |
+
// e.g. for the purposes of passing a tensor to a triton kernel
|
| 86 |
+
void mark_mutation_hidden_from_autograd() {
|
| 87 |
+
functional_storage_impl()->mark_mutation_hidden_from_autograd();
|
| 88 |
+
}
|
| 89 |
+
void mark_mutation_during_no_grad_or_inference_mode() {
|
| 90 |
+
functional_storage_impl()->mark_mutation_during_no_grad_or_inference_mode();
|
| 91 |
+
}
|
| 92 |
+
// Are all the mutations happening to the tensor hidden from autograd
|
| 93 |
+
bool are_all_mutations_hidden_from_autograd() const {
|
| 94 |
+
return functional_storage_impl()->are_all_mutations_hidden_from_autograd();
|
| 95 |
+
}
|
| 96 |
+
// Did all mutations happen under no_grad or inference_mode
|
| 97 |
+
// (We also need to ignore mutations fully hidden from autograd here)
|
| 98 |
+
bool are_all_mutations_under_no_grad_or_inference_mode() const {
|
| 99 |
+
return functional_storage_impl()
|
| 100 |
+
->are_all_mutations_under_no_grad_or_inference_mode();
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
void maybe_mark_symbolic(functionalization::ViewMeta* meta) {
|
| 104 |
+
is_symbolic_ = is_symbolic_ | meta->has_symbolic_inputs;
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
bool is_symbolic() const {
|
| 108 |
+
return is_symbolic_;
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
// Retrieves the ViewMeta sequence of this tensor.
|
| 112 |
+
const std::vector<std::shared_ptr<functionalization::ViewMeta>>& view_metas()
|
| 113 |
+
const;
|
| 114 |
+
|
| 115 |
+
// Sync's the underlying tensor with its alias, if it's out of date. This
|
| 116 |
+
// involves two steps: 1) Apply any pending updates/mutations to the alias 2)
|
| 117 |
+
// Replay the views (if any) to regenerate the current tensor off of the
|
| 118 |
+
// updated alias.
|
| 119 |
+
void sync_();
|
| 120 |
+
// Performs step (1) of the sync. This is its own public API because it's
|
| 121 |
+
// needed by view_inplace ops like transpose_. See Note [Functionalization
|
| 122 |
+
// Pass - Inplace View Ops]
|
| 123 |
+
void regenerate_from_base();
|
| 124 |
+
// Performs step (2) of the sync. This is its own public API because it's
|
| 125 |
+
// needed by functorch. functorch wants to make sure that all input tensors to
|
| 126 |
+
// a functionalized program have been properly synced so it can properly
|
| 127 |
+
// propagate mutations to inputs. It can't just call sync_(), because the
|
| 128 |
+
// FunctionalTensorWrapper will look like it has no aliases and sync_ will be
|
| 129 |
+
// a noop. We use the reference count on storage_ to determine if the wrapper
|
| 130 |
+
// is aliased, and by the time functorch is ready to propagate updates to
|
| 131 |
+
// inputs, any intermediate views of the input created by the program will
|
| 132 |
+
// have been deallocated. This function also returns whether or not the base
|
| 133 |
+
// actually had any updates to apply.
|
| 134 |
+
bool apply_updates();
|
| 135 |
+
// Takes the current state of value_ and snapshots it, sending it as a pending
|
| 136 |
+
// update to the alias.
|
| 137 |
+
void commit_update();
|
| 138 |
+
// When any tensor is mutated, the tensor increments its alias's "generation".
|
| 139 |
+
// Separately, each tensor maintains its own "generation" counter, which is
|
| 140 |
+
// used to determine if it's up-to-date with its alias. The act of syncing a
|
| 141 |
+
// tensor will set a tensor's generation equal to its alias's generation.
|
| 142 |
+
bool is_up_to_date() const;
|
| 143 |
+
// Freezes the storage of this tensor, preventing subsequent mutations
|
| 144 |
+
void freeze_storage() const;
|
| 145 |
+
// Every FunctionalTensorWrapper contains a vector<ViewMeta> objects
|
| 146 |
+
// describing the series of view ops that ran to generate the current tensor
|
| 147 |
+
// from the base tensor. This method is used by inplace-view ops like
|
| 148 |
+
// transpose_. It appends a ViewMeta to the existing stack, and refreshes the
|
| 149 |
+
// tensor by replaying the views off of the alias.
|
| 150 |
+
void mutate_view_meta(
|
| 151 |
+
const std::shared_ptr<at::functionalization::ViewMeta>& meta);
|
| 152 |
+
|
| 153 |
+
// Custom implementation of self.set_(src)
|
| 154 |
+
void set__impl(const FunctionalTensorWrapper* other);
|
| 155 |
+
|
| 156 |
+
// Custom implementation of resize_storage_bytes_(self, new_size)
|
| 157 |
+
void storage_resize_(const c10::SymInt& new_size);
|
| 158 |
+
|
| 159 |
+
// Returns whether the current tensor's data was ever mutated
|
| 160 |
+
bool has_data_mutation();
|
| 161 |
+
//
|
| 162 |
+
// Returns whether the current FunctionalTensorWrapper
|
| 163 |
+
// experienced a set_() call.
|
| 164 |
+
bool was_storage_changed() {
|
| 165 |
+
return was_storage_changed_;
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
void mark_storage_changed() {
|
| 169 |
+
was_storage_changed_ = true;
|
| 170 |
+
storage_changed_counter_++;
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
uint64_t storage_changed_counter() {
|
| 174 |
+
return storage_changed_counter_;
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
// A FunctionalTensor is considered a base if its not a view of another
|
| 178 |
+
// tensor.
|
| 179 |
+
bool isBaseTensor() const {
|
| 180 |
+
return view_metas_.empty();
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
c10::SymInt get_storage_size(bool before) {
|
| 184 |
+
return functional_storage_impl()->get_storage_size(before);
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
// Returns whether the FunctionalTensor experienced an
|
| 188 |
+
// untyped_storage().resize_() call
|
| 189 |
+
bool was_inductor_storage_resized() {
|
| 190 |
+
return functional_storage_impl()->was_inductor_storage_resized();
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
bool inductor_storage_resized_counter() {
|
| 194 |
+
return functional_storage_impl()->inductor_storage_resized_counter();
|
| 195 |
+
}
|
| 196 |
+
// The functionalization pass can be used to remove mutations.
|
| 197 |
+
// It does so by replacing any mutation op with it's corresponding
|
| 198 |
+
// out-of-place op, followed by a call to replace_(). e.g:
|
| 199 |
+
//
|
| 200 |
+
// a.add_(1)
|
| 201 |
+
//
|
| 202 |
+
// will turn into:
|
| 203 |
+
//
|
| 204 |
+
// tmp = a.add(1)
|
| 205 |
+
// a.replace_(tmp)
|
| 206 |
+
//
|
| 207 |
+
// replace_() swaps out the wrapped tensor, value_, with tmp.
|
| 208 |
+
void replace_(const Tensor& other, bool from_lazy_regenerate = false);
|
| 209 |
+
|
| 210 |
+
bool is_multi_output_view() {
|
| 211 |
+
return is_multi_output_view_;
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
// See Note[resize_() in functionalization pass]
|
| 215 |
+
void maybe_replace_storage(const Tensor& other);
|
| 216 |
+
|
| 217 |
+
// Replaces the storage with a new functional storage,
|
| 218 |
+
// and clears the view_metas_ stack.
|
| 219 |
+
// WARNING: Calling this function will sever the aliasing relationship between
|
| 220 |
+
// the current FunctionalTensorWrapper and any of its outstanding aliases.
|
| 221 |
+
// Please only call if you know what you're doing.
|
| 222 |
+
void _unsafe_reset_storage();
|
| 223 |
+
|
| 224 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 225 |
+
const c10::VariableVersion& version_counter,
|
| 226 |
+
bool allow_tensor_metadata_change) const override;
|
| 227 |
+
|
| 228 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 229 |
+
c10::VariableVersion&& version_counter,
|
| 230 |
+
bool allow_tensor_metadata_change) const override;
|
| 231 |
+
|
| 232 |
+
~FunctionalTensorWrapper() override = default;
|
| 233 |
+
|
| 234 |
+
// FunctionalTensorWrapper overrides all custom size/stride function,
|
| 235 |
+
// so that if the inner tensor has a custom implementation
|
| 236 |
+
// we make sure to call that implementation.
|
| 237 |
+
at::IntArrayRef sizes_custom() const override;
|
| 238 |
+
at::IntArrayRef strides_custom() const override;
|
| 239 |
+
int64_t dim_custom() const override;
|
| 240 |
+
int64_t numel_custom() const override;
|
| 241 |
+
c10::SymBool sym_is_contiguous_custom(
|
| 242 |
+
at::MemoryFormat memory_format) const override;
|
| 243 |
+
c10::SymIntArrayRef sym_sizes_custom() const override;
|
| 244 |
+
c10::SymInt sym_size_custom(int64_t d) const override;
|
| 245 |
+
c10::SymIntArrayRef sym_strides_custom() const override;
|
| 246 |
+
c10::SymInt sym_storage_offset_custom() const override;
|
| 247 |
+
c10::Device device_custom() const override;
|
| 248 |
+
c10::Layout layout_impl() const override;
|
| 249 |
+
|
| 250 |
+
private:
|
| 251 |
+
const char* tensorimpl_type_name() const override;
|
| 252 |
+
void set_constructor_metadata();
|
| 253 |
+
functionalization::FunctionalStorageImpl* functional_storage_impl() const;
|
| 254 |
+
|
| 255 |
+
// This is used to re-implement shallow_copy_and_detach for
|
| 256 |
+
// FunctionalTensorWrapper. The implementation is identical, but we just need
|
| 257 |
+
// to return a subclass instead of a plain TensorImpl.
|
| 258 |
+
// TODO: maybe it's possible to arrange for that to happen automatically
|
| 259 |
+
// without an override here?
|
| 260 |
+
template <typename VariableVersion>
|
| 261 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
|
| 262 |
+
VariableVersion&& version_counter,
|
| 263 |
+
bool allow_tensor_metadata_change) const;
|
| 264 |
+
|
| 265 |
+
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
|
| 266 |
+
void copy_tensor_metadata_and_refresh(
|
| 267 |
+
const FunctionalTensorWrapper* src_impl,
|
| 268 |
+
FunctionalTensorWrapper* dest_impl,
|
| 269 |
+
const c10::VariableVersion& version_counter,
|
| 270 |
+
bool allow_tensor_metadata_change) const;
|
| 271 |
+
|
| 272 |
+
// Note that value is not taken by reference: internally, the wrapper will
|
| 273 |
+
// change the value tensor that it points to over time.
|
| 274 |
+
Tensor value_;
|
| 275 |
+
int64_t level_{};
|
| 276 |
+
// These two counters are used for identifying
|
| 277 |
+
// whether all the mutations on a given tensor are hidden from autograd or
|
| 278 |
+
// not. If we have an input mutation that is hidden from autograd, then once
|
| 279 |
+
// we convert the input mutation to a copy_() we know it will be safe to hide
|
| 280 |
+
// the copy_() from autograd as well.
|
| 281 |
+
bool has_metadata_mutation_ = false;
|
| 282 |
+
bool is_multi_output_view_ = false;
|
| 283 |
+
// Did the tensor experience a set_() call.
|
| 284 |
+
bool was_storage_changed_ = false;
|
| 285 |
+
uint64_t storage_changed_counter_ = 0;
|
| 286 |
+
// Did the tensor experience any view operation with symbolic int.
|
| 287 |
+
bool is_symbolic_ = false;
|
| 288 |
+
|
| 289 |
+
size_t generation_ = 0;
|
| 290 |
+
std::vector<std::shared_ptr<at::functionalization::ViewMeta>> view_metas_;
|
| 291 |
+
|
| 292 |
+
protected:
|
| 293 |
+
static void copy_tensor_metadata(
|
| 294 |
+
const FunctionalTensorWrapper* src_impl,
|
| 295 |
+
FunctionalTensorWrapper* dest_impl,
|
| 296 |
+
const c10::VariableVersion& version_counter,
|
| 297 |
+
bool allow_tensor_metadata_change);
|
| 298 |
+
};
|
| 299 |
+
|
| 300 |
+
// Utility functions for the functionalization pass.
|
| 301 |
+
|
| 302 |
+
namespace functionalization {
|
| 303 |
+
namespace impl {
|
| 304 |
+
|
| 305 |
+
inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper(
|
| 306 |
+
const Tensor& tensor) {
|
| 307 |
+
auto functional_impl =
|
| 308 |
+
static_cast<FunctionalTensorWrapper*>(tensor.unsafeGetTensorImpl());
|
| 309 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_impl != nullptr);
|
| 310 |
+
return functional_impl;
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
TORCH_API bool isBaseTensor(const at::Tensor& tensor);
|
| 314 |
+
|
| 315 |
+
TORCH_API bool isFunctionalTensor(const at::Tensor& tensor);
|
| 316 |
+
TORCH_API bool isFunctionalTensor(const std::optional<Tensor>& t);
|
| 317 |
+
TORCH_API bool isFunctionalTensor(
|
| 318 |
+
const c10::List<std::optional<Tensor>>& t_list);
|
| 319 |
+
TORCH_API bool isFunctionalTensor(ITensorListRef list);
|
| 320 |
+
|
| 321 |
+
TORCH_API Tensor to_functional_tensor(const Tensor& tensor);
|
| 322 |
+
TORCH_API std::optional<Tensor> to_functional_tensor(
|
| 323 |
+
const std::optional<Tensor>& tensor);
|
| 324 |
+
TORCH_API c10::List<std::optional<Tensor>> to_functional_tensor(
|
| 325 |
+
const c10::List<std::optional<Tensor>>& t_list);
|
| 326 |
+
TORCH_API std::vector<Tensor> to_functional_tensor(ITensorListRef t_list);
|
| 327 |
+
|
| 328 |
+
TORCH_API void freeze_functional_tensor(const Tensor& tensor);
|
| 329 |
+
|
| 330 |
+
TORCH_API Tensor
|
| 331 |
+
from_functional_tensor(const Tensor& tensor, bool assert_functional = true);
|
| 332 |
+
TORCH_API std::optional<Tensor> from_functional_tensor(
|
| 333 |
+
const std::optional<Tensor>& t,
|
| 334 |
+
bool assert_functional = true);
|
| 335 |
+
TORCH_API c10::List<std::optional<Tensor>> from_functional_tensor(
|
| 336 |
+
const c10::List<std::optional<Tensor>>& t_list);
|
| 337 |
+
TORCH_API std::vector<Tensor> from_functional_tensor(ITensorListRef t_list);
|
| 338 |
+
|
| 339 |
+
TORCH_API void sync(const at::Tensor& t);
|
| 340 |
+
TORCH_API void sync(const std::optional<Tensor>& t);
|
| 341 |
+
TORCH_API void sync(const c10::List<std::optional<Tensor>>& t_list);
|
| 342 |
+
TORCH_API void sync(ITensorListRef t_list);
|
| 343 |
+
|
| 344 |
+
TORCH_API void replace_(const Tensor& functional_tensor, const Tensor& other);
|
| 345 |
+
TORCH_API void replace_(
|
| 346 |
+
const ITensorListRef functional_tensor,
|
| 347 |
+
ITensorListRef other);
|
| 348 |
+
|
| 349 |
+
TORCH_API void commit_update(const Tensor& functional_tensor);
|
| 350 |
+
TORCH_API void commit_update(ITensorListRef functional_tensor);
|
| 351 |
+
|
| 352 |
+
TORCH_API void unsafe_reset_storage(const Tensor& functional_tensor);
|
| 353 |
+
|
| 354 |
+
TORCH_API void mark_mutation_hidden_from_autograd(
|
| 355 |
+
const Tensor& functional_tensor);
|
| 356 |
+
|
| 357 |
+
TORCH_API bool are_all_mutations_hidden_from_autograd(
|
| 358 |
+
const Tensor& functional_tensor);
|
| 359 |
+
|
| 360 |
+
TORCH_API bool are_all_mutations_under_no_grad_or_inference_mode(
|
| 361 |
+
const Tensor& functional_tensor);
|
| 362 |
+
|
| 363 |
+
// These two methods are XLA-specific logic and are no-ops
|
| 364 |
+
// for the normal functionalization flow.
|
| 365 |
+
TORCH_API void propagate_xla_data(
|
| 366 |
+
const Tensor& functional_tensor,
|
| 367 |
+
const Tensor& other);
|
| 368 |
+
TORCH_API void propagate_xla_data(
|
| 369 |
+
const ITensorListRef functional_tensor,
|
| 370 |
+
ITensorListRef other);
|
| 371 |
+
|
| 372 |
+
TORCH_API void propagate_xla_data_direct(
|
| 373 |
+
const Tensor& tensor,
|
| 374 |
+
const Tensor& other);
|
| 375 |
+
TORCH_API void propagate_xla_data_direct(
|
| 376 |
+
const ITensorListRef tensor,
|
| 377 |
+
ITensorListRef other);
|
| 378 |
+
|
| 379 |
+
Tensor create_functional_tensor_with_view_meta(
|
| 380 |
+
const Tensor& view_to_wrap,
|
| 381 |
+
const Tensor& base,
|
| 382 |
+
const std::shared_ptr<functionalization::ViewMeta>& meta,
|
| 383 |
+
int64_t out_idx = 0);
|
| 384 |
+
std::vector<Tensor> create_functional_tensor_with_view_meta(
|
| 385 |
+
ITensorListRef view_to_wrap,
|
| 386 |
+
const Tensor& base,
|
| 387 |
+
const std::shared_ptr<functionalization::ViewMeta>& meta);
|
| 388 |
+
|
| 389 |
+
void mutate_view_meta(
|
| 390 |
+
const Tensor& self,
|
| 391 |
+
const std::shared_ptr<functionalization::ViewMeta>& meta);
|
| 392 |
+
|
| 393 |
+
TORCH_API Tensor apply_view_meta_sequence(
|
| 394 |
+
const Tensor& base,
|
| 395 |
+
const std::vector<std::shared_ptr<functionalization::ViewMeta>>& sequence);
|
| 396 |
+
|
| 397 |
+
void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out);
|
| 398 |
+
void set_sizes_strides_offset(
|
| 399 |
+
const std::vector<Tensor>& outs,
|
| 400 |
+
const std::vector<Tensor>& meta_outs);
|
| 401 |
+
|
| 402 |
+
// ~~~~~ TLS used in functionalization ~~~~~
|
| 403 |
+
|
| 404 |
+
TORCH_API bool getFunctionalizationReapplyViewsTLS();
|
| 405 |
+
TORCH_API void setFunctionalizationReapplyViewsTLS(bool reapply_views);
|
| 406 |
+
|
| 407 |
+
class TORCH_API FunctionalizationReapplyViewsGuard {
|
| 408 |
+
public:
|
| 409 |
+
FunctionalizationReapplyViewsGuard(bool reapply_views)
|
| 410 |
+
: prev_(getFunctionalizationReapplyViewsTLS()) {
|
| 411 |
+
setFunctionalizationReapplyViewsTLS(reapply_views);
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
~FunctionalizationReapplyViewsGuard() {
|
| 415 |
+
setFunctionalizationReapplyViewsTLS(prev_);
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
FunctionalizationReapplyViewsGuard(
|
| 419 |
+
const FunctionalizationReapplyViewsGuard&) = delete;
|
| 420 |
+
FunctionalizationReapplyViewsGuard operator=(
|
| 421 |
+
const FunctionalizationReapplyViewsGuard&) = delete;
|
| 422 |
+
FunctionalizationReapplyViewsGuard(FunctionalizationReapplyViewsGuard&&) =
|
| 423 |
+
delete;
|
| 424 |
+
FunctionalizationReapplyViewsGuard operator=(
|
| 425 |
+
FunctionalizationReapplyViewsGuard&&) = delete;
|
| 426 |
+
|
| 427 |
+
private:
|
| 428 |
+
bool prev_;
|
| 429 |
+
};
|
| 430 |
+
|
| 431 |
+
} // namespace impl
|
| 432 |
+
|
| 433 |
+
// Helper function to call an out-of-place composite aten kernel that may use
|
| 434 |
+
// mutations / views internally, and functionalize them.
|
| 435 |
+
TORCH_API void functionalize_op_helper(
|
| 436 |
+
const c10::OperatorHandle& op,
|
| 437 |
+
torch::jit::Stack* stack);
|
| 438 |
+
|
| 439 |
+
template <class Op, bool symint, class ReturnType, class... ParameterTypes>
|
| 440 |
+
struct _functionalize_aten_op final {};
|
| 441 |
+
|
| 442 |
+
template <class Op, bool symint, class ReturnType, class... ParameterTypes>
|
| 443 |
+
struct _functionalize_aten_op<Op, symint, ReturnType(ParameterTypes...)> final {
|
| 444 |
+
static ReturnType call(
|
| 445 |
+
typename c10::maybe_keep_symint<symint, ParameterTypes>::type... args) {
|
| 446 |
+
using FuncType = ReturnType(
|
| 447 |
+
typename c10::maybe_keep_symint<symint, ParameterTypes>::type...);
|
| 448 |
+
auto op = c10::Dispatcher::singleton()
|
| 449 |
+
.findSchemaOrThrow(
|
| 450 |
+
(const char*)Op::name, (const char*)Op::overload_name)
|
| 451 |
+
.typed<FuncType>();
|
| 452 |
+
|
| 453 |
+
return c10::impl::BoxedKernelWrapper<FuncType>::call(
|
| 454 |
+
c10::BoxedKernel::makeFromFunction<functionalize_op_helper>(),
|
| 455 |
+
op,
|
| 456 |
+
// BoxedKernelWrapper knows to ignore this keyset argument,
|
| 457 |
+
// because functionalize_op_helper doesn't take in a DispatchKeySet
|
| 458 |
+
c10::DispatchKeySet(),
|
| 459 |
+
args...);
|
| 460 |
+
}
|
| 461 |
+
};
|
| 462 |
+
|
| 463 |
+
template <class Op>
|
| 464 |
+
using functionalize_aten_op =
|
| 465 |
+
_functionalize_aten_op<Op, false, typename Op::schema>;
|
| 466 |
+
|
| 467 |
+
template <class Op>
|
| 468 |
+
using functionalize_aten_op_symint =
|
| 469 |
+
_functionalize_aten_op<Op, true, typename Op::schema>;
|
| 470 |
+
|
| 471 |
+
} // namespace functionalization
|
| 472 |
+
} // namespace at
|
| 473 |
+
|
| 474 |
+
#else
|
| 475 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 476 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/Functions.h
ADDED
|
@@ -0,0 +1,1476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
// @generated by torchgen/gen.py from Functions.h
|
| 5 |
+
|
| 6 |
+
#ifdef TORCH_ASSERT_NO_OPERATORS
|
| 7 |
+
#error This change adds a dependency on native_functions.yaml, \
|
| 8 |
+
meaning the file will need to be re-compiled every time an operator \
|
| 9 |
+
is changed or added. Consider if your change would be better placed in \
|
| 10 |
+
another file, or if a more specific header might achieve the same goal. \
|
| 11 |
+
See NOTE: [Tensor vs. TensorBase]
|
| 12 |
+
#endif
|
| 13 |
+
|
| 14 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 15 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 16 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 17 |
+
Consider including a specific operator from <ATen/ops/{my_operator}.h> and \
|
| 18 |
+
see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 19 |
+
#endif
|
| 20 |
+
|
| 21 |
+
// NOTE: [TORCH_ASSERT_ONLY_METHOD_OPERATORS]
|
| 22 |
+
//
|
| 23 |
+
// In ATen, certain generated headers files include the definitions of
|
| 24 |
+
// every single operator in PyTorch. Unfortunately this means every
|
| 25 |
+
// time an operator signature is updated or changed in
|
| 26 |
+
// native_functions.yaml, you (and every other PyTorch developer) need
|
| 27 |
+
// to recompile every source file that includes any of these headers.
|
| 28 |
+
//
|
| 29 |
+
// To break up these header dependencies, and improve incremental
|
| 30 |
+
// build times for all PyTorch developers. These headers are split
|
| 31 |
+
// into per-operator headers in the `ATen/ops` folder. This limits
|
| 32 |
+
// incremental builds to only changes to methods of `Tensor`, or files
|
| 33 |
+
// that use the specific operator being changed. With `at::sum` as an
|
| 34 |
+
// example, you should include
|
| 35 |
+
//
|
| 36 |
+
// <ATen/ops/sum.h> // instead of ATen/Functions.h
|
| 37 |
+
// <ATen/ops/sum_native.h> // instead of ATen/NativeFunctions.h
|
| 38 |
+
// <ATen/ops/sum_ops.h> // instead of ATen/Operators.h
|
| 39 |
+
// <ATen/ops/sum_cpu_dispatch.h> // instead of ATen/CPUFunctions.h
|
| 40 |
+
//
|
| 41 |
+
// However, even if you're careful to use this in your own code.
|
| 42 |
+
// `Functions.h` might be included indirectly through another header
|
| 43 |
+
// without you realising. To avoid this, you can add
|
| 44 |
+
//
|
| 45 |
+
// #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
| 46 |
+
//
|
| 47 |
+
// to the top of your source file. This way any time the non-specific
|
| 48 |
+
// headers are included, the compiler will error out.
|
| 49 |
+
//
|
| 50 |
+
// Also, be aware that `ops` are not available in all build
|
| 51 |
+
// configurations (namely fb-internal) so you must guard these
|
| 52 |
+
// includes with `#ifdef AT_PER_OPERATOR_HEADERS`. e.g.
|
| 53 |
+
//
|
| 54 |
+
// #ifndef AT_PER_OPERATOR_HEADERS
|
| 55 |
+
// #include <ATen/Functions.h>
|
| 56 |
+
// #else
|
| 57 |
+
// #include <ATen/ops/sum.h>
|
| 58 |
+
// #endif
|
| 59 |
+
|
| 60 |
+
#include <ATen/Context.h>
|
| 61 |
+
#include <ATen/DeviceGuard.h>
|
| 62 |
+
#include <ATen/TensorUtils.h>
|
| 63 |
+
#include <ATen/TracerMode.h>
|
| 64 |
+
#include <ATen/core/Generator.h>
|
| 65 |
+
#include <ATen/core/Reduction.h>
|
| 66 |
+
#include <c10/core/SymInt.h>
|
| 67 |
+
#include <ATen/core/Tensor.h>
|
| 68 |
+
#include <c10/core/Scalar.h>
|
| 69 |
+
#include <c10/core/Storage.h>
|
| 70 |
+
#include <c10/core/TensorOptions.h>
|
| 71 |
+
#include <c10/util/Deprecated.h>
|
| 72 |
+
#include <optional>
|
| 73 |
+
#include <c10/util/OptionalArrayRef.h>
|
| 74 |
+
|
| 75 |
+
#include <ATen/ops/from_blob.h>
|
| 76 |
+
#include <ATen/ops/tensor.h>
|
| 77 |
+
|
| 78 |
+
#include <ATen/ops/_adaptive_avg_pool2d.h>
|
| 79 |
+
#include <ATen/ops/_adaptive_avg_pool2d_backward.h>
|
| 80 |
+
#include <ATen/ops/_adaptive_avg_pool3d.h>
|
| 81 |
+
#include <ATen/ops/_adaptive_avg_pool3d_backward.h>
|
| 82 |
+
#include <ATen/ops/_add_batch_dim.h>
|
| 83 |
+
#include <ATen/ops/_add_relu.h>
|
| 84 |
+
#include <ATen/ops/_addmm_activation.h>
|
| 85 |
+
#include <ATen/ops/_aminmax.h>
|
| 86 |
+
#include <ATen/ops/_amp_foreach_non_finite_check_and_unscale.h>
|
| 87 |
+
#include <ATen/ops/_amp_update_scale.h>
|
| 88 |
+
#include <ATen/ops/_assert_async.h>
|
| 89 |
+
#include <ATen/ops/_assert_scalar.h>
|
| 90 |
+
#include <ATen/ops/_assert_tensor_metadata.h>
|
| 91 |
+
#include <ATen/ops/_autocast_to_full_precision.h>
|
| 92 |
+
#include <ATen/ops/_autocast_to_reduced_precision.h>
|
| 93 |
+
#include <ATen/ops/_backward.h>
|
| 94 |
+
#include <ATen/ops/_batch_norm_impl_index.h>
|
| 95 |
+
#include <ATen/ops/_batch_norm_impl_index_backward.h>
|
| 96 |
+
#include <ATen/ops/_batch_norm_no_update.h>
|
| 97 |
+
#include <ATen/ops/_batch_norm_with_update.h>
|
| 98 |
+
#include <ATen/ops/_cast_Byte.h>
|
| 99 |
+
#include <ATen/ops/_cast_Char.h>
|
| 100 |
+
#include <ATen/ops/_cast_Double.h>
|
| 101 |
+
#include <ATen/ops/_cast_Float.h>
|
| 102 |
+
#include <ATen/ops/_cast_Half.h>
|
| 103 |
+
#include <ATen/ops/_cast_Int.h>
|
| 104 |
+
#include <ATen/ops/_cast_Long.h>
|
| 105 |
+
#include <ATen/ops/_cast_Short.h>
|
| 106 |
+
#include <ATen/ops/_cdist_backward.h>
|
| 107 |
+
#include <ATen/ops/_cdist_forward.h>
|
| 108 |
+
#include <ATen/ops/_cholesky_solve_helper.h>
|
| 109 |
+
#include <ATen/ops/_choose_qparams_per_tensor.h>
|
| 110 |
+
#include <ATen/ops/_chunk_cat.h>
|
| 111 |
+
#include <ATen/ops/_coalesce.h>
|
| 112 |
+
#include <ATen/ops/_coalesced.h>
|
| 113 |
+
#include <ATen/ops/_compute_linear_combination.h>
|
| 114 |
+
#include <ATen/ops/_conj.h>
|
| 115 |
+
#include <ATen/ops/_conj_copy.h>
|
| 116 |
+
#include <ATen/ops/_conj_physical.h>
|
| 117 |
+
#include <ATen/ops/_conv_depthwise2d.h>
|
| 118 |
+
#include <ATen/ops/_convert_indices_from_coo_to_csr.h>
|
| 119 |
+
#include <ATen/ops/_convert_indices_from_csr_to_coo.h>
|
| 120 |
+
#include <ATen/ops/_convert_weight_to_int4pack.h>
|
| 121 |
+
#include <ATen/ops/_convert_weight_to_int4pack_for_cpu.h>
|
| 122 |
+
#include <ATen/ops/_convolution.h>
|
| 123 |
+
#include <ATen/ops/_convolution_double_backward.h>
|
| 124 |
+
#include <ATen/ops/_convolution_mode.h>
|
| 125 |
+
#include <ATen/ops/_copy_from.h>
|
| 126 |
+
#include <ATen/ops/_copy_from_and_resize.h>
|
| 127 |
+
#include <ATen/ops/_cslt_compress.h>
|
| 128 |
+
#include <ATen/ops/_cslt_sparse_mm.h>
|
| 129 |
+
#include <ATen/ops/_cslt_sparse_mm_search.h>
|
| 130 |
+
#include <ATen/ops/_ctc_loss.h>
|
| 131 |
+
#include <ATen/ops/_ctc_loss_backward.h>
|
| 132 |
+
#include <ATen/ops/_cudnn_attention_backward.h>
|
| 133 |
+
#include <ATen/ops/_cudnn_attention_forward.h>
|
| 134 |
+
#include <ATen/ops/_cudnn_ctc_loss.h>
|
| 135 |
+
#include <ATen/ops/_cudnn_init_dropout_state.h>
|
| 136 |
+
#include <ATen/ops/_cudnn_rnn.h>
|
| 137 |
+
#include <ATen/ops/_cudnn_rnn_backward.h>
|
| 138 |
+
#include <ATen/ops/_cudnn_rnn_flatten_weight.h>
|
| 139 |
+
#include <ATen/ops/_cufft_clear_plan_cache.h>
|
| 140 |
+
#include <ATen/ops/_cufft_get_plan_cache_max_size.h>
|
| 141 |
+
#include <ATen/ops/_cufft_get_plan_cache_size.h>
|
| 142 |
+
#include <ATen/ops/_cufft_set_plan_cache_max_size.h>
|
| 143 |
+
#include <ATen/ops/_cummax_helper.h>
|
| 144 |
+
#include <ATen/ops/_cummin_helper.h>
|
| 145 |
+
#include <ATen/ops/_debug_has_internal_overlap.h>
|
| 146 |
+
#include <ATen/ops/_dimI.h>
|
| 147 |
+
#include <ATen/ops/_dimV.h>
|
| 148 |
+
#include <ATen/ops/_dim_arange.h>
|
| 149 |
+
#include <ATen/ops/_dirichlet_grad.h>
|
| 150 |
+
#include <ATen/ops/_dyn_quant_matmul_4bit.h>
|
| 151 |
+
#include <ATen/ops/_dyn_quant_pack_4bit_weight.h>
|
| 152 |
+
#include <ATen/ops/_efficient_attention_backward.h>
|
| 153 |
+
#include <ATen/ops/_efficient_attention_forward.h>
|
| 154 |
+
#include <ATen/ops/_efficientzerotensor.h>
|
| 155 |
+
#include <ATen/ops/_embedding_bag.h>
|
| 156 |
+
#include <ATen/ops/_embedding_bag_backward.h>
|
| 157 |
+
#include <ATen/ops/_embedding_bag_dense_backward.h>
|
| 158 |
+
#include <ATen/ops/_embedding_bag_forward_only.h>
|
| 159 |
+
#include <ATen/ops/_embedding_bag_per_sample_weights_backward.h>
|
| 160 |
+
#include <ATen/ops/_embedding_bag_sparse_backward.h>
|
| 161 |
+
#include <ATen/ops/_empty_affine_quantized.h>
|
| 162 |
+
#include <ATen/ops/_empty_per_channel_affine_quantized.h>
|
| 163 |
+
#include <ATen/ops/_euclidean_dist.h>
|
| 164 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine.h>
|
| 165 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_backward.h>
|
| 166 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine.h>
|
| 167 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_backward.h>
|
| 168 |
+
#include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams.h>
|
| 169 |
+
#include <ATen/ops/_fft_c2c.h>
|
| 170 |
+
#include <ATen/ops/_fft_c2r.h>
|
| 171 |
+
#include <ATen/ops/_fft_r2c.h>
|
| 172 |
+
#include <ATen/ops/_fill_mem_eff_dropout_mask.h>
|
| 173 |
+
#include <ATen/ops/_flash_attention_backward.h>
|
| 174 |
+
#include <ATen/ops/_flash_attention_forward.h>
|
| 175 |
+
#include <ATen/ops/_foobar.h>
|
| 176 |
+
#include <ATen/ops/_foreach_abs.h>
|
| 177 |
+
#include <ATen/ops/_foreach_acos.h>
|
| 178 |
+
#include <ATen/ops/_foreach_add.h>
|
| 179 |
+
#include <ATen/ops/_foreach_addcdiv.h>
|
| 180 |
+
#include <ATen/ops/_foreach_addcmul.h>
|
| 181 |
+
#include <ATen/ops/_foreach_asin.h>
|
| 182 |
+
#include <ATen/ops/_foreach_atan.h>
|
| 183 |
+
#include <ATen/ops/_foreach_ceil.h>
|
| 184 |
+
#include <ATen/ops/_foreach_clamp_max.h>
|
| 185 |
+
#include <ATen/ops/_foreach_clamp_min.h>
|
| 186 |
+
#include <ATen/ops/_foreach_copy.h>
|
| 187 |
+
#include <ATen/ops/_foreach_cos.h>
|
| 188 |
+
#include <ATen/ops/_foreach_cosh.h>
|
| 189 |
+
#include <ATen/ops/_foreach_div.h>
|
| 190 |
+
#include <ATen/ops/_foreach_erf.h>
|
| 191 |
+
#include <ATen/ops/_foreach_erfc.h>
|
| 192 |
+
#include <ATen/ops/_foreach_exp.h>
|
| 193 |
+
#include <ATen/ops/_foreach_expm1.h>
|
| 194 |
+
#include <ATen/ops/_foreach_floor.h>
|
| 195 |
+
#include <ATen/ops/_foreach_frac.h>
|
| 196 |
+
#include <ATen/ops/_foreach_lerp.h>
|
| 197 |
+
#include <ATen/ops/_foreach_lgamma.h>
|
| 198 |
+
#include <ATen/ops/_foreach_log.h>
|
| 199 |
+
#include <ATen/ops/_foreach_log10.h>
|
| 200 |
+
#include <ATen/ops/_foreach_log1p.h>
|
| 201 |
+
#include <ATen/ops/_foreach_log2.h>
|
| 202 |
+
#include <ATen/ops/_foreach_max.h>
|
| 203 |
+
#include <ATen/ops/_foreach_maximum.h>
|
| 204 |
+
#include <ATen/ops/_foreach_minimum.h>
|
| 205 |
+
#include <ATen/ops/_foreach_mul.h>
|
| 206 |
+
#include <ATen/ops/_foreach_neg.h>
|
| 207 |
+
#include <ATen/ops/_foreach_norm.h>
|
| 208 |
+
#include <ATen/ops/_foreach_pow.h>
|
| 209 |
+
#include <ATen/ops/_foreach_reciprocal.h>
|
| 210 |
+
#include <ATen/ops/_foreach_round.h>
|
| 211 |
+
#include <ATen/ops/_foreach_rsqrt.h>
|
| 212 |
+
#include <ATen/ops/_foreach_sigmoid.h>
|
| 213 |
+
#include <ATen/ops/_foreach_sign.h>
|
| 214 |
+
#include <ATen/ops/_foreach_sin.h>
|
| 215 |
+
#include <ATen/ops/_foreach_sinh.h>
|
| 216 |
+
#include <ATen/ops/_foreach_sqrt.h>
|
| 217 |
+
#include <ATen/ops/_foreach_sub.h>
|
| 218 |
+
#include <ATen/ops/_foreach_tan.h>
|
| 219 |
+
#include <ATen/ops/_foreach_tanh.h>
|
| 220 |
+
#include <ATen/ops/_foreach_trunc.h>
|
| 221 |
+
#include <ATen/ops/_foreach_zero.h>
|
| 222 |
+
#include <ATen/ops/_functional_assert_async.h>
|
| 223 |
+
#include <ATen/ops/_functional_assert_scalar.h>
|
| 224 |
+
#include <ATen/ops/_functional_sym_constrain_range.h>
|
| 225 |
+
#include <ATen/ops/_functional_sym_constrain_range_for_size.h>
|
| 226 |
+
#include <ATen/ops/_fused_adagrad.h>
|
| 227 |
+
#include <ATen/ops/_fused_adam.h>
|
| 228 |
+
#include <ATen/ops/_fused_adamw.h>
|
| 229 |
+
#include <ATen/ops/_fused_dropout.h>
|
| 230 |
+
#include <ATen/ops/_fused_moving_avg_obs_fq_helper.h>
|
| 231 |
+
#include <ATen/ops/_fused_rms_norm.h>
|
| 232 |
+
#include <ATen/ops/_fused_rms_norm_backward.h>
|
| 233 |
+
#include <ATen/ops/_fused_sdp_choice.h>
|
| 234 |
+
#include <ATen/ops/_fused_sgd.h>
|
| 235 |
+
#include <ATen/ops/_fw_primal.h>
|
| 236 |
+
#include <ATen/ops/_fw_primal_copy.h>
|
| 237 |
+
#include <ATen/ops/_gather_sparse_backward.h>
|
| 238 |
+
#include <ATen/ops/_grid_sampler_2d_cpu_fallback.h>
|
| 239 |
+
#include <ATen/ops/_grid_sampler_2d_cpu_fallback_backward.h>
|
| 240 |
+
#include <ATen/ops/_grouped_mm.h>
|
| 241 |
+
#include <ATen/ops/_has_compatible_shallow_copy_type.h>
|
| 242 |
+
#include <ATen/ops/_has_same_storage_numel.h>
|
| 243 |
+
#include <ATen/ops/_histogramdd_bin_edges.h>
|
| 244 |
+
#include <ATen/ops/_histogramdd_from_bin_cts.h>
|
| 245 |
+
#include <ATen/ops/_histogramdd_from_bin_tensors.h>
|
| 246 |
+
#include <ATen/ops/_index_put_impl.h>
|
| 247 |
+
#include <ATen/ops/_indices.h>
|
| 248 |
+
#include <ATen/ops/_indices_copy.h>
|
| 249 |
+
#include <ATen/ops/_int_mm.h>
|
| 250 |
+
#include <ATen/ops/_is_all_true.h>
|
| 251 |
+
#include <ATen/ops/_is_any_true.h>
|
| 252 |
+
#include <ATen/ops/_is_zerotensor.h>
|
| 253 |
+
#include <ATen/ops/_jagged_to_padded_dense_forward.h>
|
| 254 |
+
#include <ATen/ops/_lazy_clone.h>
|
| 255 |
+
#include <ATen/ops/_linalg_check_errors.h>
|
| 256 |
+
#include <ATen/ops/_linalg_det.h>
|
| 257 |
+
#include <ATen/ops/_linalg_eigh.h>
|
| 258 |
+
#include <ATen/ops/_linalg_eigvals.h>
|
| 259 |
+
#include <ATen/ops/_linalg_slogdet.h>
|
| 260 |
+
#include <ATen/ops/_linalg_solve_ex.h>
|
| 261 |
+
#include <ATen/ops/_linalg_svd.h>
|
| 262 |
+
#include <ATen/ops/_local_scalar_dense.h>
|
| 263 |
+
#include <ATen/ops/_log_softmax.h>
|
| 264 |
+
#include <ATen/ops/_log_softmax_backward_data.h>
|
| 265 |
+
#include <ATen/ops/_logcumsumexp.h>
|
| 266 |
+
#include <ATen/ops/_lstm_mps.h>
|
| 267 |
+
#include <ATen/ops/_lu_with_info.h>
|
| 268 |
+
#include <ATen/ops/_make_dep_token.h>
|
| 269 |
+
#include <ATen/ops/_make_dual.h>
|
| 270 |
+
#include <ATen/ops/_make_dual_copy.h>
|
| 271 |
+
#include <ATen/ops/_make_per_channel_quantized_tensor.h>
|
| 272 |
+
#include <ATen/ops/_make_per_tensor_quantized_tensor.h>
|
| 273 |
+
#include <ATen/ops/_masked_scale.h>
|
| 274 |
+
#include <ATen/ops/_masked_softmax.h>
|
| 275 |
+
#include <ATen/ops/_masked_softmax_backward.h>
|
| 276 |
+
#include <ATen/ops/_mixed_dtypes_linear.h>
|
| 277 |
+
#include <ATen/ops/_mkldnn_reshape.h>
|
| 278 |
+
#include <ATen/ops/_mkldnn_transpose.h>
|
| 279 |
+
#include <ATen/ops/_mps_convolution.h>
|
| 280 |
+
#include <ATen/ops/_mps_convolution_transpose.h>
|
| 281 |
+
#include <ATen/ops/_native_batch_norm_legit.h>
|
| 282 |
+
#include <ATen/ops/_native_batch_norm_legit_no_training.h>
|
| 283 |
+
#include <ATen/ops/_native_multi_head_attention.h>
|
| 284 |
+
#include <ATen/ops/_neg_view.h>
|
| 285 |
+
#include <ATen/ops/_neg_view_copy.h>
|
| 286 |
+
#include <ATen/ops/_nested_compute_contiguous_strides_offsets.h>
|
| 287 |
+
#include <ATen/ops/_nested_from_padded.h>
|
| 288 |
+
#include <ATen/ops/_nested_from_padded_and_nested_example.h>
|
| 289 |
+
#include <ATen/ops/_nested_from_padded_tensor.h>
|
| 290 |
+
#include <ATen/ops/_nested_get_jagged_dummy.h>
|
| 291 |
+
#include <ATen/ops/_nested_get_lengths.h>
|
| 292 |
+
#include <ATen/ops/_nested_get_max_seqlen.h>
|
| 293 |
+
#include <ATen/ops/_nested_get_min_seqlen.h>
|
| 294 |
+
#include <ATen/ops/_nested_get_offsets.h>
|
| 295 |
+
#include <ATen/ops/_nested_get_ragged_idx.h>
|
| 296 |
+
#include <ATen/ops/_nested_get_values.h>
|
| 297 |
+
#include <ATen/ops/_nested_get_values_copy.h>
|
| 298 |
+
#include <ATen/ops/_nested_select_backward.h>
|
| 299 |
+
#include <ATen/ops/_nested_sum_backward.h>
|
| 300 |
+
#include <ATen/ops/_nested_tensor_from_mask.h>
|
| 301 |
+
#include <ATen/ops/_nested_tensor_from_mask_left_aligned.h>
|
| 302 |
+
#include <ATen/ops/_nested_tensor_from_tensor_list.h>
|
| 303 |
+
#include <ATen/ops/_nested_tensor_size.h>
|
| 304 |
+
#include <ATen/ops/_nested_tensor_softmax_with_shape.h>
|
| 305 |
+
#include <ATen/ops/_nested_tensor_storage_offsets.h>
|
| 306 |
+
#include <ATen/ops/_nested_tensor_strides.h>
|
| 307 |
+
#include <ATen/ops/_nested_view_from_buffer.h>
|
| 308 |
+
#include <ATen/ops/_nested_view_from_buffer_copy.h>
|
| 309 |
+
#include <ATen/ops/_nested_view_from_jagged.h>
|
| 310 |
+
#include <ATen/ops/_nested_view_from_jagged_copy.h>
|
| 311 |
+
#include <ATen/ops/_new_zeros_with_same_feature_meta.h>
|
| 312 |
+
#include <ATen/ops/_nnpack_available.h>
|
| 313 |
+
#include <ATen/ops/_nnpack_spatial_convolution.h>
|
| 314 |
+
#include <ATen/ops/_nnz.h>
|
| 315 |
+
#include <ATen/ops/_pack_padded_sequence.h>
|
| 316 |
+
#include <ATen/ops/_pack_padded_sequence_backward.h>
|
| 317 |
+
#include <ATen/ops/_pad_circular.h>
|
| 318 |
+
#include <ATen/ops/_pad_enum.h>
|
| 319 |
+
#include <ATen/ops/_pad_packed_sequence.h>
|
| 320 |
+
#include <ATen/ops/_padded_dense_to_jagged_forward.h>
|
| 321 |
+
#include <ATen/ops/_pdist_backward.h>
|
| 322 |
+
#include <ATen/ops/_pdist_forward.h>
|
| 323 |
+
#include <ATen/ops/_pin_memory.h>
|
| 324 |
+
#include <ATen/ops/_prelu_kernel.h>
|
| 325 |
+
#include <ATen/ops/_prelu_kernel_backward.h>
|
| 326 |
+
#include <ATen/ops/_print.h>
|
| 327 |
+
#include <ATen/ops/_propagate_xla_data.h>
|
| 328 |
+
#include <ATen/ops/_remove_batch_dim.h>
|
| 329 |
+
#include <ATen/ops/_reshape_alias.h>
|
| 330 |
+
#include <ATen/ops/_reshape_alias_copy.h>
|
| 331 |
+
#include <ATen/ops/_reshape_copy.h>
|
| 332 |
+
#include <ATen/ops/_reshape_from_tensor.h>
|
| 333 |
+
#include <ATen/ops/_resize_output.h>
|
| 334 |
+
#include <ATen/ops/_rowwise_prune.h>
|
| 335 |
+
#include <ATen/ops/_safe_softmax.h>
|
| 336 |
+
#include <ATen/ops/_sample_dirichlet.h>
|
| 337 |
+
#include <ATen/ops/_saturate_weight_to_fp16.h>
|
| 338 |
+
#include <ATen/ops/_scaled_dot_product_attention_math.h>
|
| 339 |
+
#include <ATen/ops/_scaled_dot_product_attention_math_for_mps.h>
|
| 340 |
+
#include <ATen/ops/_scaled_dot_product_cudnn_attention.h>
|
| 341 |
+
#include <ATen/ops/_scaled_dot_product_cudnn_attention_backward.h>
|
| 342 |
+
#include <ATen/ops/_scaled_dot_product_efficient_attention.h>
|
| 343 |
+
#include <ATen/ops/_scaled_dot_product_efficient_attention_backward.h>
|
| 344 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention.h>
|
| 345 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_backward.h>
|
| 346 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu.h>
|
| 347 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward.h>
|
| 348 |
+
#include <ATen/ops/_scaled_dot_product_fused_attention_overrideable.h>
|
| 349 |
+
#include <ATen/ops/_scaled_dot_product_fused_attention_overrideable_backward.h>
|
| 350 |
+
#include <ATen/ops/_scaled_grouped_mm.h>
|
| 351 |
+
#include <ATen/ops/_scaled_grouped_mm_v2.h>
|
| 352 |
+
#include <ATen/ops/_scaled_mm.h>
|
| 353 |
+
#include <ATen/ops/_scaled_mm_v2.h>
|
| 354 |
+
#include <ATen/ops/_segment_reduce_backward.h>
|
| 355 |
+
#include <ATen/ops/_shape_as_tensor.h>
|
| 356 |
+
#include <ATen/ops/_slow_conv2d_backward.h>
|
| 357 |
+
#include <ATen/ops/_slow_conv2d_forward.h>
|
| 358 |
+
#include <ATen/ops/_sobol_engine_draw.h>
|
| 359 |
+
#include <ATen/ops/_sobol_engine_ff.h>
|
| 360 |
+
#include <ATen/ops/_sobol_engine_initialize_state.h>
|
| 361 |
+
#include <ATen/ops/_sobol_engine_scramble.h>
|
| 362 |
+
#include <ATen/ops/_softmax.h>
|
| 363 |
+
#include <ATen/ops/_softmax_backward_data.h>
|
| 364 |
+
#include <ATen/ops/_sparse_addmm.h>
|
| 365 |
+
#include <ATen/ops/_sparse_broadcast_to.h>
|
| 366 |
+
#include <ATen/ops/_sparse_broadcast_to_copy.h>
|
| 367 |
+
#include <ATen/ops/_sparse_bsc_tensor_unsafe.h>
|
| 368 |
+
#include <ATen/ops/_sparse_bsr_tensor_unsafe.h>
|
| 369 |
+
#include <ATen/ops/_sparse_compressed_tensor_unsafe.h>
|
| 370 |
+
#include <ATen/ops/_sparse_compressed_tensor_with_dims.h>
|
| 371 |
+
#include <ATen/ops/_sparse_coo_tensor_unsafe.h>
|
| 372 |
+
#include <ATen/ops/_sparse_coo_tensor_with_dims.h>
|
| 373 |
+
#include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors.h>
|
| 374 |
+
#include <ATen/ops/_sparse_csc_tensor_unsafe.h>
|
| 375 |
+
#include <ATen/ops/_sparse_csr_prod.h>
|
| 376 |
+
#include <ATen/ops/_sparse_csr_sum.h>
|
| 377 |
+
#include <ATen/ops/_sparse_csr_tensor_unsafe.h>
|
| 378 |
+
#include <ATen/ops/_sparse_log_softmax.h>
|
| 379 |
+
#include <ATen/ops/_sparse_log_softmax_backward_data.h>
|
| 380 |
+
#include <ATen/ops/_sparse_mask_projection.h>
|
| 381 |
+
#include <ATen/ops/_sparse_mm.h>
|
| 382 |
+
#include <ATen/ops/_sparse_mm_reduce_impl.h>
|
| 383 |
+
#include <ATen/ops/_sparse_mm_reduce_impl_backward.h>
|
| 384 |
+
#include <ATen/ops/_sparse_semi_structured_addmm.h>
|
| 385 |
+
#include <ATen/ops/_sparse_semi_structured_apply.h>
|
| 386 |
+
#include <ATen/ops/_sparse_semi_structured_apply_dense.h>
|
| 387 |
+
#include <ATen/ops/_sparse_semi_structured_linear.h>
|
| 388 |
+
#include <ATen/ops/_sparse_semi_structured_mm.h>
|
| 389 |
+
#include <ATen/ops/_sparse_semi_structured_tile.h>
|
| 390 |
+
#include <ATen/ops/_sparse_softmax.h>
|
| 391 |
+
#include <ATen/ops/_sparse_softmax_backward_data.h>
|
| 392 |
+
#include <ATen/ops/_sparse_sparse_matmul.h>
|
| 393 |
+
#include <ATen/ops/_sparse_sum.h>
|
| 394 |
+
#include <ATen/ops/_sparse_sum_backward.h>
|
| 395 |
+
#include <ATen/ops/_spdiags.h>
|
| 396 |
+
#include <ATen/ops/_spsolve.h>
|
| 397 |
+
#include <ATen/ops/_stack.h>
|
| 398 |
+
#include <ATen/ops/_standard_gamma.h>
|
| 399 |
+
#include <ATen/ops/_standard_gamma_grad.h>
|
| 400 |
+
#include <ATen/ops/_test_ambiguous_defaults.h>
|
| 401 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch.h>
|
| 402 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_view.h>
|
| 403 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_view_copy.h>
|
| 404 |
+
#include <ATen/ops/_test_check_tensor.h>
|
| 405 |
+
#include <ATen/ops/_test_functorch_fallback.h>
|
| 406 |
+
#include <ATen/ops/_test_optional_filled_intlist.h>
|
| 407 |
+
#include <ATen/ops/_test_optional_floatlist.h>
|
| 408 |
+
#include <ATen/ops/_test_optional_intlist.h>
|
| 409 |
+
#include <ATen/ops/_test_parallel_materialize.h>
|
| 410 |
+
#include <ATen/ops/_test_serialization_subcmul.h>
|
| 411 |
+
#include <ATen/ops/_test_string_default.h>
|
| 412 |
+
#include <ATen/ops/_test_warn_in_autograd.h>
|
| 413 |
+
#include <ATen/ops/_thnn_differentiable_gru_cell_backward.h>
|
| 414 |
+
#include <ATen/ops/_thnn_differentiable_lstm_cell_backward.h>
|
| 415 |
+
#include <ATen/ops/_thnn_fused_gru_cell.h>
|
| 416 |
+
#include <ATen/ops/_thnn_fused_gru_cell_backward.h>
|
| 417 |
+
#include <ATen/ops/_thnn_fused_lstm_cell.h>
|
| 418 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_backward.h>
|
| 419 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_backward_impl.h>
|
| 420 |
+
#include <ATen/ops/_to_copy.h>
|
| 421 |
+
#include <ATen/ops/_to_cpu.h>
|
| 422 |
+
#include <ATen/ops/_to_dense.h>
|
| 423 |
+
#include <ATen/ops/_to_sparse.h>
|
| 424 |
+
#include <ATen/ops/_to_sparse_bsc.h>
|
| 425 |
+
#include <ATen/ops/_to_sparse_bsr.h>
|
| 426 |
+
#include <ATen/ops/_to_sparse_csc.h>
|
| 427 |
+
#include <ATen/ops/_to_sparse_csr.h>
|
| 428 |
+
#include <ATen/ops/_to_sparse_semi_structured.h>
|
| 429 |
+
#include <ATen/ops/_transform_bias_rescale_qkv.h>
|
| 430 |
+
#include <ATen/ops/_transformer_encoder_layer_fwd.h>
|
| 431 |
+
#include <ATen/ops/_trilinear.h>
|
| 432 |
+
#include <ATen/ops/_triton_multi_head_attention.h>
|
| 433 |
+
#include <ATen/ops/_triton_scaled_dot_attention.h>
|
| 434 |
+
#include <ATen/ops/_unique.h>
|
| 435 |
+
#include <ATen/ops/_unique2.h>
|
| 436 |
+
#include <ATen/ops/_unpack_dual.h>
|
| 437 |
+
#include <ATen/ops/_unsafe_index.h>
|
| 438 |
+
#include <ATen/ops/_unsafe_index_put.h>
|
| 439 |
+
#include <ATen/ops/_unsafe_masked_index.h>
|
| 440 |
+
#include <ATen/ops/_unsafe_masked_index_put_accumulate.h>
|
| 441 |
+
#include <ATen/ops/_unsafe_view.h>
|
| 442 |
+
#include <ATen/ops/_upsample_bicubic2d_aa.h>
|
| 443 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_backward.h>
|
| 444 |
+
#include <ATen/ops/_upsample_bilinear2d_aa.h>
|
| 445 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_backward.h>
|
| 446 |
+
#include <ATen/ops/_upsample_nearest_exact1d.h>
|
| 447 |
+
#include <ATen/ops/_upsample_nearest_exact1d_backward.h>
|
| 448 |
+
#include <ATen/ops/_upsample_nearest_exact2d.h>
|
| 449 |
+
#include <ATen/ops/_upsample_nearest_exact2d_backward.h>
|
| 450 |
+
#include <ATen/ops/_upsample_nearest_exact3d.h>
|
| 451 |
+
#include <ATen/ops/_upsample_nearest_exact3d_backward.h>
|
| 452 |
+
#include <ATen/ops/_use_cudnn_ctc_loss.h>
|
| 453 |
+
#include <ATen/ops/_use_cudnn_rnn_flatten_weight.h>
|
| 454 |
+
#include <ATen/ops/_validate_compressed_sparse_indices.h>
|
| 455 |
+
#include <ATen/ops/_validate_sparse_bsc_tensor_args.h>
|
| 456 |
+
#include <ATen/ops/_validate_sparse_bsr_tensor_args.h>
|
| 457 |
+
#include <ATen/ops/_validate_sparse_compressed_tensor_args.h>
|
| 458 |
+
#include <ATen/ops/_validate_sparse_coo_tensor_args.h>
|
| 459 |
+
#include <ATen/ops/_validate_sparse_csc_tensor_args.h>
|
| 460 |
+
#include <ATen/ops/_validate_sparse_csr_tensor_args.h>
|
| 461 |
+
#include <ATen/ops/_values.h>
|
| 462 |
+
#include <ATen/ops/_values_copy.h>
|
| 463 |
+
#include <ATen/ops/_version.h>
|
| 464 |
+
#include <ATen/ops/_weight_int4pack_mm.h>
|
| 465 |
+
#include <ATen/ops/_weight_int4pack_mm_for_cpu.h>
|
| 466 |
+
#include <ATen/ops/_weight_int4pack_mm_with_scales_and_zeros.h>
|
| 467 |
+
#include <ATen/ops/_weight_int8pack_mm.h>
|
| 468 |
+
#include <ATen/ops/_weight_norm.h>
|
| 469 |
+
#include <ATen/ops/_weight_norm_differentiable_backward.h>
|
| 470 |
+
#include <ATen/ops/_weight_norm_interface.h>
|
| 471 |
+
#include <ATen/ops/_weight_norm_interface_backward.h>
|
| 472 |
+
#include <ATen/ops/_wrapped_linear_prepack.h>
|
| 473 |
+
#include <ATen/ops/_wrapped_quantized_linear_prepacked.h>
|
| 474 |
+
#include <ATen/ops/abs.h>
|
| 475 |
+
#include <ATen/ops/absolute.h>
|
| 476 |
+
#include <ATen/ops/acos.h>
|
| 477 |
+
#include <ATen/ops/acosh.h>
|
| 478 |
+
#include <ATen/ops/adaptive_avg_pool1d.h>
|
| 479 |
+
#include <ATen/ops/adaptive_avg_pool2d.h>
|
| 480 |
+
#include <ATen/ops/adaptive_avg_pool3d.h>
|
| 481 |
+
#include <ATen/ops/adaptive_avg_pool3d_backward.h>
|
| 482 |
+
#include <ATen/ops/adaptive_max_pool1d.h>
|
| 483 |
+
#include <ATen/ops/adaptive_max_pool2d.h>
|
| 484 |
+
#include <ATen/ops/adaptive_max_pool2d_backward.h>
|
| 485 |
+
#include <ATen/ops/adaptive_max_pool3d.h>
|
| 486 |
+
#include <ATen/ops/adaptive_max_pool3d_backward.h>
|
| 487 |
+
#include <ATen/ops/add.h>
|
| 488 |
+
#include <ATen/ops/addbmm.h>
|
| 489 |
+
#include <ATen/ops/addcdiv.h>
|
| 490 |
+
#include <ATen/ops/addcmul.h>
|
| 491 |
+
#include <ATen/ops/addmm.h>
|
| 492 |
+
#include <ATen/ops/addmv.h>
|
| 493 |
+
#include <ATen/ops/addr.h>
|
| 494 |
+
#include <ATen/ops/adjoint.h>
|
| 495 |
+
#include <ATen/ops/affine_grid_generator.h>
|
| 496 |
+
#include <ATen/ops/affine_grid_generator_backward.h>
|
| 497 |
+
#include <ATen/ops/alias.h>
|
| 498 |
+
#include <ATen/ops/alias_copy.h>
|
| 499 |
+
#include <ATen/ops/align_as.h>
|
| 500 |
+
#include <ATen/ops/align_tensors.h>
|
| 501 |
+
#include <ATen/ops/align_to.h>
|
| 502 |
+
#include <ATen/ops/all.h>
|
| 503 |
+
#include <ATen/ops/allclose.h>
|
| 504 |
+
#include <ATen/ops/alpha_dropout.h>
|
| 505 |
+
#include <ATen/ops/amax.h>
|
| 506 |
+
#include <ATen/ops/amin.h>
|
| 507 |
+
#include <ATen/ops/aminmax.h>
|
| 508 |
+
#include <ATen/ops/and.h>
|
| 509 |
+
#include <ATen/ops/angle.h>
|
| 510 |
+
#include <ATen/ops/any.h>
|
| 511 |
+
#include <ATen/ops/arange.h>
|
| 512 |
+
#include <ATen/ops/arccos.h>
|
| 513 |
+
#include <ATen/ops/arccosh.h>
|
| 514 |
+
#include <ATen/ops/arcsin.h>
|
| 515 |
+
#include <ATen/ops/arcsinh.h>
|
| 516 |
+
#include <ATen/ops/arctan.h>
|
| 517 |
+
#include <ATen/ops/arctan2.h>
|
| 518 |
+
#include <ATen/ops/arctanh.h>
|
| 519 |
+
#include <ATen/ops/argmax.h>
|
| 520 |
+
#include <ATen/ops/argmin.h>
|
| 521 |
+
#include <ATen/ops/argsort.h>
|
| 522 |
+
#include <ATen/ops/argwhere.h>
|
| 523 |
+
#include <ATen/ops/as_strided.h>
|
| 524 |
+
#include <ATen/ops/as_strided_copy.h>
|
| 525 |
+
#include <ATen/ops/as_strided_scatter.h>
|
| 526 |
+
#include <ATen/ops/asin.h>
|
| 527 |
+
#include <ATen/ops/asinh.h>
|
| 528 |
+
#include <ATen/ops/atan.h>
|
| 529 |
+
#include <ATen/ops/atan2.h>
|
| 530 |
+
#include <ATen/ops/atanh.h>
|
| 531 |
+
#include <ATen/ops/atleast_1d.h>
|
| 532 |
+
#include <ATen/ops/atleast_2d.h>
|
| 533 |
+
#include <ATen/ops/atleast_3d.h>
|
| 534 |
+
#include <ATen/ops/avg_pool1d.h>
|
| 535 |
+
#include <ATen/ops/avg_pool2d.h>
|
| 536 |
+
#include <ATen/ops/avg_pool2d_backward.h>
|
| 537 |
+
#include <ATen/ops/avg_pool3d.h>
|
| 538 |
+
#include <ATen/ops/avg_pool3d_backward.h>
|
| 539 |
+
#include <ATen/ops/baddbmm.h>
|
| 540 |
+
#include <ATen/ops/bartlett_window.h>
|
| 541 |
+
#include <ATen/ops/batch_norm.h>
|
| 542 |
+
#include <ATen/ops/batch_norm_backward.h>
|
| 543 |
+
#include <ATen/ops/batch_norm_backward_elemt.h>
|
| 544 |
+
#include <ATen/ops/batch_norm_backward_reduce.h>
|
| 545 |
+
#include <ATen/ops/batch_norm_elemt.h>
|
| 546 |
+
#include <ATen/ops/batch_norm_gather_stats.h>
|
| 547 |
+
#include <ATen/ops/batch_norm_gather_stats_with_counts.h>
|
| 548 |
+
#include <ATen/ops/batch_norm_stats.h>
|
| 549 |
+
#include <ATen/ops/batch_norm_update_stats.h>
|
| 550 |
+
#include <ATen/ops/bernoulli.h>
|
| 551 |
+
#include <ATen/ops/bilinear.h>
|
| 552 |
+
#include <ATen/ops/binary_cross_entropy.h>
|
| 553 |
+
#include <ATen/ops/binary_cross_entropy_backward.h>
|
| 554 |
+
#include <ATen/ops/binary_cross_entropy_with_logits.h>
|
| 555 |
+
#include <ATen/ops/bincount.h>
|
| 556 |
+
#include <ATen/ops/binomial.h>
|
| 557 |
+
#include <ATen/ops/bitwise_and.h>
|
| 558 |
+
#include <ATen/ops/bitwise_left_shift.h>
|
| 559 |
+
#include <ATen/ops/bitwise_not.h>
|
| 560 |
+
#include <ATen/ops/bitwise_or.h>
|
| 561 |
+
#include <ATen/ops/bitwise_right_shift.h>
|
| 562 |
+
#include <ATen/ops/bitwise_xor.h>
|
| 563 |
+
#include <ATen/ops/blackman_window.h>
|
| 564 |
+
#include <ATen/ops/block_diag.h>
|
| 565 |
+
#include <ATen/ops/bmm.h>
|
| 566 |
+
#include <ATen/ops/broadcast_tensors.h>
|
| 567 |
+
#include <ATen/ops/broadcast_to.h>
|
| 568 |
+
#include <ATen/ops/bucketize.h>
|
| 569 |
+
#include <ATen/ops/can_cast.h>
|
| 570 |
+
#include <ATen/ops/cartesian_prod.h>
|
| 571 |
+
#include <ATen/ops/cat.h>
|
| 572 |
+
#include <ATen/ops/cauchy.h>
|
| 573 |
+
#include <ATen/ops/ccol_indices.h>
|
| 574 |
+
#include <ATen/ops/ccol_indices_copy.h>
|
| 575 |
+
#include <ATen/ops/cdist.h>
|
| 576 |
+
#include <ATen/ops/ceil.h>
|
| 577 |
+
#include <ATen/ops/celu.h>
|
| 578 |
+
#include <ATen/ops/chain_matmul.h>
|
| 579 |
+
#include <ATen/ops/chalf.h>
|
| 580 |
+
#include <ATen/ops/channel_shuffle.h>
|
| 581 |
+
#include <ATen/ops/cholesky.h>
|
| 582 |
+
#include <ATen/ops/cholesky_inverse.h>
|
| 583 |
+
#include <ATen/ops/cholesky_solve.h>
|
| 584 |
+
#include <ATen/ops/choose_qparams_optimized.h>
|
| 585 |
+
#include <ATen/ops/chunk.h>
|
| 586 |
+
#include <ATen/ops/clamp.h>
|
| 587 |
+
#include <ATen/ops/clamp_max.h>
|
| 588 |
+
#include <ATen/ops/clamp_min.h>
|
| 589 |
+
#include <ATen/ops/clip.h>
|
| 590 |
+
#include <ATen/ops/clone.h>
|
| 591 |
+
#include <ATen/ops/coalesce.h>
|
| 592 |
+
#include <ATen/ops/col2im.h>
|
| 593 |
+
#include <ATen/ops/col_indices.h>
|
| 594 |
+
#include <ATen/ops/col_indices_copy.h>
|
| 595 |
+
#include <ATen/ops/column_stack.h>
|
| 596 |
+
#include <ATen/ops/combinations.h>
|
| 597 |
+
#include <ATen/ops/complex.h>
|
| 598 |
+
#include <ATen/ops/concat.h>
|
| 599 |
+
#include <ATen/ops/concatenate.h>
|
| 600 |
+
#include <ATen/ops/conj.h>
|
| 601 |
+
#include <ATen/ops/conj_physical.h>
|
| 602 |
+
#include <ATen/ops/constant_pad_nd.h>
|
| 603 |
+
#include <ATen/ops/contiguous.h>
|
| 604 |
+
#include <ATen/ops/conv1d.h>
|
| 605 |
+
#include <ATen/ops/conv2d.h>
|
| 606 |
+
#include <ATen/ops/conv3d.h>
|
| 607 |
+
#include <ATen/ops/conv_depthwise3d.h>
|
| 608 |
+
#include <ATen/ops/conv_tbc.h>
|
| 609 |
+
#include <ATen/ops/conv_tbc_backward.h>
|
| 610 |
+
#include <ATen/ops/conv_transpose1d.h>
|
| 611 |
+
#include <ATen/ops/conv_transpose2d.h>
|
| 612 |
+
#include <ATen/ops/conv_transpose3d.h>
|
| 613 |
+
#include <ATen/ops/convolution.h>
|
| 614 |
+
#include <ATen/ops/convolution_backward.h>
|
| 615 |
+
#include <ATen/ops/convolution_backward_overrideable.h>
|
| 616 |
+
#include <ATen/ops/convolution_overrideable.h>
|
| 617 |
+
#include <ATen/ops/copy.h>
|
| 618 |
+
#include <ATen/ops/copy_sparse_to_sparse.h>
|
| 619 |
+
#include <ATen/ops/copysign.h>
|
| 620 |
+
#include <ATen/ops/corrcoef.h>
|
| 621 |
+
#include <ATen/ops/cos.h>
|
| 622 |
+
#include <ATen/ops/cosh.h>
|
| 623 |
+
#include <ATen/ops/cosine_embedding_loss.h>
|
| 624 |
+
#include <ATen/ops/cosine_similarity.h>
|
| 625 |
+
#include <ATen/ops/count_nonzero.h>
|
| 626 |
+
#include <ATen/ops/cov.h>
|
| 627 |
+
#include <ATen/ops/cross.h>
|
| 628 |
+
#include <ATen/ops/cross_entropy_loss.h>
|
| 629 |
+
#include <ATen/ops/crow_indices.h>
|
| 630 |
+
#include <ATen/ops/crow_indices_copy.h>
|
| 631 |
+
#include <ATen/ops/ctc_loss.h>
|
| 632 |
+
#include <ATen/ops/cudnn_affine_grid_generator.h>
|
| 633 |
+
#include <ATen/ops/cudnn_affine_grid_generator_backward.h>
|
| 634 |
+
#include <ATen/ops/cudnn_batch_norm.h>
|
| 635 |
+
#include <ATen/ops/cudnn_batch_norm_backward.h>
|
| 636 |
+
#include <ATen/ops/cudnn_convolution.h>
|
| 637 |
+
#include <ATen/ops/cudnn_convolution_add_relu.h>
|
| 638 |
+
#include <ATen/ops/cudnn_convolution_relu.h>
|
| 639 |
+
#include <ATen/ops/cudnn_convolution_transpose.h>
|
| 640 |
+
#include <ATen/ops/cudnn_grid_sampler.h>
|
| 641 |
+
#include <ATen/ops/cudnn_grid_sampler_backward.h>
|
| 642 |
+
#include <ATen/ops/cudnn_is_acceptable.h>
|
| 643 |
+
#include <ATen/ops/cummax.h>
|
| 644 |
+
#include <ATen/ops/cummaxmin_backward.h>
|
| 645 |
+
#include <ATen/ops/cummin.h>
|
| 646 |
+
#include <ATen/ops/cumprod.h>
|
| 647 |
+
#include <ATen/ops/cumprod_backward.h>
|
| 648 |
+
#include <ATen/ops/cumsum.h>
|
| 649 |
+
#include <ATen/ops/cumulative_trapezoid.h>
|
| 650 |
+
#include <ATen/ops/data.h>
|
| 651 |
+
#include <ATen/ops/deg2rad.h>
|
| 652 |
+
#include <ATen/ops/dense_dim.h>
|
| 653 |
+
#include <ATen/ops/dequantize.h>
|
| 654 |
+
#include <ATen/ops/det.h>
|
| 655 |
+
#include <ATen/ops/detach.h>
|
| 656 |
+
#include <ATen/ops/detach_copy.h>
|
| 657 |
+
#include <ATen/ops/diag.h>
|
| 658 |
+
#include <ATen/ops/diag_embed.h>
|
| 659 |
+
#include <ATen/ops/diagflat.h>
|
| 660 |
+
#include <ATen/ops/diagonal.h>
|
| 661 |
+
#include <ATen/ops/diagonal_backward.h>
|
| 662 |
+
#include <ATen/ops/diagonal_copy.h>
|
| 663 |
+
#include <ATen/ops/diagonal_scatter.h>
|
| 664 |
+
#include <ATen/ops/diff.h>
|
| 665 |
+
#include <ATen/ops/digamma.h>
|
| 666 |
+
#include <ATen/ops/dist.h>
|
| 667 |
+
#include <ATen/ops/div.h>
|
| 668 |
+
#include <ATen/ops/divide.h>
|
| 669 |
+
#include <ATen/ops/dot.h>
|
| 670 |
+
#include <ATen/ops/dropout.h>
|
| 671 |
+
#include <ATen/ops/dsplit.h>
|
| 672 |
+
#include <ATen/ops/dstack.h>
|
| 673 |
+
#include <ATen/ops/einsum.h>
|
| 674 |
+
#include <ATen/ops/elu.h>
|
| 675 |
+
#include <ATen/ops/elu_backward.h>
|
| 676 |
+
#include <ATen/ops/embedding.h>
|
| 677 |
+
#include <ATen/ops/embedding_backward.h>
|
| 678 |
+
#include <ATen/ops/embedding_bag.h>
|
| 679 |
+
#include <ATen/ops/embedding_dense_backward.h>
|
| 680 |
+
#include <ATen/ops/embedding_renorm.h>
|
| 681 |
+
#include <ATen/ops/embedding_sparse_backward.h>
|
| 682 |
+
#include <ATen/ops/empty.h>
|
| 683 |
+
#include <ATen/ops/empty_like.h>
|
| 684 |
+
#include <ATen/ops/empty_permuted.h>
|
| 685 |
+
#include <ATen/ops/empty_quantized.h>
|
| 686 |
+
#include <ATen/ops/empty_strided.h>
|
| 687 |
+
#include <ATen/ops/eq.h>
|
| 688 |
+
#include <ATen/ops/equal.h>
|
| 689 |
+
#include <ATen/ops/erf.h>
|
| 690 |
+
#include <ATen/ops/erfc.h>
|
| 691 |
+
#include <ATen/ops/erfinv.h>
|
| 692 |
+
#include <ATen/ops/exp.h>
|
| 693 |
+
#include <ATen/ops/exp2.h>
|
| 694 |
+
#include <ATen/ops/expand.h>
|
| 695 |
+
#include <ATen/ops/expand_as.h>
|
| 696 |
+
#include <ATen/ops/expand_copy.h>
|
| 697 |
+
#include <ATen/ops/expm1.h>
|
| 698 |
+
#include <ATen/ops/exponential.h>
|
| 699 |
+
#include <ATen/ops/eye.h>
|
| 700 |
+
#include <ATen/ops/fake_quantize_per_channel_affine.h>
|
| 701 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask.h>
|
| 702 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_backward.h>
|
| 703 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine.h>
|
| 704 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask.h>
|
| 705 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_backward.h>
|
| 706 |
+
#include <ATen/ops/fbgemm_linear_fp16_weight.h>
|
| 707 |
+
#include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation.h>
|
| 708 |
+
#include <ATen/ops/fbgemm_linear_int8_weight.h>
|
| 709 |
+
#include <ATen/ops/fbgemm_linear_int8_weight_fp32_activation.h>
|
| 710 |
+
#include <ATen/ops/fbgemm_linear_quantize_weight.h>
|
| 711 |
+
#include <ATen/ops/fbgemm_pack_gemm_matrix_fp16.h>
|
| 712 |
+
#include <ATen/ops/fbgemm_pack_quantized_matrix.h>
|
| 713 |
+
#include <ATen/ops/feature_alpha_dropout.h>
|
| 714 |
+
#include <ATen/ops/feature_dropout.h>
|
| 715 |
+
#include <ATen/ops/fft_fft.h>
|
| 716 |
+
#include <ATen/ops/fft_fft2.h>
|
| 717 |
+
#include <ATen/ops/fft_fftfreq.h>
|
| 718 |
+
#include <ATen/ops/fft_fftn.h>
|
| 719 |
+
#include <ATen/ops/fft_fftshift.h>
|
| 720 |
+
#include <ATen/ops/fft_hfft.h>
|
| 721 |
+
#include <ATen/ops/fft_hfft2.h>
|
| 722 |
+
#include <ATen/ops/fft_hfftn.h>
|
| 723 |
+
#include <ATen/ops/fft_ifft.h>
|
| 724 |
+
#include <ATen/ops/fft_ifft2.h>
|
| 725 |
+
#include <ATen/ops/fft_ifftn.h>
|
| 726 |
+
#include <ATen/ops/fft_ifftshift.h>
|
| 727 |
+
#include <ATen/ops/fft_ihfft.h>
|
| 728 |
+
#include <ATen/ops/fft_ihfft2.h>
|
| 729 |
+
#include <ATen/ops/fft_ihfftn.h>
|
| 730 |
+
#include <ATen/ops/fft_irfft.h>
|
| 731 |
+
#include <ATen/ops/fft_irfft2.h>
|
| 732 |
+
#include <ATen/ops/fft_irfftn.h>
|
| 733 |
+
#include <ATen/ops/fft_rfft.h>
|
| 734 |
+
#include <ATen/ops/fft_rfft2.h>
|
| 735 |
+
#include <ATen/ops/fft_rfftfreq.h>
|
| 736 |
+
#include <ATen/ops/fft_rfftn.h>
|
| 737 |
+
#include <ATen/ops/fill.h>
|
| 738 |
+
#include <ATen/ops/fill_diagonal.h>
|
| 739 |
+
#include <ATen/ops/fix.h>
|
| 740 |
+
#include <ATen/ops/flatten.h>
|
| 741 |
+
#include <ATen/ops/flatten_dense_tensors.h>
|
| 742 |
+
#include <ATen/ops/flip.h>
|
| 743 |
+
#include <ATen/ops/fliplr.h>
|
| 744 |
+
#include <ATen/ops/flipud.h>
|
| 745 |
+
#include <ATen/ops/float_power.h>
|
| 746 |
+
#include <ATen/ops/floor.h>
|
| 747 |
+
#include <ATen/ops/floor_divide.h>
|
| 748 |
+
#include <ATen/ops/fmax.h>
|
| 749 |
+
#include <ATen/ops/fmin.h>
|
| 750 |
+
#include <ATen/ops/fmod.h>
|
| 751 |
+
#include <ATen/ops/frac.h>
|
| 752 |
+
#include <ATen/ops/fractional_max_pool2d.h>
|
| 753 |
+
#include <ATen/ops/fractional_max_pool2d_backward.h>
|
| 754 |
+
#include <ATen/ops/fractional_max_pool3d.h>
|
| 755 |
+
#include <ATen/ops/fractional_max_pool3d_backward.h>
|
| 756 |
+
#include <ATen/ops/frexp.h>
|
| 757 |
+
#include <ATen/ops/frobenius_norm.h>
|
| 758 |
+
#include <ATen/ops/from_file.h>
|
| 759 |
+
#include <ATen/ops/full.h>
|
| 760 |
+
#include <ATen/ops/full_like.h>
|
| 761 |
+
#include <ATen/ops/fused_moving_avg_obs_fake_quant.h>
|
| 762 |
+
#include <ATen/ops/gather.h>
|
| 763 |
+
#include <ATen/ops/gather_backward.h>
|
| 764 |
+
#include <ATen/ops/gcd.h>
|
| 765 |
+
#include <ATen/ops/ge.h>
|
| 766 |
+
#include <ATen/ops/gelu.h>
|
| 767 |
+
#include <ATen/ops/gelu_backward.h>
|
| 768 |
+
#include <ATen/ops/geometric.h>
|
| 769 |
+
#include <ATen/ops/geqrf.h>
|
| 770 |
+
#include <ATen/ops/ger.h>
|
| 771 |
+
#include <ATen/ops/glu.h>
|
| 772 |
+
#include <ATen/ops/glu_backward.h>
|
| 773 |
+
#include <ATen/ops/glu_backward_jvp.h>
|
| 774 |
+
#include <ATen/ops/glu_jvp.h>
|
| 775 |
+
#include <ATen/ops/gradient.h>
|
| 776 |
+
#include <ATen/ops/greater.h>
|
| 777 |
+
#include <ATen/ops/greater_equal.h>
|
| 778 |
+
#include <ATen/ops/grid_sampler.h>
|
| 779 |
+
#include <ATen/ops/grid_sampler_2d.h>
|
| 780 |
+
#include <ATen/ops/grid_sampler_2d_backward.h>
|
| 781 |
+
#include <ATen/ops/grid_sampler_3d.h>
|
| 782 |
+
#include <ATen/ops/grid_sampler_3d_backward.h>
|
| 783 |
+
#include <ATen/ops/group_norm.h>
|
| 784 |
+
#include <ATen/ops/gru.h>
|
| 785 |
+
#include <ATen/ops/gru_cell.h>
|
| 786 |
+
#include <ATen/ops/gt.h>
|
| 787 |
+
#include <ATen/ops/hamming_window.h>
|
| 788 |
+
#include <ATen/ops/hann_window.h>
|
| 789 |
+
#include <ATen/ops/hardshrink.h>
|
| 790 |
+
#include <ATen/ops/hardshrink_backward.h>
|
| 791 |
+
#include <ATen/ops/hardsigmoid.h>
|
| 792 |
+
#include <ATen/ops/hardsigmoid_backward.h>
|
| 793 |
+
#include <ATen/ops/hardswish.h>
|
| 794 |
+
#include <ATen/ops/hardswish_backward.h>
|
| 795 |
+
#include <ATen/ops/hardtanh.h>
|
| 796 |
+
#include <ATen/ops/hardtanh_backward.h>
|
| 797 |
+
#include <ATen/ops/hash_tensor.h>
|
| 798 |
+
#include <ATen/ops/heaviside.h>
|
| 799 |
+
#include <ATen/ops/hinge_embedding_loss.h>
|
| 800 |
+
#include <ATen/ops/histc.h>
|
| 801 |
+
#include <ATen/ops/histogram.h>
|
| 802 |
+
#include <ATen/ops/histogramdd.h>
|
| 803 |
+
#include <ATen/ops/hsplit.h>
|
| 804 |
+
#include <ATen/ops/hspmm.h>
|
| 805 |
+
#include <ATen/ops/hstack.h>
|
| 806 |
+
#include <ATen/ops/huber_loss.h>
|
| 807 |
+
#include <ATen/ops/huber_loss_backward.h>
|
| 808 |
+
#include <ATen/ops/hypot.h>
|
| 809 |
+
#include <ATen/ops/i0.h>
|
| 810 |
+
#include <ATen/ops/igamma.h>
|
| 811 |
+
#include <ATen/ops/igammac.h>
|
| 812 |
+
#include <ATen/ops/im2col.h>
|
| 813 |
+
#include <ATen/ops/imag.h>
|
| 814 |
+
#include <ATen/ops/index.h>
|
| 815 |
+
#include <ATen/ops/index_add.h>
|
| 816 |
+
#include <ATen/ops/index_copy.h>
|
| 817 |
+
#include <ATen/ops/index_fill.h>
|
| 818 |
+
#include <ATen/ops/index_put.h>
|
| 819 |
+
#include <ATen/ops/index_reduce.h>
|
| 820 |
+
#include <ATen/ops/index_select.h>
|
| 821 |
+
#include <ATen/ops/index_select_backward.h>
|
| 822 |
+
#include <ATen/ops/indices.h>
|
| 823 |
+
#include <ATen/ops/indices_copy.h>
|
| 824 |
+
#include <ATen/ops/infinitely_differentiable_gelu_backward.h>
|
| 825 |
+
#include <ATen/ops/inner.h>
|
| 826 |
+
#include <ATen/ops/instance_norm.h>
|
| 827 |
+
#include <ATen/ops/int_repr.h>
|
| 828 |
+
#include <ATen/ops/inverse.h>
|
| 829 |
+
#include <ATen/ops/is_coalesced.h>
|
| 830 |
+
#include <ATen/ops/is_complex.h>
|
| 831 |
+
#include <ATen/ops/is_conj.h>
|
| 832 |
+
#include <ATen/ops/is_distributed.h>
|
| 833 |
+
#include <ATen/ops/is_floating_point.h>
|
| 834 |
+
#include <ATen/ops/is_inference.h>
|
| 835 |
+
#include <ATen/ops/is_leaf.h>
|
| 836 |
+
#include <ATen/ops/is_neg.h>
|
| 837 |
+
#include <ATen/ops/is_nonzero.h>
|
| 838 |
+
#include <ATen/ops/is_pinned.h>
|
| 839 |
+
#include <ATen/ops/is_same_size.h>
|
| 840 |
+
#include <ATen/ops/is_set_to.h>
|
| 841 |
+
#include <ATen/ops/is_signed.h>
|
| 842 |
+
#include <ATen/ops/is_vulkan_available.h>
|
| 843 |
+
#include <ATen/ops/isclose.h>
|
| 844 |
+
#include <ATen/ops/isfinite.h>
|
| 845 |
+
#include <ATen/ops/isin.h>
|
| 846 |
+
#include <ATen/ops/isinf.h>
|
| 847 |
+
#include <ATen/ops/isnan.h>
|
| 848 |
+
#include <ATen/ops/isneginf.h>
|
| 849 |
+
#include <ATen/ops/isposinf.h>
|
| 850 |
+
#include <ATen/ops/isreal.h>
|
| 851 |
+
#include <ATen/ops/istft.h>
|
| 852 |
+
#include <ATen/ops/item.h>
|
| 853 |
+
#include <ATen/ops/kaiser_window.h>
|
| 854 |
+
#include <ATen/ops/kl_div.h>
|
| 855 |
+
#include <ATen/ops/kron.h>
|
| 856 |
+
#include <ATen/ops/kthvalue.h>
|
| 857 |
+
#include <ATen/ops/l1_loss.h>
|
| 858 |
+
#include <ATen/ops/layer_norm.h>
|
| 859 |
+
#include <ATen/ops/lcm.h>
|
| 860 |
+
#include <ATen/ops/ldexp.h>
|
| 861 |
+
#include <ATen/ops/le.h>
|
| 862 |
+
#include <ATen/ops/leaky_relu.h>
|
| 863 |
+
#include <ATen/ops/leaky_relu_backward.h>
|
| 864 |
+
#include <ATen/ops/lerp.h>
|
| 865 |
+
#include <ATen/ops/less.h>
|
| 866 |
+
#include <ATen/ops/less_equal.h>
|
| 867 |
+
#include <ATen/ops/lgamma.h>
|
| 868 |
+
#include <ATen/ops/lift.h>
|
| 869 |
+
#include <ATen/ops/lift_fresh.h>
|
| 870 |
+
#include <ATen/ops/lift_fresh_copy.h>
|
| 871 |
+
#include <ATen/ops/linalg_cholesky.h>
|
| 872 |
+
#include <ATen/ops/linalg_cholesky_ex.h>
|
| 873 |
+
#include <ATen/ops/linalg_cond.h>
|
| 874 |
+
#include <ATen/ops/linalg_cross.h>
|
| 875 |
+
#include <ATen/ops/linalg_det.h>
|
| 876 |
+
#include <ATen/ops/linalg_diagonal.h>
|
| 877 |
+
#include <ATen/ops/linalg_eig.h>
|
| 878 |
+
#include <ATen/ops/linalg_eigh.h>
|
| 879 |
+
#include <ATen/ops/linalg_eigvals.h>
|
| 880 |
+
#include <ATen/ops/linalg_eigvalsh.h>
|
| 881 |
+
#include <ATen/ops/linalg_householder_product.h>
|
| 882 |
+
#include <ATen/ops/linalg_inv.h>
|
| 883 |
+
#include <ATen/ops/linalg_inv_ex.h>
|
| 884 |
+
#include <ATen/ops/linalg_ldl_factor.h>
|
| 885 |
+
#include <ATen/ops/linalg_ldl_factor_ex.h>
|
| 886 |
+
#include <ATen/ops/linalg_ldl_solve.h>
|
| 887 |
+
#include <ATen/ops/linalg_lstsq.h>
|
| 888 |
+
#include <ATen/ops/linalg_lu.h>
|
| 889 |
+
#include <ATen/ops/linalg_lu_factor.h>
|
| 890 |
+
#include <ATen/ops/linalg_lu_factor_ex.h>
|
| 891 |
+
#include <ATen/ops/linalg_lu_solve.h>
|
| 892 |
+
#include <ATen/ops/linalg_matmul.h>
|
| 893 |
+
#include <ATen/ops/linalg_matrix_exp.h>
|
| 894 |
+
#include <ATen/ops/linalg_matrix_norm.h>
|
| 895 |
+
#include <ATen/ops/linalg_matrix_power.h>
|
| 896 |
+
#include <ATen/ops/linalg_matrix_rank.h>
|
| 897 |
+
#include <ATen/ops/linalg_multi_dot.h>
|
| 898 |
+
#include <ATen/ops/linalg_norm.h>
|
| 899 |
+
#include <ATen/ops/linalg_pinv.h>
|
| 900 |
+
#include <ATen/ops/linalg_qr.h>
|
| 901 |
+
#include <ATen/ops/linalg_slogdet.h>
|
| 902 |
+
#include <ATen/ops/linalg_solve.h>
|
| 903 |
+
#include <ATen/ops/linalg_solve_ex.h>
|
| 904 |
+
#include <ATen/ops/linalg_solve_triangular.h>
|
| 905 |
+
#include <ATen/ops/linalg_svd.h>
|
| 906 |
+
#include <ATen/ops/linalg_svdvals.h>
|
| 907 |
+
#include <ATen/ops/linalg_tensorinv.h>
|
| 908 |
+
#include <ATen/ops/linalg_tensorsolve.h>
|
| 909 |
+
#include <ATen/ops/linalg_vander.h>
|
| 910 |
+
#include <ATen/ops/linalg_vecdot.h>
|
| 911 |
+
#include <ATen/ops/linalg_vector_norm.h>
|
| 912 |
+
#include <ATen/ops/linear.h>
|
| 913 |
+
#include <ATen/ops/linear_backward.h>
|
| 914 |
+
#include <ATen/ops/linspace.h>
|
| 915 |
+
#include <ATen/ops/log.h>
|
| 916 |
+
#include <ATen/ops/log10.h>
|
| 917 |
+
#include <ATen/ops/log1p.h>
|
| 918 |
+
#include <ATen/ops/log2.h>
|
| 919 |
+
#include <ATen/ops/log_normal.h>
|
| 920 |
+
#include <ATen/ops/log_sigmoid.h>
|
| 921 |
+
#include <ATen/ops/log_sigmoid_backward.h>
|
| 922 |
+
#include <ATen/ops/log_sigmoid_forward.h>
|
| 923 |
+
#include <ATen/ops/log_softmax.h>
|
| 924 |
+
#include <ATen/ops/logaddexp.h>
|
| 925 |
+
#include <ATen/ops/logaddexp2.h>
|
| 926 |
+
#include <ATen/ops/logcumsumexp.h>
|
| 927 |
+
#include <ATen/ops/logdet.h>
|
| 928 |
+
#include <ATen/ops/logical_and.h>
|
| 929 |
+
#include <ATen/ops/logical_not.h>
|
| 930 |
+
#include <ATen/ops/logical_or.h>
|
| 931 |
+
#include <ATen/ops/logical_xor.h>
|
| 932 |
+
#include <ATen/ops/logit.h>
|
| 933 |
+
#include <ATen/ops/logit_backward.h>
|
| 934 |
+
#include <ATen/ops/logspace.h>
|
| 935 |
+
#include <ATen/ops/logsumexp.h>
|
| 936 |
+
#include <ATen/ops/lshift.h>
|
| 937 |
+
#include <ATen/ops/lstm.h>
|
| 938 |
+
#include <ATen/ops/lstm_cell.h>
|
| 939 |
+
#include <ATen/ops/lstm_mps_backward.h>
|
| 940 |
+
#include <ATen/ops/lt.h>
|
| 941 |
+
#include <ATen/ops/lu_solve.h>
|
| 942 |
+
#include <ATen/ops/lu_unpack.h>
|
| 943 |
+
#include <ATen/ops/mH.h>
|
| 944 |
+
#include <ATen/ops/mT.h>
|
| 945 |
+
#include <ATen/ops/margin_ranking_loss.h>
|
| 946 |
+
#include <ATen/ops/masked_fill.h>
|
| 947 |
+
#include <ATen/ops/masked_scatter.h>
|
| 948 |
+
#include <ATen/ops/masked_scatter_backward.h>
|
| 949 |
+
#include <ATen/ops/masked_select.h>
|
| 950 |
+
#include <ATen/ops/masked_select_backward.h>
|
| 951 |
+
#include <ATen/ops/matmul.h>
|
| 952 |
+
#include <ATen/ops/matmul_backward.h>
|
| 953 |
+
#include <ATen/ops/matrix_H.h>
|
| 954 |
+
#include <ATen/ops/matrix_exp.h>
|
| 955 |
+
#include <ATen/ops/matrix_exp_backward.h>
|
| 956 |
+
#include <ATen/ops/matrix_power.h>
|
| 957 |
+
#include <ATen/ops/max.h>
|
| 958 |
+
#include <ATen/ops/max_pool1d.h>
|
| 959 |
+
#include <ATen/ops/max_pool1d_with_indices.h>
|
| 960 |
+
#include <ATen/ops/max_pool2d.h>
|
| 961 |
+
#include <ATen/ops/max_pool2d_backward.h>
|
| 962 |
+
#include <ATen/ops/max_pool2d_with_indices.h>
|
| 963 |
+
#include <ATen/ops/max_pool2d_with_indices_backward.h>
|
| 964 |
+
#include <ATen/ops/max_pool3d.h>
|
| 965 |
+
#include <ATen/ops/max_pool3d_with_indices.h>
|
| 966 |
+
#include <ATen/ops/max_pool3d_with_indices_backward.h>
|
| 967 |
+
#include <ATen/ops/max_unpool2d.h>
|
| 968 |
+
#include <ATen/ops/max_unpool3d.h>
|
| 969 |
+
#include <ATen/ops/maximum.h>
|
| 970 |
+
#include <ATen/ops/mean.h>
|
| 971 |
+
#include <ATen/ops/median.h>
|
| 972 |
+
#include <ATen/ops/meshgrid.h>
|
| 973 |
+
#include <ATen/ops/min.h>
|
| 974 |
+
#include <ATen/ops/minimum.h>
|
| 975 |
+
#include <ATen/ops/miopen_batch_norm.h>
|
| 976 |
+
#include <ATen/ops/miopen_batch_norm_backward.h>
|
| 977 |
+
#include <ATen/ops/miopen_convolution.h>
|
| 978 |
+
#include <ATen/ops/miopen_convolution_add_relu.h>
|
| 979 |
+
#include <ATen/ops/miopen_convolution_relu.h>
|
| 980 |
+
#include <ATen/ops/miopen_convolution_transpose.h>
|
| 981 |
+
#include <ATen/ops/miopen_depthwise_convolution.h>
|
| 982 |
+
#include <ATen/ops/miopen_rnn.h>
|
| 983 |
+
#include <ATen/ops/miopen_rnn_backward.h>
|
| 984 |
+
#include <ATen/ops/mish.h>
|
| 985 |
+
#include <ATen/ops/mish_backward.h>
|
| 986 |
+
#include <ATen/ops/mkldnn_adaptive_avg_pool2d.h>
|
| 987 |
+
#include <ATen/ops/mkldnn_adaptive_avg_pool2d_backward.h>
|
| 988 |
+
#include <ATen/ops/mkldnn_convolution.h>
|
| 989 |
+
#include <ATen/ops/mkldnn_linear.h>
|
| 990 |
+
#include <ATen/ops/mkldnn_linear_backward.h>
|
| 991 |
+
#include <ATen/ops/mkldnn_linear_backward_input.h>
|
| 992 |
+
#include <ATen/ops/mkldnn_linear_backward_weights.h>
|
| 993 |
+
#include <ATen/ops/mkldnn_max_pool2d.h>
|
| 994 |
+
#include <ATen/ops/mkldnn_max_pool2d_backward.h>
|
| 995 |
+
#include <ATen/ops/mkldnn_max_pool3d.h>
|
| 996 |
+
#include <ATen/ops/mkldnn_max_pool3d_backward.h>
|
| 997 |
+
#include <ATen/ops/mkldnn_reorder_conv2d_weight.h>
|
| 998 |
+
#include <ATen/ops/mkldnn_reorder_conv3d_weight.h>
|
| 999 |
+
#include <ATen/ops/mkldnn_rnn_layer.h>
|
| 1000 |
+
#include <ATen/ops/mkldnn_rnn_layer_backward.h>
|
| 1001 |
+
#include <ATen/ops/mm.h>
|
| 1002 |
+
#include <ATen/ops/mode.h>
|
| 1003 |
+
#include <ATen/ops/moveaxis.h>
|
| 1004 |
+
#include <ATen/ops/movedim.h>
|
| 1005 |
+
#include <ATen/ops/mps_convolution_backward.h>
|
| 1006 |
+
#include <ATen/ops/mps_convolution_transpose_backward.h>
|
| 1007 |
+
#include <ATen/ops/mse_loss.h>
|
| 1008 |
+
#include <ATen/ops/mse_loss_backward.h>
|
| 1009 |
+
#include <ATen/ops/msort.h>
|
| 1010 |
+
#include <ATen/ops/mul.h>
|
| 1011 |
+
#include <ATen/ops/multi_margin_loss.h>
|
| 1012 |
+
#include <ATen/ops/multi_margin_loss_backward.h>
|
| 1013 |
+
#include <ATen/ops/multilabel_margin_loss.h>
|
| 1014 |
+
#include <ATen/ops/multilabel_margin_loss_backward.h>
|
| 1015 |
+
#include <ATen/ops/multilabel_margin_loss_forward.h>
|
| 1016 |
+
#include <ATen/ops/multinomial.h>
|
| 1017 |
+
#include <ATen/ops/multiply.h>
|
| 1018 |
+
#include <ATen/ops/mv.h>
|
| 1019 |
+
#include <ATen/ops/mvlgamma.h>
|
| 1020 |
+
#include <ATen/ops/nan_to_num.h>
|
| 1021 |
+
#include <ATen/ops/nanmean.h>
|
| 1022 |
+
#include <ATen/ops/nanmedian.h>
|
| 1023 |
+
#include <ATen/ops/nanquantile.h>
|
| 1024 |
+
#include <ATen/ops/nansum.h>
|
| 1025 |
+
#include <ATen/ops/narrow.h>
|
| 1026 |
+
#include <ATen/ops/narrow_copy.h>
|
| 1027 |
+
#include <ATen/ops/native_batch_norm.h>
|
| 1028 |
+
#include <ATen/ops/native_batch_norm_backward.h>
|
| 1029 |
+
#include <ATen/ops/native_channel_shuffle.h>
|
| 1030 |
+
#include <ATen/ops/native_dropout.h>
|
| 1031 |
+
#include <ATen/ops/native_dropout_backward.h>
|
| 1032 |
+
#include <ATen/ops/native_group_norm.h>
|
| 1033 |
+
#include <ATen/ops/native_group_norm_backward.h>
|
| 1034 |
+
#include <ATen/ops/native_layer_norm.h>
|
| 1035 |
+
#include <ATen/ops/native_layer_norm_backward.h>
|
| 1036 |
+
#include <ATen/ops/native_norm.h>
|
| 1037 |
+
#include <ATen/ops/ne.h>
|
| 1038 |
+
#include <ATen/ops/neg.h>
|
| 1039 |
+
#include <ATen/ops/negative.h>
|
| 1040 |
+
#include <ATen/ops/nested_to_padded_tensor.h>
|
| 1041 |
+
#include <ATen/ops/new_empty.h>
|
| 1042 |
+
#include <ATen/ops/new_empty_strided.h>
|
| 1043 |
+
#include <ATen/ops/new_full.h>
|
| 1044 |
+
#include <ATen/ops/new_ones.h>
|
| 1045 |
+
#include <ATen/ops/new_zeros.h>
|
| 1046 |
+
#include <ATen/ops/nextafter.h>
|
| 1047 |
+
#include <ATen/ops/nll_loss.h>
|
| 1048 |
+
#include <ATen/ops/nll_loss2d.h>
|
| 1049 |
+
#include <ATen/ops/nll_loss2d_backward.h>
|
| 1050 |
+
#include <ATen/ops/nll_loss2d_forward.h>
|
| 1051 |
+
#include <ATen/ops/nll_loss_backward.h>
|
| 1052 |
+
#include <ATen/ops/nll_loss_forward.h>
|
| 1053 |
+
#include <ATen/ops/nll_loss_nd.h>
|
| 1054 |
+
#include <ATen/ops/nonzero.h>
|
| 1055 |
+
#include <ATen/ops/nonzero_numpy.h>
|
| 1056 |
+
#include <ATen/ops/nonzero_static.h>
|
| 1057 |
+
#include <ATen/ops/norm.h>
|
| 1058 |
+
#include <ATen/ops/norm_except_dim.h>
|
| 1059 |
+
#include <ATen/ops/normal.h>
|
| 1060 |
+
#include <ATen/ops/not_equal.h>
|
| 1061 |
+
#include <ATen/ops/nuclear_norm.h>
|
| 1062 |
+
#include <ATen/ops/numpy_T.h>
|
| 1063 |
+
#include <ATen/ops/one_hot.h>
|
| 1064 |
+
#include <ATen/ops/ones.h>
|
| 1065 |
+
#include <ATen/ops/ones_like.h>
|
| 1066 |
+
#include <ATen/ops/or.h>
|
| 1067 |
+
#include <ATen/ops/orgqr.h>
|
| 1068 |
+
#include <ATen/ops/ormqr.h>
|
| 1069 |
+
#include <ATen/ops/outer.h>
|
| 1070 |
+
#include <ATen/ops/output_nr.h>
|
| 1071 |
+
#include <ATen/ops/pad.h>
|
| 1072 |
+
#include <ATen/ops/pad_sequence.h>
|
| 1073 |
+
#include <ATen/ops/pairwise_distance.h>
|
| 1074 |
+
#include <ATen/ops/pdist.h>
|
| 1075 |
+
#include <ATen/ops/permute.h>
|
| 1076 |
+
#include <ATen/ops/permute_copy.h>
|
| 1077 |
+
#include <ATen/ops/pin_memory.h>
|
| 1078 |
+
#include <ATen/ops/pinverse.h>
|
| 1079 |
+
#include <ATen/ops/pixel_shuffle.h>
|
| 1080 |
+
#include <ATen/ops/pixel_unshuffle.h>
|
| 1081 |
+
#include <ATen/ops/poisson.h>
|
| 1082 |
+
#include <ATen/ops/poisson_nll_loss.h>
|
| 1083 |
+
#include <ATen/ops/polar.h>
|
| 1084 |
+
#include <ATen/ops/polygamma.h>
|
| 1085 |
+
#include <ATen/ops/positive.h>
|
| 1086 |
+
#include <ATen/ops/pow.h>
|
| 1087 |
+
#include <ATen/ops/prelu.h>
|
| 1088 |
+
#include <ATen/ops/prod.h>
|
| 1089 |
+
#include <ATen/ops/promote_types.h>
|
| 1090 |
+
#include <ATen/ops/put.h>
|
| 1091 |
+
#include <ATen/ops/q_per_channel_axis.h>
|
| 1092 |
+
#include <ATen/ops/q_per_channel_scales.h>
|
| 1093 |
+
#include <ATen/ops/q_per_channel_zero_points.h>
|
| 1094 |
+
#include <ATen/ops/q_scale.h>
|
| 1095 |
+
#include <ATen/ops/q_zero_point.h>
|
| 1096 |
+
#include <ATen/ops/qr.h>
|
| 1097 |
+
#include <ATen/ops/qscheme.h>
|
| 1098 |
+
#include <ATen/ops/quantile.h>
|
| 1099 |
+
#include <ATen/ops/quantize_per_channel.h>
|
| 1100 |
+
#include <ATen/ops/quantize_per_tensor.h>
|
| 1101 |
+
#include <ATen/ops/quantize_per_tensor_dynamic.h>
|
| 1102 |
+
#include <ATen/ops/quantized_batch_norm.h>
|
| 1103 |
+
#include <ATen/ops/quantized_gru_cell.h>
|
| 1104 |
+
#include <ATen/ops/quantized_lstm_cell.h>
|
| 1105 |
+
#include <ATen/ops/quantized_max_pool1d.h>
|
| 1106 |
+
#include <ATen/ops/quantized_max_pool2d.h>
|
| 1107 |
+
#include <ATen/ops/quantized_max_pool3d.h>
|
| 1108 |
+
#include <ATen/ops/quantized_rnn_relu_cell.h>
|
| 1109 |
+
#include <ATen/ops/quantized_rnn_tanh_cell.h>
|
| 1110 |
+
#include <ATen/ops/rad2deg.h>
|
| 1111 |
+
#include <ATen/ops/rand.h>
|
| 1112 |
+
#include <ATen/ops/rand_like.h>
|
| 1113 |
+
#include <ATen/ops/randint.h>
|
| 1114 |
+
#include <ATen/ops/randint_like.h>
|
| 1115 |
+
#include <ATen/ops/randn.h>
|
| 1116 |
+
#include <ATen/ops/randn_like.h>
|
| 1117 |
+
#include <ATen/ops/random.h>
|
| 1118 |
+
#include <ATen/ops/randperm.h>
|
| 1119 |
+
#include <ATen/ops/range.h>
|
| 1120 |
+
#include <ATen/ops/ravel.h>
|
| 1121 |
+
#include <ATen/ops/real.h>
|
| 1122 |
+
#include <ATen/ops/reciprocal.h>
|
| 1123 |
+
#include <ATen/ops/record_stream.h>
|
| 1124 |
+
#include <ATen/ops/refine_names.h>
|
| 1125 |
+
#include <ATen/ops/reflection_pad1d.h>
|
| 1126 |
+
#include <ATen/ops/reflection_pad1d_backward.h>
|
| 1127 |
+
#include <ATen/ops/reflection_pad2d.h>
|
| 1128 |
+
#include <ATen/ops/reflection_pad2d_backward.h>
|
| 1129 |
+
#include <ATen/ops/reflection_pad3d.h>
|
| 1130 |
+
#include <ATen/ops/reflection_pad3d_backward.h>
|
| 1131 |
+
#include <ATen/ops/relu.h>
|
| 1132 |
+
#include <ATen/ops/relu6.h>
|
| 1133 |
+
#include <ATen/ops/remainder.h>
|
| 1134 |
+
#include <ATen/ops/rename.h>
|
| 1135 |
+
#include <ATen/ops/renorm.h>
|
| 1136 |
+
#include <ATen/ops/repeat.h>
|
| 1137 |
+
#include <ATen/ops/repeat_interleave.h>
|
| 1138 |
+
#include <ATen/ops/replication_pad1d.h>
|
| 1139 |
+
#include <ATen/ops/replication_pad1d_backward.h>
|
| 1140 |
+
#include <ATen/ops/replication_pad2d.h>
|
| 1141 |
+
#include <ATen/ops/replication_pad2d_backward.h>
|
| 1142 |
+
#include <ATen/ops/replication_pad3d.h>
|
| 1143 |
+
#include <ATen/ops/replication_pad3d_backward.h>
|
| 1144 |
+
#include <ATen/ops/requires_grad.h>
|
| 1145 |
+
#include <ATen/ops/reshape.h>
|
| 1146 |
+
#include <ATen/ops/reshape_as.h>
|
| 1147 |
+
#include <ATen/ops/resize.h>
|
| 1148 |
+
#include <ATen/ops/resize_as.h>
|
| 1149 |
+
#include <ATen/ops/resize_as_sparse.h>
|
| 1150 |
+
#include <ATen/ops/resolve_conj.h>
|
| 1151 |
+
#include <ATen/ops/resolve_neg.h>
|
| 1152 |
+
#include <ATen/ops/result_type.h>
|
| 1153 |
+
#include <ATen/ops/retain_grad.h>
|
| 1154 |
+
#include <ATen/ops/retains_grad.h>
|
| 1155 |
+
#include <ATen/ops/rms_norm.h>
|
| 1156 |
+
#include <ATen/ops/rnn_relu.h>
|
| 1157 |
+
#include <ATen/ops/rnn_relu_cell.h>
|
| 1158 |
+
#include <ATen/ops/rnn_tanh.h>
|
| 1159 |
+
#include <ATen/ops/rnn_tanh_cell.h>
|
| 1160 |
+
#include <ATen/ops/roll.h>
|
| 1161 |
+
#include <ATen/ops/rot90.h>
|
| 1162 |
+
#include <ATen/ops/round.h>
|
| 1163 |
+
#include <ATen/ops/row_indices.h>
|
| 1164 |
+
#include <ATen/ops/row_indices_copy.h>
|
| 1165 |
+
#include <ATen/ops/row_stack.h>
|
| 1166 |
+
#include <ATen/ops/rrelu.h>
|
| 1167 |
+
#include <ATen/ops/rrelu_with_noise.h>
|
| 1168 |
+
#include <ATen/ops/rrelu_with_noise_backward.h>
|
| 1169 |
+
#include <ATen/ops/rshift.h>
|
| 1170 |
+
#include <ATen/ops/rsqrt.h>
|
| 1171 |
+
#include <ATen/ops/rsub.h>
|
| 1172 |
+
#include <ATen/ops/scalar_tensor.h>
|
| 1173 |
+
#include <ATen/ops/scaled_dot_product_attention.h>
|
| 1174 |
+
#include <ATen/ops/scatter.h>
|
| 1175 |
+
#include <ATen/ops/scatter_add.h>
|
| 1176 |
+
#include <ATen/ops/scatter_reduce.h>
|
| 1177 |
+
#include <ATen/ops/searchsorted.h>
|
| 1178 |
+
#include <ATen/ops/segment_reduce.h>
|
| 1179 |
+
#include <ATen/ops/select.h>
|
| 1180 |
+
#include <ATen/ops/select_backward.h>
|
| 1181 |
+
#include <ATen/ops/select_copy.h>
|
| 1182 |
+
#include <ATen/ops/select_scatter.h>
|
| 1183 |
+
#include <ATen/ops/selu.h>
|
| 1184 |
+
#include <ATen/ops/set.h>
|
| 1185 |
+
#include <ATen/ops/set_data.h>
|
| 1186 |
+
#include <ATen/ops/sgn.h>
|
| 1187 |
+
#include <ATen/ops/sigmoid.h>
|
| 1188 |
+
#include <ATen/ops/sigmoid_backward.h>
|
| 1189 |
+
#include <ATen/ops/sign.h>
|
| 1190 |
+
#include <ATen/ops/signbit.h>
|
| 1191 |
+
#include <ATen/ops/silu.h>
|
| 1192 |
+
#include <ATen/ops/silu_backward.h>
|
| 1193 |
+
#include <ATen/ops/sin.h>
|
| 1194 |
+
#include <ATen/ops/sinc.h>
|
| 1195 |
+
#include <ATen/ops/sinh.h>
|
| 1196 |
+
#include <ATen/ops/size.h>
|
| 1197 |
+
#include <ATen/ops/slice.h>
|
| 1198 |
+
#include <ATen/ops/slice_backward.h>
|
| 1199 |
+
#include <ATen/ops/slice_copy.h>
|
| 1200 |
+
#include <ATen/ops/slice_inverse.h>
|
| 1201 |
+
#include <ATen/ops/slice_scatter.h>
|
| 1202 |
+
#include <ATen/ops/slogdet.h>
|
| 1203 |
+
#include <ATen/ops/slow_conv3d.h>
|
| 1204 |
+
#include <ATen/ops/slow_conv3d_forward.h>
|
| 1205 |
+
#include <ATen/ops/slow_conv_dilated2d.h>
|
| 1206 |
+
#include <ATen/ops/slow_conv_dilated3d.h>
|
| 1207 |
+
#include <ATen/ops/slow_conv_transpose2d.h>
|
| 1208 |
+
#include <ATen/ops/slow_conv_transpose3d.h>
|
| 1209 |
+
#include <ATen/ops/smm.h>
|
| 1210 |
+
#include <ATen/ops/smooth_l1_loss.h>
|
| 1211 |
+
#include <ATen/ops/smooth_l1_loss_backward.h>
|
| 1212 |
+
#include <ATen/ops/soft_margin_loss.h>
|
| 1213 |
+
#include <ATen/ops/soft_margin_loss_backward.h>
|
| 1214 |
+
#include <ATen/ops/softmax.h>
|
| 1215 |
+
#include <ATen/ops/softplus.h>
|
| 1216 |
+
#include <ATen/ops/softplus_backward.h>
|
| 1217 |
+
#include <ATen/ops/softshrink.h>
|
| 1218 |
+
#include <ATen/ops/softshrink_backward.h>
|
| 1219 |
+
#include <ATen/ops/sort.h>
|
| 1220 |
+
#include <ATen/ops/sparse_bsc_tensor.h>
|
| 1221 |
+
#include <ATen/ops/sparse_bsr_tensor.h>
|
| 1222 |
+
#include <ATen/ops/sparse_compressed_tensor.h>
|
| 1223 |
+
#include <ATen/ops/sparse_coo_tensor.h>
|
| 1224 |
+
#include <ATen/ops/sparse_csc_tensor.h>
|
| 1225 |
+
#include <ATen/ops/sparse_csr_tensor.h>
|
| 1226 |
+
#include <ATen/ops/sparse_dim.h>
|
| 1227 |
+
#include <ATen/ops/sparse_mask.h>
|
| 1228 |
+
#include <ATen/ops/sparse_resize.h>
|
| 1229 |
+
#include <ATen/ops/sparse_resize_and_clear.h>
|
| 1230 |
+
#include <ATen/ops/sparse_sampled_addmm.h>
|
| 1231 |
+
#include <ATen/ops/special_airy_ai.h>
|
| 1232 |
+
#include <ATen/ops/special_bessel_j0.h>
|
| 1233 |
+
#include <ATen/ops/special_bessel_j1.h>
|
| 1234 |
+
#include <ATen/ops/special_bessel_y0.h>
|
| 1235 |
+
#include <ATen/ops/special_bessel_y1.h>
|
| 1236 |
+
#include <ATen/ops/special_chebyshev_polynomial_t.h>
|
| 1237 |
+
#include <ATen/ops/special_chebyshev_polynomial_u.h>
|
| 1238 |
+
#include <ATen/ops/special_chebyshev_polynomial_v.h>
|
| 1239 |
+
#include <ATen/ops/special_chebyshev_polynomial_w.h>
|
| 1240 |
+
#include <ATen/ops/special_digamma.h>
|
| 1241 |
+
#include <ATen/ops/special_entr.h>
|
| 1242 |
+
#include <ATen/ops/special_erf.h>
|
| 1243 |
+
#include <ATen/ops/special_erfc.h>
|
| 1244 |
+
#include <ATen/ops/special_erfcx.h>
|
| 1245 |
+
#include <ATen/ops/special_erfinv.h>
|
| 1246 |
+
#include <ATen/ops/special_exp2.h>
|
| 1247 |
+
#include <ATen/ops/special_expit.h>
|
| 1248 |
+
#include <ATen/ops/special_expm1.h>
|
| 1249 |
+
#include <ATen/ops/special_gammainc.h>
|
| 1250 |
+
#include <ATen/ops/special_gammaincc.h>
|
| 1251 |
+
#include <ATen/ops/special_gammaln.h>
|
| 1252 |
+
#include <ATen/ops/special_hermite_polynomial_h.h>
|
| 1253 |
+
#include <ATen/ops/special_hermite_polynomial_he.h>
|
| 1254 |
+
#include <ATen/ops/special_i0.h>
|
| 1255 |
+
#include <ATen/ops/special_i0e.h>
|
| 1256 |
+
#include <ATen/ops/special_i1.h>
|
| 1257 |
+
#include <ATen/ops/special_i1e.h>
|
| 1258 |
+
#include <ATen/ops/special_laguerre_polynomial_l.h>
|
| 1259 |
+
#include <ATen/ops/special_legendre_polynomial_p.h>
|
| 1260 |
+
#include <ATen/ops/special_log1p.h>
|
| 1261 |
+
#include <ATen/ops/special_log_ndtr.h>
|
| 1262 |
+
#include <ATen/ops/special_log_softmax.h>
|
| 1263 |
+
#include <ATen/ops/special_logit.h>
|
| 1264 |
+
#include <ATen/ops/special_logsumexp.h>
|
| 1265 |
+
#include <ATen/ops/special_modified_bessel_i0.h>
|
| 1266 |
+
#include <ATen/ops/special_modified_bessel_i1.h>
|
| 1267 |
+
#include <ATen/ops/special_modified_bessel_k0.h>
|
| 1268 |
+
#include <ATen/ops/special_modified_bessel_k1.h>
|
| 1269 |
+
#include <ATen/ops/special_multigammaln.h>
|
| 1270 |
+
#include <ATen/ops/special_ndtr.h>
|
| 1271 |
+
#include <ATen/ops/special_ndtri.h>
|
| 1272 |
+
#include <ATen/ops/special_polygamma.h>
|
| 1273 |
+
#include <ATen/ops/special_psi.h>
|
| 1274 |
+
#include <ATen/ops/special_round.h>
|
| 1275 |
+
#include <ATen/ops/special_scaled_modified_bessel_k0.h>
|
| 1276 |
+
#include <ATen/ops/special_scaled_modified_bessel_k1.h>
|
| 1277 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_t.h>
|
| 1278 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_u.h>
|
| 1279 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_v.h>
|
| 1280 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_w.h>
|
| 1281 |
+
#include <ATen/ops/special_sinc.h>
|
| 1282 |
+
#include <ATen/ops/special_softmax.h>
|
| 1283 |
+
#include <ATen/ops/special_spherical_bessel_j0.h>
|
| 1284 |
+
#include <ATen/ops/special_xlog1py.h>
|
| 1285 |
+
#include <ATen/ops/special_xlogy.h>
|
| 1286 |
+
#include <ATen/ops/special_zeta.h>
|
| 1287 |
+
#include <ATen/ops/split.h>
|
| 1288 |
+
#include <ATen/ops/split_copy.h>
|
| 1289 |
+
#include <ATen/ops/split_with_sizes.h>
|
| 1290 |
+
#include <ATen/ops/split_with_sizes_copy.h>
|
| 1291 |
+
#include <ATen/ops/sqrt.h>
|
| 1292 |
+
#include <ATen/ops/square.h>
|
| 1293 |
+
#include <ATen/ops/squeeze.h>
|
| 1294 |
+
#include <ATen/ops/squeeze_copy.h>
|
| 1295 |
+
#include <ATen/ops/sspaddmm.h>
|
| 1296 |
+
#include <ATen/ops/stack.h>
|
| 1297 |
+
#include <ATen/ops/std.h>
|
| 1298 |
+
#include <ATen/ops/std_mean.h>
|
| 1299 |
+
#include <ATen/ops/stft.h>
|
| 1300 |
+
#include <ATen/ops/stride.h>
|
| 1301 |
+
#include <ATen/ops/sub.h>
|
| 1302 |
+
#include <ATen/ops/subtract.h>
|
| 1303 |
+
#include <ATen/ops/sum.h>
|
| 1304 |
+
#include <ATen/ops/sum_to_size.h>
|
| 1305 |
+
#include <ATen/ops/svd.h>
|
| 1306 |
+
#include <ATen/ops/swapaxes.h>
|
| 1307 |
+
#include <ATen/ops/swapdims.h>
|
| 1308 |
+
#include <ATen/ops/sym_constrain_range.h>
|
| 1309 |
+
#include <ATen/ops/sym_constrain_range_for_size.h>
|
| 1310 |
+
#include <ATen/ops/sym_is_contiguous.h>
|
| 1311 |
+
#include <ATen/ops/sym_numel.h>
|
| 1312 |
+
#include <ATen/ops/sym_size.h>
|
| 1313 |
+
#include <ATen/ops/sym_storage_offset.h>
|
| 1314 |
+
#include <ATen/ops/sym_stride.h>
|
| 1315 |
+
#include <ATen/ops/t.h>
|
| 1316 |
+
#include <ATen/ops/t_copy.h>
|
| 1317 |
+
#include <ATen/ops/take.h>
|
| 1318 |
+
#include <ATen/ops/take_along_dim.h>
|
| 1319 |
+
#include <ATen/ops/tan.h>
|
| 1320 |
+
#include <ATen/ops/tanh.h>
|
| 1321 |
+
#include <ATen/ops/tanh_backward.h>
|
| 1322 |
+
#include <ATen/ops/tensor_split.h>
|
| 1323 |
+
#include <ATen/ops/tensordot.h>
|
| 1324 |
+
#include <ATen/ops/thnn_conv2d.h>
|
| 1325 |
+
#include <ATen/ops/threshold.h>
|
| 1326 |
+
#include <ATen/ops/threshold_backward.h>
|
| 1327 |
+
#include <ATen/ops/tile.h>
|
| 1328 |
+
#include <ATen/ops/to.h>
|
| 1329 |
+
#include <ATen/ops/to_dense.h>
|
| 1330 |
+
#include <ATen/ops/to_dense_backward.h>
|
| 1331 |
+
#include <ATen/ops/to_mkldnn.h>
|
| 1332 |
+
#include <ATen/ops/to_mkldnn_backward.h>
|
| 1333 |
+
#include <ATen/ops/to_padded_tensor.h>
|
| 1334 |
+
#include <ATen/ops/to_sparse.h>
|
| 1335 |
+
#include <ATen/ops/to_sparse_bsc.h>
|
| 1336 |
+
#include <ATen/ops/to_sparse_bsr.h>
|
| 1337 |
+
#include <ATen/ops/to_sparse_csc.h>
|
| 1338 |
+
#include <ATen/ops/to_sparse_csr.h>
|
| 1339 |
+
#include <ATen/ops/topk.h>
|
| 1340 |
+
#include <ATen/ops/trace.h>
|
| 1341 |
+
#include <ATen/ops/trace_backward.h>
|
| 1342 |
+
#include <ATen/ops/transpose.h>
|
| 1343 |
+
#include <ATen/ops/transpose_copy.h>
|
| 1344 |
+
#include <ATen/ops/trapezoid.h>
|
| 1345 |
+
#include <ATen/ops/trapz.h>
|
| 1346 |
+
#include <ATen/ops/triangular_solve.h>
|
| 1347 |
+
#include <ATen/ops/tril.h>
|
| 1348 |
+
#include <ATen/ops/tril_indices.h>
|
| 1349 |
+
#include <ATen/ops/triplet_margin_loss.h>
|
| 1350 |
+
#include <ATen/ops/triu.h>
|
| 1351 |
+
#include <ATen/ops/triu_indices.h>
|
| 1352 |
+
#include <ATen/ops/true_divide.h>
|
| 1353 |
+
#include <ATen/ops/trunc.h>
|
| 1354 |
+
#include <ATen/ops/type_as.h>
|
| 1355 |
+
#include <ATen/ops/unbind.h>
|
| 1356 |
+
#include <ATen/ops/unbind_copy.h>
|
| 1357 |
+
#include <ATen/ops/unflatten.h>
|
| 1358 |
+
#include <ATen/ops/unflatten_dense_tensors.h>
|
| 1359 |
+
#include <ATen/ops/unfold.h>
|
| 1360 |
+
#include <ATen/ops/unfold_backward.h>
|
| 1361 |
+
#include <ATen/ops/unfold_copy.h>
|
| 1362 |
+
#include <ATen/ops/uniform.h>
|
| 1363 |
+
#include <ATen/ops/unique_consecutive.h>
|
| 1364 |
+
#include <ATen/ops/unique_dim.h>
|
| 1365 |
+
#include <ATen/ops/unique_dim_consecutive.h>
|
| 1366 |
+
#include <ATen/ops/unsafe_chunk.h>
|
| 1367 |
+
#include <ATen/ops/unsafe_split.h>
|
| 1368 |
+
#include <ATen/ops/unsafe_split_with_sizes.h>
|
| 1369 |
+
#include <ATen/ops/unsqueeze.h>
|
| 1370 |
+
#include <ATen/ops/unsqueeze_copy.h>
|
| 1371 |
+
#include <ATen/ops/upsample_bicubic2d.h>
|
| 1372 |
+
#include <ATen/ops/upsample_bicubic2d_backward.h>
|
| 1373 |
+
#include <ATen/ops/upsample_bilinear2d.h>
|
| 1374 |
+
#include <ATen/ops/upsample_bilinear2d_backward.h>
|
| 1375 |
+
#include <ATen/ops/upsample_linear1d.h>
|
| 1376 |
+
#include <ATen/ops/upsample_linear1d_backward.h>
|
| 1377 |
+
#include <ATen/ops/upsample_nearest1d.h>
|
| 1378 |
+
#include <ATen/ops/upsample_nearest1d_backward.h>
|
| 1379 |
+
#include <ATen/ops/upsample_nearest2d.h>
|
| 1380 |
+
#include <ATen/ops/upsample_nearest2d_backward.h>
|
| 1381 |
+
#include <ATen/ops/upsample_nearest3d.h>
|
| 1382 |
+
#include <ATen/ops/upsample_nearest3d_backward.h>
|
| 1383 |
+
#include <ATen/ops/upsample_trilinear3d.h>
|
| 1384 |
+
#include <ATen/ops/upsample_trilinear3d_backward.h>
|
| 1385 |
+
#include <ATen/ops/value_selecting_reduction_backward.h>
|
| 1386 |
+
#include <ATen/ops/values.h>
|
| 1387 |
+
#include <ATen/ops/values_copy.h>
|
| 1388 |
+
#include <ATen/ops/vander.h>
|
| 1389 |
+
#include <ATen/ops/var.h>
|
| 1390 |
+
#include <ATen/ops/var_mean.h>
|
| 1391 |
+
#include <ATen/ops/vdot.h>
|
| 1392 |
+
#include <ATen/ops/view.h>
|
| 1393 |
+
#include <ATen/ops/view_as.h>
|
| 1394 |
+
#include <ATen/ops/view_as_complex.h>
|
| 1395 |
+
#include <ATen/ops/view_as_complex_copy.h>
|
| 1396 |
+
#include <ATen/ops/view_as_real.h>
|
| 1397 |
+
#include <ATen/ops/view_as_real_copy.h>
|
| 1398 |
+
#include <ATen/ops/view_copy.h>
|
| 1399 |
+
#include <ATen/ops/vsplit.h>
|
| 1400 |
+
#include <ATen/ops/vstack.h>
|
| 1401 |
+
#include <ATen/ops/where.h>
|
| 1402 |
+
#include <ATen/ops/xlogy.h>
|
| 1403 |
+
#include <ATen/ops/xor.h>
|
| 1404 |
+
#include <ATen/ops/zero.h>
|
| 1405 |
+
#include <ATen/ops/zeros.h>
|
| 1406 |
+
#include <ATen/ops/zeros_like.h>
|
| 1407 |
+
|
| 1408 |
+
namespace at {
|
| 1409 |
+
|
| 1410 |
+
|
| 1411 |
+
|
| 1412 |
+
// Special C++ only overloads for std()-like functions (See gh-40287)
|
| 1413 |
+
// These are needed because int -> bool conversion takes precedence over int -> IntArrayRef
|
| 1414 |
+
// So, for example std(0) would select the std(unbiased=False) overload
|
| 1415 |
+
inline Tensor var(const Tensor& self, int dim) {
|
| 1416 |
+
return at::var(self, IntArrayRef{dim});
|
| 1417 |
+
}
|
| 1418 |
+
inline std::tuple<Tensor, Tensor> var_mean(const Tensor& self, int dim) {
|
| 1419 |
+
return at::var_mean(self, IntArrayRef{dim});
|
| 1420 |
+
}
|
| 1421 |
+
inline Tensor std(const Tensor& self, int dim) {
|
| 1422 |
+
return at::std(self, IntArrayRef{dim});
|
| 1423 |
+
}
|
| 1424 |
+
inline std::tuple<Tensor, Tensor> std_mean(const Tensor& self, int dim) {
|
| 1425 |
+
return at::std_mean(self, IntArrayRef{dim});
|
| 1426 |
+
}
|
| 1427 |
+
|
| 1428 |
+
inline int64_t numel(const Tensor& tensor) {
|
| 1429 |
+
return tensor.numel();
|
| 1430 |
+
}
|
| 1431 |
+
|
| 1432 |
+
inline int64_t size(const Tensor& tensor, int64_t dim) {
|
| 1433 |
+
return tensor.size(dim);
|
| 1434 |
+
}
|
| 1435 |
+
|
| 1436 |
+
inline int64_t stride(const Tensor& tensor, int64_t dim) {
|
| 1437 |
+
return tensor.stride(dim);
|
| 1438 |
+
}
|
| 1439 |
+
|
| 1440 |
+
inline bool is_complex(const Tensor& tensor) {
|
| 1441 |
+
return tensor.is_complex();
|
| 1442 |
+
}
|
| 1443 |
+
|
| 1444 |
+
inline bool is_floating_point(const Tensor& tensor) {
|
| 1445 |
+
return tensor.is_floating_point();
|
| 1446 |
+
}
|
| 1447 |
+
|
| 1448 |
+
inline bool is_signed(const Tensor& tensor) {
|
| 1449 |
+
return tensor.is_signed();
|
| 1450 |
+
}
|
| 1451 |
+
|
| 1452 |
+
inline bool is_inference(const Tensor& tensor) {
|
| 1453 |
+
return tensor.is_inference();
|
| 1454 |
+
}
|
| 1455 |
+
|
| 1456 |
+
inline bool _is_zerotensor(const Tensor& tensor) {
|
| 1457 |
+
return tensor._is_zerotensor();
|
| 1458 |
+
}
|
| 1459 |
+
|
| 1460 |
+
inline bool is_conj(const Tensor& tensor) {
|
| 1461 |
+
return tensor.is_conj();
|
| 1462 |
+
}
|
| 1463 |
+
|
| 1464 |
+
inline Tensor conj(const Tensor& tensor) {
|
| 1465 |
+
return tensor.conj();
|
| 1466 |
+
}
|
| 1467 |
+
|
| 1468 |
+
inline bool is_neg(const Tensor& tensor) {
|
| 1469 |
+
return tensor.is_neg();
|
| 1470 |
+
}
|
| 1471 |
+
|
| 1472 |
+
}
|
| 1473 |
+
|
| 1474 |
+
#else
|
| 1475 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 1476 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/InitialTensorOptions.h
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <c10/core/TensorOptions.h>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
|
| 8 |
+
// Represents the initial TensorOptions, before the "defaults" are ever changed.
|
| 9 |
+
// This is designed to be used in library code, where the explicit devices,
|
| 10 |
+
// dtypes, etc. are known. NOTE: this is not a stable API.
|
| 11 |
+
inline TensorOptions initialTensorOptions() {
|
| 12 |
+
return TensorOptions(kCPU).dtype(kFloat).layout(kStrided).requires_grad(
|
| 13 |
+
false);
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
} // namespace at
|
| 17 |
+
|
| 18 |
+
#else
|
| 19 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 20 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/LegacyBatchedTensorImpl.h
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <bitset>
|
| 5 |
+
|
| 6 |
+
#include <ATen/ArrayRef.h>
|
| 7 |
+
#include <ATen/SmallVector.h>
|
| 8 |
+
#include <ATen/Tensor.h>
|
| 9 |
+
|
| 10 |
+
namespace at {
|
| 11 |
+
|
| 12 |
+
// We assume this in a few other places in the codebase,
|
| 13 |
+
// but there isn't a centralized definition.
|
| 14 |
+
constexpr int64_t kVmapMaxTensorDims = 64;
|
| 15 |
+
|
| 16 |
+
// The valid vmap levels range from [0, 64). This effectively means that we
|
| 17 |
+
// support a maximum of 64 nested vmaps.
|
| 18 |
+
constexpr int64_t kVmapNumLevels = 64;
|
| 19 |
+
|
| 20 |
+
// Store this number of elements of BatchDims on the stack. Most people will
|
| 21 |
+
// probably use <= 5 nested vmaps, but adjust this number as necessary.
|
| 22 |
+
constexpr int64_t kBatchDimsStackSize = 5;
|
| 23 |
+
|
| 24 |
+
// a BatchDim represents a "private" dimension on a Tensor created inside of
|
| 25 |
+
// vmap. It is a (level, dim) tuple, with the `dim` indicating which dimension
|
| 26 |
+
// is being vmap'ed over and the `level` being an identifier for which vmap
|
| 27 |
+
// said dimension was created inside. The `dim` corresponds to a "physical
|
| 28 |
+
// dim" - it is a dimension index on the underlying physical tensor that is
|
| 29 |
+
// being vmapped over.
|
| 30 |
+
struct BatchDim {
|
| 31 |
+
BatchDim(int64_t level, int64_t dim) : dim_(dim), level_(level) {}
|
| 32 |
+
int64_t dim() const {
|
| 33 |
+
return dim_;
|
| 34 |
+
}
|
| 35 |
+
int64_t level() const {
|
| 36 |
+
return level_;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
private:
|
| 40 |
+
int64_t dim_;
|
| 41 |
+
int64_t level_;
|
| 42 |
+
};
|
| 43 |
+
|
| 44 |
+
using BatchDims = SmallVector<BatchDim, kBatchDimsStackSize>;
|
| 45 |
+
using BatchDimsRef = ArrayRef<BatchDim>;
|
| 46 |
+
|
| 47 |
+
// A BatchedTensorImpl holds an underlying Tensor and a list of BatchDim
|
| 48 |
+
// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
|
| 49 |
+
// BatchedTensorImpl.
|
| 50 |
+
//
|
| 51 |
+
// The batch dimensions are treated as being "private"; they are not
|
| 52 |
+
// user-visible. For example, in the following Tensor,
|
| 53 |
+
// bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2, dim=1)])
|
| 54 |
+
// dimensions 0 and 1 are batch dimensions.
|
| 55 |
+
//
|
| 56 |
+
// bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
|
| 57 |
+
// dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7)
|
| 58 |
+
// tensor.
|
| 59 |
+
struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
|
| 60 |
+
explicit BatchedTensorImpl(Tensor value, BatchDims bdims);
|
| 61 |
+
|
| 62 |
+
// Returns a reference to BatchDims that represent which dimensions of this
|
| 63 |
+
// tensor are private.
|
| 64 |
+
BatchDimsRef bdims() const {
|
| 65 |
+
return bdims_;
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
// BatchedTensorImpl wraps a Tensor
|
| 69 |
+
const Tensor& value() const {
|
| 70 |
+
return value_;
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
// Given a public dimension index, return the dimension index in the
|
| 74 |
+
// underlying value() tensor. For example, if we have
|
| 75 |
+
// bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2,
|
| 76 |
+
// dim=2)])
|
| 77 |
+
// bt.actualDim(0) -> 1
|
| 78 |
+
// bt.actualDim(1) -> 3
|
| 79 |
+
// bt.actualDim(2) -> Error
|
| 80 |
+
int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
|
| 81 |
+
|
| 82 |
+
// We have to override this because we opted into CustomStrides
|
| 83 |
+
IntArrayRef strides_custom() const override;
|
| 84 |
+
// Override a bunch of methods inherited from TensorImpl to return error
|
| 85 |
+
// messages.
|
| 86 |
+
c10::SymBool sym_is_contiguous_custom(
|
| 87 |
+
at::MemoryFormat memory_format) const override;
|
| 88 |
+
void set_size(int64_t dim, int64_t new_size) override;
|
| 89 |
+
void set_stride(int64_t dim, int64_t new_stride) override;
|
| 90 |
+
void set_storage_offset(int64_t storage_offset) override;
|
| 91 |
+
#ifdef DEBUG
|
| 92 |
+
bool has_storage() const override;
|
| 93 |
+
#endif
|
| 94 |
+
|
| 95 |
+
private:
|
| 96 |
+
// see NOTE: [BatchedTensorImpl levels invariant]
|
| 97 |
+
void checkInvariants() const;
|
| 98 |
+
const char* tensorimpl_type_name() const override;
|
| 99 |
+
|
| 100 |
+
Tensor value_;
|
| 101 |
+
|
| 102 |
+
// Note: [BatchedTensorImpl levels invariant]
|
| 103 |
+
// There is an invariant that the BatchDims must be stored in increasing
|
| 104 |
+
// `level` order. That is, for i < j, bdims_[i].level must be less than
|
| 105 |
+
// bdims_[j].level.
|
| 106 |
+
BatchDims bdims_;
|
| 107 |
+
};
|
| 108 |
+
|
| 109 |
+
// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
|
| 110 |
+
// BatchedTensorImpl.
|
| 111 |
+
inline bool isBatchedTensor(const Tensor& tensor) {
|
| 112 |
+
return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::Batched);
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
// It is unsafe to call this on a Tensor that is not backed by a
|
| 116 |
+
// BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible.
|
| 117 |
+
inline BatchedTensorImpl* unsafeGetBatchedImpl(const Tensor& tensor) {
|
| 118 |
+
return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
inline BatchedTensorImpl* maybeGetBatchedImpl(const Tensor& tensor) {
|
| 122 |
+
if (!isBatchedTensor(tensor)) {
|
| 123 |
+
return nullptr;
|
| 124 |
+
}
|
| 125 |
+
return unsafeGetBatchedImpl(tensor);
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
// Returns a bitset. If bit i is set, then that means dim i is a batchdim.
|
| 129 |
+
inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset(
|
| 130 |
+
BatchDimsRef bdims) {
|
| 131 |
+
std::bitset<kVmapMaxTensorDims> is_bdim;
|
| 132 |
+
for (const auto& bdim : bdims) {
|
| 133 |
+
is_bdim.set(bdim.dim());
|
| 134 |
+
}
|
| 135 |
+
return is_bdim;
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
// Creates a bitset for all of the levels present in `bdims`
|
| 139 |
+
inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(BatchDimsRef bdims) {
|
| 140 |
+
std::bitset<kVmapNumLevels> result;
|
| 141 |
+
for (const auto& bdim : bdims) {
|
| 142 |
+
result.set(bdim.level());
|
| 143 |
+
}
|
| 144 |
+
return result;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) {
|
| 148 |
+
out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ')';
|
| 149 |
+
return out;
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
// Use this to construct a BatchedTensor from a regular Tensor
|
| 153 |
+
TORCH_API Tensor makeBatched(Tensor tensor, BatchDims bdims);
|
| 154 |
+
|
| 155 |
+
// Adds a batch dim to `tensor`, returning a BatchedTensor
|
| 156 |
+
TORCH_API Tensor addBatchDim(Tensor tensor, int64_t level, int64_t dim);
|
| 157 |
+
|
| 158 |
+
// Checks if an inplace operation on self and other is "vmap compatible".
|
| 159 |
+
// See NOTE: [vmap-incompatible in-place operations] for the definition of this.
|
| 160 |
+
TORCH_API bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other);
|
| 161 |
+
|
| 162 |
+
} // namespace at
|
| 163 |
+
|
| 164 |
+
#else
|
| 165 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 166 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/LegacyVmapMode.h
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <c10/core/impl/LocalDispatchKeySet.h>
|
| 5 |
+
|
| 6 |
+
namespace at::impl {
|
| 7 |
+
|
| 8 |
+
// VmapMode contains a thread local count of how many nested vmaps
|
| 9 |
+
// we are currently inside. That number is known as the `vmap level`.
|
| 10 |
+
// VmapMode is used in the implementation of the Python `torch.vmap` API.
|
| 11 |
+
//
|
| 12 |
+
// NOTE: this is NOT the c++ api for torch.vmap. That doesn't exist yet.
|
| 13 |
+
|
| 14 |
+
struct TORCH_API VmapMode {
|
| 15 |
+
// Returns the vmap level, aka the count of how many nested vmaps we're in.
|
| 16 |
+
static int64_t current_vmap_level();
|
| 17 |
+
|
| 18 |
+
// Increment the count of nested vmaps. If this causes the vmap level to be
|
| 19 |
+
// greater than 0, then it enables DispatchKey::VmapMode on all tensors.
|
| 20 |
+
static int64_t increment_nesting();
|
| 21 |
+
|
| 22 |
+
// Decrements the count of nested vmaps. If this causes the vmap level to be
|
| 23 |
+
// equal to 0, then it disables DispatchKey::VmapMode on all tensors.
|
| 24 |
+
static int64_t decrement_nesting();
|
| 25 |
+
};
|
| 26 |
+
|
| 27 |
+
} // namespace at::impl
|
| 28 |
+
|
| 29 |
+
#else
|
| 30 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 31 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/LegacyVmapTransforms.h
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/LegacyBatchedTensorImpl.h>
|
| 5 |
+
#include <ATen/core/IListRef.h>
|
| 6 |
+
|
| 7 |
+
namespace at {
|
| 8 |
+
|
| 9 |
+
// This file contains abstractions used for transforming *logical* vmap
|
| 10 |
+
// arguments into *physical* arguments. (Keep reading for definitions of these
|
| 11 |
+
// terms).
|
| 12 |
+
|
| 13 |
+
// NOTE: [Logical vs physical args]
|
| 14 |
+
// Consider the following vmap.
|
| 15 |
+
// vmap(vmap(func, in_dims=(2,)), in_dims=(0,))(torch.ones(2, 3, 4))
|
| 16 |
+
// This would produce a BatchedTensor wrapping a Tensor of size [2, 3, 4],
|
| 17 |
+
// with batch dims 0 and 2:
|
| 18 |
+
// BatchedTensor(ones(2, 3, 4), bdims=[(lvl=1,dim=0),(lvl=2,dim=2)])
|
| 19 |
+
//
|
| 20 |
+
// We say the *logical* view of the tensor has size [3] -- tensors inside
|
| 21 |
+
// `func` appear to have size [3].
|
| 22 |
+
// However, the *physical* underlying tensor (the one passed to vmap) has size
|
| 23 |
+
// [2, 3, 4].
|
| 24 |
+
//
|
| 25 |
+
// This notion of logical vs physical also extends to non-tensor arguments.
|
| 26 |
+
// Consider the previous tensor; let's assume the user called
|
| 27 |
+
// `torch.sum(tensor, dim=0)` inside of `func`. Then the logical
|
| 28 |
+
// dimension they are reducing over is dim 0 but the physical dim is dim 1
|
| 29 |
+
// (the first non-batch dimension)
|
| 30 |
+
|
| 31 |
+
// Forward declared; see NOTE: [What is a VmapPhysicalView?]
|
| 32 |
+
struct VmapPhysicalView;
|
| 33 |
+
|
| 34 |
+
// Most PyTorch operators take 4 or fewer inputs.
|
| 35 |
+
constexpr int64_t kVmapTransformStaticInputSize = 4;
|
| 36 |
+
using VmapPhysicalViewVec =
|
| 37 |
+
SmallVector<VmapPhysicalView, kVmapTransformStaticInputSize>;
|
| 38 |
+
|
| 39 |
+
// Pytorch generally advertises good performance for <= 5 dims.
|
| 40 |
+
// (see ATen/core/DimVector.h). We add a few extra dims (~3) for vmap
|
| 41 |
+
// dimensions to get 8. Adjust this number as necessary
|
| 42 |
+
constexpr int64_t kVmapStaticDimVecSize = 8;
|
| 43 |
+
using VmapDimVector = SmallVector<int64_t, kVmapStaticDimVecSize>;
|
| 44 |
+
using VmapSymDimVector = SmallVector<c10::SymInt, kVmapStaticDimVecSize>;
|
| 45 |
+
|
| 46 |
+
// NOTE: [What is an VmapTransform?]
|
| 47 |
+
// An *VmapTransform* converts logical views of tensors to physical views.
|
| 48 |
+
//
|
| 49 |
+
// Batching rules use VmapTransforms to convert logical arguments to
|
| 50 |
+
// physical arguments, then call one or more at:: operator that handles the
|
| 51 |
+
// physical arguments, and then converts the physical result back to a logical
|
| 52 |
+
// argument.
|
| 53 |
+
|
| 54 |
+
// VmapTransform for operators that take tensors with multiple batch dims.
|
| 55 |
+
// Given one or more logical views on Tensors, `logicalToPhysical`
|
| 56 |
+
// permutes all of the batch dims to the front of the tensor, aligns
|
| 57 |
+
// and expands the batch dims to match each other (according to their `level`),
|
| 58 |
+
// and returns a VmapPhysicalView on the tensor(s).
|
| 59 |
+
struct TORCH_API MultiBatchVmapTransform {
|
| 60 |
+
static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor);
|
| 61 |
+
static VmapPhysicalViewVec logicalToPhysical(ITensorListRef logical_tensors);
|
| 62 |
+
};
|
| 63 |
+
|
| 64 |
+
// VmapTransform for operators that broadcast all inputs.
|
| 65 |
+
// Given some logical views on Tensors, `logicalToPhysical`:
|
| 66 |
+
// - permutes all of the batch dims to the front of the tensors
|
| 67 |
+
// - aligns all the batch dims to the collective levels of all of the tensors.
|
| 68 |
+
// If a tensor does not have a batch dim for a vmap level, then it receives
|
| 69 |
+
// a size-one dimension for said level.
|
| 70 |
+
// - aligns the non-batch dims to have the same dimensionality, adding extra
|
| 71 |
+
// size-1 dimensions in between the batch dimensions and the non-batch
|
| 72 |
+
// dimensions so that the batch dimensions are lined up from the right.
|
| 73 |
+
//
|
| 74 |
+
// For example: given inputs of size (B, 2) and (B, 3, 2) where B is the batch
|
| 75 |
+
// dimension, BroadcastingVmapTransform returns VmapPhysicalViews that wrap
|
| 76 |
+
// tensors of size (B, 1, 2) and (B, 3, 2).
|
| 77 |
+
//
|
| 78 |
+
// Given inputs of size (B, 2) and (2,), BroadcastingVmapTransform returns
|
| 79 |
+
// VmapPhysicalViews wrapping tensors of size (B, 2) and (1, 2). We don't
|
| 80 |
+
// actually *need* to return a tensor of size (1, 2) for the second tensor
|
| 81 |
+
// because the broadcasting operation takes care of that for us, but we do
|
| 82 |
+
// it anyways to keep things simple.
|
| 83 |
+
struct TORCH_API BroadcastingVmapTransform {
|
| 84 |
+
static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors);
|
| 85 |
+
};
|
| 86 |
+
|
| 87 |
+
// Forward declared, if you're reading this file head to toe, don't worry about
|
| 88 |
+
// it yet.
|
| 89 |
+
struct VmapPhysicalToLogicalMap;
|
| 90 |
+
|
| 91 |
+
// NOTE: [What is a VmapPhysicalView?]
|
| 92 |
+
// VmapPhysicalView represents a physical view on a Tensor.
|
| 93 |
+
//
|
| 94 |
+
// One can use it to further convert logical dimension indices, logical shapes,
|
| 95 |
+
// and more to their physical variants, or convert a new (physical) tensor into
|
| 96 |
+
// a logical BatchedTensor. (TODO(rzou): some of these are not yet implemented).
|
| 97 |
+
//
|
| 98 |
+
// VmapPhysicalView stores a physical tensor with all of its batch dimensions at
|
| 99 |
+
// the front and some levels that correspond to said batch dimensions.
|
| 100 |
+
//
|
| 101 |
+
// The levels bitset specifies which vmap levels correspond to the batch
|
| 102 |
+
// dimensions at the front of the tensor. In particular, the number of set bits
|
| 103 |
+
// corresponds to the number of batch dimensions on `tensor` and the rightmost
|
| 104 |
+
// bit of `levels` specifies the maximum number of nested vmaps we are in at
|
| 105 |
+
// this point in time.
|
| 106 |
+
// For example, given:
|
| 107 |
+
// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3})
|
| 108 |
+
//
|
| 109 |
+
// Rightmost bit of `levels` is 3 indicating the number of nested vmaps less
|
| 110 |
+
// than or equal to 3.
|
| 111 |
+
// bitset: 010100
|
| 112 |
+
// ^
|
| 113 |
+
// |
|
| 114 |
+
// levels: 012345
|
| 115 |
+
struct TORCH_API VmapPhysicalView {
|
| 116 |
+
VmapPhysicalView(Tensor&& tensor, std::bitset<kVmapNumLevels> levels)
|
| 117 |
+
: levels_(levels), tensor_(std::move(tensor)) {
|
| 118 |
+
TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor_));
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
Tensor& tensor() {
|
| 122 |
+
return tensor_;
|
| 123 |
+
}
|
| 124 |
+
const Tensor& tensor() const {
|
| 125 |
+
return tensor_;
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
// Maps logical dim indices to physical dim indices. Also does dim wrapping.
|
| 129 |
+
//
|
| 130 |
+
// For example, given:
|
| 131 |
+
// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5), levels={1, 3})
|
| 132 |
+
//
|
| 133 |
+
// Then physical_view.getPhysicalDims({0, 1}) returns {2, 3}.
|
| 134 |
+
// This is because the size of levels tell us that the first two dimensions
|
| 135 |
+
// of `tensor_` are batch dimensions, so a logical dim of `n` is actually
|
| 136 |
+
// a physical dim of `n + 2`.
|
| 137 |
+
VmapDimVector getPhysicalDims(OptionalIntArrayRef logical_dims) const;
|
| 138 |
+
int64_t getPhysicalDim(int64_t logical_dim) const;
|
| 139 |
+
|
| 140 |
+
// Returns a VmapPhysicalToLogicalMap object. This can be used for
|
| 141 |
+
// mapping a physical tensor to a new logical tensor (BatchedTensor)
|
| 142 |
+
VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const;
|
| 143 |
+
|
| 144 |
+
// Maps a logical shape to a physical shape by prepending the batch
|
| 145 |
+
// sizes to the logical shape.
|
| 146 |
+
VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const;
|
| 147 |
+
|
| 148 |
+
int64_t numBatchDims() const;
|
| 149 |
+
|
| 150 |
+
private:
|
| 151 |
+
int64_t numLogicalDims() const;
|
| 152 |
+
|
| 153 |
+
std::bitset<kVmapNumLevels> levels_;
|
| 154 |
+
Tensor tensor_;
|
| 155 |
+
};
|
| 156 |
+
|
| 157 |
+
// Convenience struct used for mapping a physical tensor (a non-BatchedTensor)
|
| 158 |
+
// to a logical one (BatchedTensor). It holds some levels that are used to do
|
| 159 |
+
// the mapping and assumes that the batch dimensions in the physical tensor all
|
| 160 |
+
// occur at the front of the tensor.
|
| 161 |
+
struct TORCH_API VmapPhysicalToLogicalMap {
|
| 162 |
+
VmapPhysicalToLogicalMap(std::bitset<kVmapNumLevels> levels)
|
| 163 |
+
: levels_(levels) {}
|
| 164 |
+
|
| 165 |
+
// Maps a physical tensor to a new logical tensor (BatchedTensor).
|
| 166 |
+
// Assumes that all of the "batch dimensions" are at the front
|
| 167 |
+
// of the physical tensor. For example, given:
|
| 168 |
+
// - x = rank-4 Tensor with size 2, 3, 5, 7
|
| 169 |
+
// - levels = (2, 4)
|
| 170 |
+
// Returns:
|
| 171 |
+
// - BatchedTensor(x, bdims=[(dim=0,lvl=2), (dim=1, lvl=4)])
|
| 172 |
+
Tensor apply(const Tensor& physical_tensor) const;
|
| 173 |
+
|
| 174 |
+
// Given a vector of physical tensors,
|
| 175 |
+
// 1. maps each tensor to a new logical tensor. Assumes that all of the
|
| 176 |
+
// "batch dimensions" are at the front of the physical tensors.
|
| 177 |
+
// 2. stores the new logical tensors back into the passed-in vector. This is
|
| 178 |
+
// to avoid additional dynamic allocations.
|
| 179 |
+
void applyInplace(std::vector<Tensor>& physical_tensors) const;
|
| 180 |
+
|
| 181 |
+
std::bitset<kVmapNumLevels> levels_;
|
| 182 |
+
};
|
| 183 |
+
|
| 184 |
+
} // namespace at
|
| 185 |
+
|
| 186 |
+
#else
|
| 187 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 188 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/MethodOperators.h
ADDED
|
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
// @generated by torchgen/gen.py from MethodOperators.h
|
| 5 |
+
|
| 6 |
+
#ifdef TORCH_ASSERT_NO_OPERATORS
|
| 7 |
+
#error This change adds a dependency on native_functions.yaml, \
|
| 8 |
+
meaning the file will need to be re-compiled every time an operator \
|
| 9 |
+
is changed or added. Consider if your change would be better placed in \
|
| 10 |
+
another file, or if a more specific header might achieve the same goal. \
|
| 11 |
+
See NOTE: [Tensor vs. TensorBase]
|
| 12 |
+
#endif
|
| 13 |
+
|
| 14 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 15 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 16 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 17 |
+
#include <ATen/core/ATen_fwd.h>
|
| 18 |
+
|
| 19 |
+
#include <ATen/ops/_addmm_activation_ops.h>
|
| 20 |
+
#include <ATen/ops/_autocast_to_full_precision_ops.h>
|
| 21 |
+
#include <ATen/ops/_autocast_to_reduced_precision_ops.h>
|
| 22 |
+
#include <ATen/ops/_backward_ops.h>
|
| 23 |
+
#include <ATen/ops/_coalesced_ops.h>
|
| 24 |
+
#include <ATen/ops/_conj_ops.h>
|
| 25 |
+
#include <ATen/ops/_conj_physical_ops.h>
|
| 26 |
+
#include <ATen/ops/_dimI_ops.h>
|
| 27 |
+
#include <ATen/ops/_dimV_ops.h>
|
| 28 |
+
#include <ATen/ops/_fw_primal_ops.h>
|
| 29 |
+
#include <ATen/ops/_indices_ops.h>
|
| 30 |
+
#include <ATen/ops/_is_all_true_ops.h>
|
| 31 |
+
#include <ATen/ops/_is_any_true_ops.h>
|
| 32 |
+
#include <ATen/ops/_is_zerotensor_ops.h>
|
| 33 |
+
#include <ATen/ops/_lazy_clone_ops.h>
|
| 34 |
+
#include <ATen/ops/_neg_view_ops.h>
|
| 35 |
+
#include <ATen/ops/_nested_tensor_size_ops.h>
|
| 36 |
+
#include <ATen/ops/_nested_tensor_storage_offsets_ops.h>
|
| 37 |
+
#include <ATen/ops/_nested_tensor_strides_ops.h>
|
| 38 |
+
#include <ATen/ops/_nnz_ops.h>
|
| 39 |
+
#include <ATen/ops/_reshape_alias_ops.h>
|
| 40 |
+
#include <ATen/ops/_sparse_mask_projection_ops.h>
|
| 41 |
+
#include <ATen/ops/_to_dense_ops.h>
|
| 42 |
+
#include <ATen/ops/_to_sparse_bsc_ops.h>
|
| 43 |
+
#include <ATen/ops/_to_sparse_bsr_ops.h>
|
| 44 |
+
#include <ATen/ops/_to_sparse_csc_ops.h>
|
| 45 |
+
#include <ATen/ops/_to_sparse_csr_ops.h>
|
| 46 |
+
#include <ATen/ops/_to_sparse_ops.h>
|
| 47 |
+
#include <ATen/ops/_values_ops.h>
|
| 48 |
+
#include <ATen/ops/_version_ops.h>
|
| 49 |
+
#include <ATen/ops/abs_ops.h>
|
| 50 |
+
#include <ATen/ops/absolute_ops.h>
|
| 51 |
+
#include <ATen/ops/acos_ops.h>
|
| 52 |
+
#include <ATen/ops/acosh_ops.h>
|
| 53 |
+
#include <ATen/ops/add_ops.h>
|
| 54 |
+
#include <ATen/ops/addbmm_ops.h>
|
| 55 |
+
#include <ATen/ops/addcdiv_ops.h>
|
| 56 |
+
#include <ATen/ops/addcmul_ops.h>
|
| 57 |
+
#include <ATen/ops/addmm_ops.h>
|
| 58 |
+
#include <ATen/ops/addmv_ops.h>
|
| 59 |
+
#include <ATen/ops/addr_ops.h>
|
| 60 |
+
#include <ATen/ops/adjoint_ops.h>
|
| 61 |
+
#include <ATen/ops/alias_ops.h>
|
| 62 |
+
#include <ATen/ops/align_as_ops.h>
|
| 63 |
+
#include <ATen/ops/align_to_ops.h>
|
| 64 |
+
#include <ATen/ops/all_ops.h>
|
| 65 |
+
#include <ATen/ops/allclose_ops.h>
|
| 66 |
+
#include <ATen/ops/amax_ops.h>
|
| 67 |
+
#include <ATen/ops/amin_ops.h>
|
| 68 |
+
#include <ATen/ops/aminmax_ops.h>
|
| 69 |
+
#include <ATen/ops/and_ops.h>
|
| 70 |
+
#include <ATen/ops/angle_ops.h>
|
| 71 |
+
#include <ATen/ops/any_ops.h>
|
| 72 |
+
#include <ATen/ops/arccos_ops.h>
|
| 73 |
+
#include <ATen/ops/arccosh_ops.h>
|
| 74 |
+
#include <ATen/ops/arcsin_ops.h>
|
| 75 |
+
#include <ATen/ops/arcsinh_ops.h>
|
| 76 |
+
#include <ATen/ops/arctan2_ops.h>
|
| 77 |
+
#include <ATen/ops/arctan_ops.h>
|
| 78 |
+
#include <ATen/ops/arctanh_ops.h>
|
| 79 |
+
#include <ATen/ops/argmax_ops.h>
|
| 80 |
+
#include <ATen/ops/argmin_ops.h>
|
| 81 |
+
#include <ATen/ops/argsort_ops.h>
|
| 82 |
+
#include <ATen/ops/argwhere_ops.h>
|
| 83 |
+
#include <ATen/ops/as_strided_ops.h>
|
| 84 |
+
#include <ATen/ops/as_strided_scatter_ops.h>
|
| 85 |
+
#include <ATen/ops/asin_ops.h>
|
| 86 |
+
#include <ATen/ops/asinh_ops.h>
|
| 87 |
+
#include <ATen/ops/atan2_ops.h>
|
| 88 |
+
#include <ATen/ops/atan_ops.h>
|
| 89 |
+
#include <ATen/ops/atanh_ops.h>
|
| 90 |
+
#include <ATen/ops/baddbmm_ops.h>
|
| 91 |
+
#include <ATen/ops/bernoulli_ops.h>
|
| 92 |
+
#include <ATen/ops/bincount_ops.h>
|
| 93 |
+
#include <ATen/ops/bitwise_and_ops.h>
|
| 94 |
+
#include <ATen/ops/bitwise_left_shift_ops.h>
|
| 95 |
+
#include <ATen/ops/bitwise_not_ops.h>
|
| 96 |
+
#include <ATen/ops/bitwise_or_ops.h>
|
| 97 |
+
#include <ATen/ops/bitwise_right_shift_ops.h>
|
| 98 |
+
#include <ATen/ops/bitwise_xor_ops.h>
|
| 99 |
+
#include <ATen/ops/bmm_ops.h>
|
| 100 |
+
#include <ATen/ops/broadcast_to_ops.h>
|
| 101 |
+
#include <ATen/ops/cauchy_ops.h>
|
| 102 |
+
#include <ATen/ops/ccol_indices_ops.h>
|
| 103 |
+
#include <ATen/ops/ceil_ops.h>
|
| 104 |
+
#include <ATen/ops/chalf_ops.h>
|
| 105 |
+
#include <ATen/ops/cholesky_inverse_ops.h>
|
| 106 |
+
#include <ATen/ops/cholesky_ops.h>
|
| 107 |
+
#include <ATen/ops/cholesky_solve_ops.h>
|
| 108 |
+
#include <ATen/ops/chunk_ops.h>
|
| 109 |
+
#include <ATen/ops/clamp_max_ops.h>
|
| 110 |
+
#include <ATen/ops/clamp_min_ops.h>
|
| 111 |
+
#include <ATen/ops/clamp_ops.h>
|
| 112 |
+
#include <ATen/ops/clip_ops.h>
|
| 113 |
+
#include <ATen/ops/clone_ops.h>
|
| 114 |
+
#include <ATen/ops/coalesce_ops.h>
|
| 115 |
+
#include <ATen/ops/col_indices_ops.h>
|
| 116 |
+
#include <ATen/ops/conj_ops.h>
|
| 117 |
+
#include <ATen/ops/conj_physical_ops.h>
|
| 118 |
+
#include <ATen/ops/contiguous_ops.h>
|
| 119 |
+
#include <ATen/ops/copy_ops.h>
|
| 120 |
+
#include <ATen/ops/copysign_ops.h>
|
| 121 |
+
#include <ATen/ops/corrcoef_ops.h>
|
| 122 |
+
#include <ATen/ops/cos_ops.h>
|
| 123 |
+
#include <ATen/ops/cosh_ops.h>
|
| 124 |
+
#include <ATen/ops/count_nonzero_ops.h>
|
| 125 |
+
#include <ATen/ops/cov_ops.h>
|
| 126 |
+
#include <ATen/ops/cross_ops.h>
|
| 127 |
+
#include <ATen/ops/crow_indices_ops.h>
|
| 128 |
+
#include <ATen/ops/cummax_ops.h>
|
| 129 |
+
#include <ATen/ops/cummin_ops.h>
|
| 130 |
+
#include <ATen/ops/cumprod_ops.h>
|
| 131 |
+
#include <ATen/ops/cumsum_ops.h>
|
| 132 |
+
#include <ATen/ops/data_ops.h>
|
| 133 |
+
#include <ATen/ops/deg2rad_ops.h>
|
| 134 |
+
#include <ATen/ops/dense_dim_ops.h>
|
| 135 |
+
#include <ATen/ops/dequantize_ops.h>
|
| 136 |
+
#include <ATen/ops/det_ops.h>
|
| 137 |
+
#include <ATen/ops/detach_ops.h>
|
| 138 |
+
#include <ATen/ops/diag_embed_ops.h>
|
| 139 |
+
#include <ATen/ops/diag_ops.h>
|
| 140 |
+
#include <ATen/ops/diagflat_ops.h>
|
| 141 |
+
#include <ATen/ops/diagonal_ops.h>
|
| 142 |
+
#include <ATen/ops/diagonal_scatter_ops.h>
|
| 143 |
+
#include <ATen/ops/diff_ops.h>
|
| 144 |
+
#include <ATen/ops/digamma_ops.h>
|
| 145 |
+
#include <ATen/ops/dist_ops.h>
|
| 146 |
+
#include <ATen/ops/div_ops.h>
|
| 147 |
+
#include <ATen/ops/divide_ops.h>
|
| 148 |
+
#include <ATen/ops/dot_ops.h>
|
| 149 |
+
#include <ATen/ops/dsplit_ops.h>
|
| 150 |
+
#include <ATen/ops/eq_ops.h>
|
| 151 |
+
#include <ATen/ops/equal_ops.h>
|
| 152 |
+
#include <ATen/ops/erf_ops.h>
|
| 153 |
+
#include <ATen/ops/erfc_ops.h>
|
| 154 |
+
#include <ATen/ops/erfinv_ops.h>
|
| 155 |
+
#include <ATen/ops/exp2_ops.h>
|
| 156 |
+
#include <ATen/ops/exp_ops.h>
|
| 157 |
+
#include <ATen/ops/expand_as_ops.h>
|
| 158 |
+
#include <ATen/ops/expand_ops.h>
|
| 159 |
+
#include <ATen/ops/expm1_ops.h>
|
| 160 |
+
#include <ATen/ops/exponential_ops.h>
|
| 161 |
+
#include <ATen/ops/fill_diagonal_ops.h>
|
| 162 |
+
#include <ATen/ops/fill_ops.h>
|
| 163 |
+
#include <ATen/ops/fix_ops.h>
|
| 164 |
+
#include <ATen/ops/flatten_ops.h>
|
| 165 |
+
#include <ATen/ops/flip_ops.h>
|
| 166 |
+
#include <ATen/ops/fliplr_ops.h>
|
| 167 |
+
#include <ATen/ops/flipud_ops.h>
|
| 168 |
+
#include <ATen/ops/float_power_ops.h>
|
| 169 |
+
#include <ATen/ops/floor_divide_ops.h>
|
| 170 |
+
#include <ATen/ops/floor_ops.h>
|
| 171 |
+
#include <ATen/ops/fmax_ops.h>
|
| 172 |
+
#include <ATen/ops/fmin_ops.h>
|
| 173 |
+
#include <ATen/ops/fmod_ops.h>
|
| 174 |
+
#include <ATen/ops/frac_ops.h>
|
| 175 |
+
#include <ATen/ops/frexp_ops.h>
|
| 176 |
+
#include <ATen/ops/gather_ops.h>
|
| 177 |
+
#include <ATen/ops/gcd_ops.h>
|
| 178 |
+
#include <ATen/ops/ge_ops.h>
|
| 179 |
+
#include <ATen/ops/geometric_ops.h>
|
| 180 |
+
#include <ATen/ops/geqrf_ops.h>
|
| 181 |
+
#include <ATen/ops/ger_ops.h>
|
| 182 |
+
#include <ATen/ops/greater_equal_ops.h>
|
| 183 |
+
#include <ATen/ops/greater_ops.h>
|
| 184 |
+
#include <ATen/ops/gt_ops.h>
|
| 185 |
+
#include <ATen/ops/hardshrink_backward_ops.h>
|
| 186 |
+
#include <ATen/ops/hardshrink_ops.h>
|
| 187 |
+
#include <ATen/ops/hash_tensor_ops.h>
|
| 188 |
+
#include <ATen/ops/heaviside_ops.h>
|
| 189 |
+
#include <ATen/ops/histc_ops.h>
|
| 190 |
+
#include <ATen/ops/histogram_ops.h>
|
| 191 |
+
#include <ATen/ops/hsplit_ops.h>
|
| 192 |
+
#include <ATen/ops/hypot_ops.h>
|
| 193 |
+
#include <ATen/ops/i0_ops.h>
|
| 194 |
+
#include <ATen/ops/igamma_ops.h>
|
| 195 |
+
#include <ATen/ops/igammac_ops.h>
|
| 196 |
+
#include <ATen/ops/index_add_ops.h>
|
| 197 |
+
#include <ATen/ops/index_copy_ops.h>
|
| 198 |
+
#include <ATen/ops/index_fill_ops.h>
|
| 199 |
+
#include <ATen/ops/index_ops.h>
|
| 200 |
+
#include <ATen/ops/index_put_ops.h>
|
| 201 |
+
#include <ATen/ops/index_reduce_ops.h>
|
| 202 |
+
#include <ATen/ops/index_select_ops.h>
|
| 203 |
+
#include <ATen/ops/indices_ops.h>
|
| 204 |
+
#include <ATen/ops/inner_ops.h>
|
| 205 |
+
#include <ATen/ops/int_repr_ops.h>
|
| 206 |
+
#include <ATen/ops/inverse_ops.h>
|
| 207 |
+
#include <ATen/ops/is_coalesced_ops.h>
|
| 208 |
+
#include <ATen/ops/is_complex_ops.h>
|
| 209 |
+
#include <ATen/ops/is_conj_ops.h>
|
| 210 |
+
#include <ATen/ops/is_distributed_ops.h>
|
| 211 |
+
#include <ATen/ops/is_floating_point_ops.h>
|
| 212 |
+
#include <ATen/ops/is_inference_ops.h>
|
| 213 |
+
#include <ATen/ops/is_leaf_ops.h>
|
| 214 |
+
#include <ATen/ops/is_neg_ops.h>
|
| 215 |
+
#include <ATen/ops/is_nonzero_ops.h>
|
| 216 |
+
#include <ATen/ops/is_pinned_ops.h>
|
| 217 |
+
#include <ATen/ops/is_same_size_ops.h>
|
| 218 |
+
#include <ATen/ops/is_set_to_ops.h>
|
| 219 |
+
#include <ATen/ops/is_signed_ops.h>
|
| 220 |
+
#include <ATen/ops/isclose_ops.h>
|
| 221 |
+
#include <ATen/ops/isfinite_ops.h>
|
| 222 |
+
#include <ATen/ops/isinf_ops.h>
|
| 223 |
+
#include <ATen/ops/isnan_ops.h>
|
| 224 |
+
#include <ATen/ops/isneginf_ops.h>
|
| 225 |
+
#include <ATen/ops/isposinf_ops.h>
|
| 226 |
+
#include <ATen/ops/isreal_ops.h>
|
| 227 |
+
#include <ATen/ops/istft_ops.h>
|
| 228 |
+
#include <ATen/ops/item_ops.h>
|
| 229 |
+
#include <ATen/ops/kron_ops.h>
|
| 230 |
+
#include <ATen/ops/kthvalue_ops.h>
|
| 231 |
+
#include <ATen/ops/lcm_ops.h>
|
| 232 |
+
#include <ATen/ops/ldexp_ops.h>
|
| 233 |
+
#include <ATen/ops/le_ops.h>
|
| 234 |
+
#include <ATen/ops/lerp_ops.h>
|
| 235 |
+
#include <ATen/ops/less_equal_ops.h>
|
| 236 |
+
#include <ATen/ops/less_ops.h>
|
| 237 |
+
#include <ATen/ops/lgamma_ops.h>
|
| 238 |
+
#include <ATen/ops/log10_ops.h>
|
| 239 |
+
#include <ATen/ops/log1p_ops.h>
|
| 240 |
+
#include <ATen/ops/log2_ops.h>
|
| 241 |
+
#include <ATen/ops/log_normal_ops.h>
|
| 242 |
+
#include <ATen/ops/log_ops.h>
|
| 243 |
+
#include <ATen/ops/log_softmax_ops.h>
|
| 244 |
+
#include <ATen/ops/logaddexp2_ops.h>
|
| 245 |
+
#include <ATen/ops/logaddexp_ops.h>
|
| 246 |
+
#include <ATen/ops/logcumsumexp_ops.h>
|
| 247 |
+
#include <ATen/ops/logdet_ops.h>
|
| 248 |
+
#include <ATen/ops/logical_and_ops.h>
|
| 249 |
+
#include <ATen/ops/logical_not_ops.h>
|
| 250 |
+
#include <ATen/ops/logical_or_ops.h>
|
| 251 |
+
#include <ATen/ops/logical_xor_ops.h>
|
| 252 |
+
#include <ATen/ops/logit_ops.h>
|
| 253 |
+
#include <ATen/ops/logsumexp_ops.h>
|
| 254 |
+
#include <ATen/ops/lshift_ops.h>
|
| 255 |
+
#include <ATen/ops/lt_ops.h>
|
| 256 |
+
#include <ATen/ops/lu_solve_ops.h>
|
| 257 |
+
#include <ATen/ops/mH_ops.h>
|
| 258 |
+
#include <ATen/ops/mT_ops.h>
|
| 259 |
+
#include <ATen/ops/masked_fill_ops.h>
|
| 260 |
+
#include <ATen/ops/masked_scatter_ops.h>
|
| 261 |
+
#include <ATen/ops/masked_select_ops.h>
|
| 262 |
+
#include <ATen/ops/matmul_ops.h>
|
| 263 |
+
#include <ATen/ops/matrix_H_ops.h>
|
| 264 |
+
#include <ATen/ops/matrix_exp_ops.h>
|
| 265 |
+
#include <ATen/ops/matrix_power_ops.h>
|
| 266 |
+
#include <ATen/ops/max_ops.h>
|
| 267 |
+
#include <ATen/ops/maximum_ops.h>
|
| 268 |
+
#include <ATen/ops/mean_ops.h>
|
| 269 |
+
#include <ATen/ops/median_ops.h>
|
| 270 |
+
#include <ATen/ops/min_ops.h>
|
| 271 |
+
#include <ATen/ops/minimum_ops.h>
|
| 272 |
+
#include <ATen/ops/mm_ops.h>
|
| 273 |
+
#include <ATen/ops/mode_ops.h>
|
| 274 |
+
#include <ATen/ops/moveaxis_ops.h>
|
| 275 |
+
#include <ATen/ops/movedim_ops.h>
|
| 276 |
+
#include <ATen/ops/msort_ops.h>
|
| 277 |
+
#include <ATen/ops/mul_ops.h>
|
| 278 |
+
#include <ATen/ops/multinomial_ops.h>
|
| 279 |
+
#include <ATen/ops/multiply_ops.h>
|
| 280 |
+
#include <ATen/ops/mv_ops.h>
|
| 281 |
+
#include <ATen/ops/mvlgamma_ops.h>
|
| 282 |
+
#include <ATen/ops/nan_to_num_ops.h>
|
| 283 |
+
#include <ATen/ops/nanmean_ops.h>
|
| 284 |
+
#include <ATen/ops/nanmedian_ops.h>
|
| 285 |
+
#include <ATen/ops/nanquantile_ops.h>
|
| 286 |
+
#include <ATen/ops/nansum_ops.h>
|
| 287 |
+
#include <ATen/ops/narrow_copy_ops.h>
|
| 288 |
+
#include <ATen/ops/narrow_ops.h>
|
| 289 |
+
#include <ATen/ops/ne_ops.h>
|
| 290 |
+
#include <ATen/ops/neg_ops.h>
|
| 291 |
+
#include <ATen/ops/negative_ops.h>
|
| 292 |
+
#include <ATen/ops/new_empty_ops.h>
|
| 293 |
+
#include <ATen/ops/new_empty_strided_ops.h>
|
| 294 |
+
#include <ATen/ops/new_full_ops.h>
|
| 295 |
+
#include <ATen/ops/new_ones_ops.h>
|
| 296 |
+
#include <ATen/ops/new_zeros_ops.h>
|
| 297 |
+
#include <ATen/ops/nextafter_ops.h>
|
| 298 |
+
#include <ATen/ops/nonzero_numpy_ops.h>
|
| 299 |
+
#include <ATen/ops/nonzero_ops.h>
|
| 300 |
+
#include <ATen/ops/nonzero_static_ops.h>
|
| 301 |
+
#include <ATen/ops/norm_ops.h>
|
| 302 |
+
#include <ATen/ops/normal_ops.h>
|
| 303 |
+
#include <ATen/ops/not_equal_ops.h>
|
| 304 |
+
#include <ATen/ops/numpy_T_ops.h>
|
| 305 |
+
#include <ATen/ops/or_ops.h>
|
| 306 |
+
#include <ATen/ops/orgqr_ops.h>
|
| 307 |
+
#include <ATen/ops/ormqr_ops.h>
|
| 308 |
+
#include <ATen/ops/outer_ops.h>
|
| 309 |
+
#include <ATen/ops/output_nr_ops.h>
|
| 310 |
+
#include <ATen/ops/permute_ops.h>
|
| 311 |
+
#include <ATen/ops/pin_memory_ops.h>
|
| 312 |
+
#include <ATen/ops/pinverse_ops.h>
|
| 313 |
+
#include <ATen/ops/polygamma_ops.h>
|
| 314 |
+
#include <ATen/ops/positive_ops.h>
|
| 315 |
+
#include <ATen/ops/pow_ops.h>
|
| 316 |
+
#include <ATen/ops/prelu_ops.h>
|
| 317 |
+
#include <ATen/ops/prod_ops.h>
|
| 318 |
+
#include <ATen/ops/put_ops.h>
|
| 319 |
+
#include <ATen/ops/q_per_channel_axis_ops.h>
|
| 320 |
+
#include <ATen/ops/q_per_channel_scales_ops.h>
|
| 321 |
+
#include <ATen/ops/q_per_channel_zero_points_ops.h>
|
| 322 |
+
#include <ATen/ops/q_scale_ops.h>
|
| 323 |
+
#include <ATen/ops/q_zero_point_ops.h>
|
| 324 |
+
#include <ATen/ops/qr_ops.h>
|
| 325 |
+
#include <ATen/ops/qscheme_ops.h>
|
| 326 |
+
#include <ATen/ops/quantile_ops.h>
|
| 327 |
+
#include <ATen/ops/rad2deg_ops.h>
|
| 328 |
+
#include <ATen/ops/random_ops.h>
|
| 329 |
+
#include <ATen/ops/ravel_ops.h>
|
| 330 |
+
#include <ATen/ops/reciprocal_ops.h>
|
| 331 |
+
#include <ATen/ops/record_stream_ops.h>
|
| 332 |
+
#include <ATen/ops/refine_names_ops.h>
|
| 333 |
+
#include <ATen/ops/relu_ops.h>
|
| 334 |
+
#include <ATen/ops/remainder_ops.h>
|
| 335 |
+
#include <ATen/ops/rename_ops.h>
|
| 336 |
+
#include <ATen/ops/renorm_ops.h>
|
| 337 |
+
#include <ATen/ops/repeat_interleave_ops.h>
|
| 338 |
+
#include <ATen/ops/repeat_ops.h>
|
| 339 |
+
#include <ATen/ops/requires_grad_ops.h>
|
| 340 |
+
#include <ATen/ops/reshape_as_ops.h>
|
| 341 |
+
#include <ATen/ops/reshape_ops.h>
|
| 342 |
+
#include <ATen/ops/resize_as_ops.h>
|
| 343 |
+
#include <ATen/ops/resize_as_sparse_ops.h>
|
| 344 |
+
#include <ATen/ops/resize_ops.h>
|
| 345 |
+
#include <ATen/ops/resolve_conj_ops.h>
|
| 346 |
+
#include <ATen/ops/resolve_neg_ops.h>
|
| 347 |
+
#include <ATen/ops/retain_grad_ops.h>
|
| 348 |
+
#include <ATen/ops/retains_grad_ops.h>
|
| 349 |
+
#include <ATen/ops/roll_ops.h>
|
| 350 |
+
#include <ATen/ops/rot90_ops.h>
|
| 351 |
+
#include <ATen/ops/round_ops.h>
|
| 352 |
+
#include <ATen/ops/row_indices_ops.h>
|
| 353 |
+
#include <ATen/ops/rshift_ops.h>
|
| 354 |
+
#include <ATen/ops/rsqrt_ops.h>
|
| 355 |
+
#include <ATen/ops/scatter_add_ops.h>
|
| 356 |
+
#include <ATen/ops/scatter_ops.h>
|
| 357 |
+
#include <ATen/ops/scatter_reduce_ops.h>
|
| 358 |
+
#include <ATen/ops/select_ops.h>
|
| 359 |
+
#include <ATen/ops/select_scatter_ops.h>
|
| 360 |
+
#include <ATen/ops/set_data_ops.h>
|
| 361 |
+
#include <ATen/ops/set_ops.h>
|
| 362 |
+
#include <ATen/ops/sgn_ops.h>
|
| 363 |
+
#include <ATen/ops/sigmoid_ops.h>
|
| 364 |
+
#include <ATen/ops/sign_ops.h>
|
| 365 |
+
#include <ATen/ops/signbit_ops.h>
|
| 366 |
+
#include <ATen/ops/sin_ops.h>
|
| 367 |
+
#include <ATen/ops/sinc_ops.h>
|
| 368 |
+
#include <ATen/ops/sinh_ops.h>
|
| 369 |
+
#include <ATen/ops/size_ops.h>
|
| 370 |
+
#include <ATen/ops/slice_inverse_ops.h>
|
| 371 |
+
#include <ATen/ops/slice_ops.h>
|
| 372 |
+
#include <ATen/ops/slice_scatter_ops.h>
|
| 373 |
+
#include <ATen/ops/slogdet_ops.h>
|
| 374 |
+
#include <ATen/ops/smm_ops.h>
|
| 375 |
+
#include <ATen/ops/softmax_ops.h>
|
| 376 |
+
#include <ATen/ops/sort_ops.h>
|
| 377 |
+
#include <ATen/ops/sparse_dim_ops.h>
|
| 378 |
+
#include <ATen/ops/sparse_mask_ops.h>
|
| 379 |
+
#include <ATen/ops/sparse_resize_and_clear_ops.h>
|
| 380 |
+
#include <ATen/ops/sparse_resize_ops.h>
|
| 381 |
+
#include <ATen/ops/split_ops.h>
|
| 382 |
+
#include <ATen/ops/split_with_sizes_ops.h>
|
| 383 |
+
#include <ATen/ops/sqrt_ops.h>
|
| 384 |
+
#include <ATen/ops/square_ops.h>
|
| 385 |
+
#include <ATen/ops/squeeze_ops.h>
|
| 386 |
+
#include <ATen/ops/sspaddmm_ops.h>
|
| 387 |
+
#include <ATen/ops/std_ops.h>
|
| 388 |
+
#include <ATen/ops/stft_ops.h>
|
| 389 |
+
#include <ATen/ops/stride_ops.h>
|
| 390 |
+
#include <ATen/ops/sub_ops.h>
|
| 391 |
+
#include <ATen/ops/subtract_ops.h>
|
| 392 |
+
#include <ATen/ops/sum_ops.h>
|
| 393 |
+
#include <ATen/ops/sum_to_size_ops.h>
|
| 394 |
+
#include <ATen/ops/svd_ops.h>
|
| 395 |
+
#include <ATen/ops/swapaxes_ops.h>
|
| 396 |
+
#include <ATen/ops/swapdims_ops.h>
|
| 397 |
+
#include <ATen/ops/t_ops.h>
|
| 398 |
+
#include <ATen/ops/take_along_dim_ops.h>
|
| 399 |
+
#include <ATen/ops/take_ops.h>
|
| 400 |
+
#include <ATen/ops/tan_ops.h>
|
| 401 |
+
#include <ATen/ops/tanh_ops.h>
|
| 402 |
+
#include <ATen/ops/tensor_split_ops.h>
|
| 403 |
+
#include <ATen/ops/tile_ops.h>
|
| 404 |
+
#include <ATen/ops/to_dense_ops.h>
|
| 405 |
+
#include <ATen/ops/to_mkldnn_ops.h>
|
| 406 |
+
#include <ATen/ops/to_ops.h>
|
| 407 |
+
#include <ATen/ops/to_padded_tensor_ops.h>
|
| 408 |
+
#include <ATen/ops/to_sparse_bsc_ops.h>
|
| 409 |
+
#include <ATen/ops/to_sparse_bsr_ops.h>
|
| 410 |
+
#include <ATen/ops/to_sparse_csc_ops.h>
|
| 411 |
+
#include <ATen/ops/to_sparse_csr_ops.h>
|
| 412 |
+
#include <ATen/ops/to_sparse_ops.h>
|
| 413 |
+
#include <ATen/ops/topk_ops.h>
|
| 414 |
+
#include <ATen/ops/trace_ops.h>
|
| 415 |
+
#include <ATen/ops/transpose_ops.h>
|
| 416 |
+
#include <ATen/ops/triangular_solve_ops.h>
|
| 417 |
+
#include <ATen/ops/tril_ops.h>
|
| 418 |
+
#include <ATen/ops/triu_ops.h>
|
| 419 |
+
#include <ATen/ops/true_divide_ops.h>
|
| 420 |
+
#include <ATen/ops/trunc_ops.h>
|
| 421 |
+
#include <ATen/ops/type_as_ops.h>
|
| 422 |
+
#include <ATen/ops/unbind_ops.h>
|
| 423 |
+
#include <ATen/ops/unflatten_ops.h>
|
| 424 |
+
#include <ATen/ops/unfold_ops.h>
|
| 425 |
+
#include <ATen/ops/uniform_ops.h>
|
| 426 |
+
#include <ATen/ops/unsafe_chunk_ops.h>
|
| 427 |
+
#include <ATen/ops/unsafe_split_ops.h>
|
| 428 |
+
#include <ATen/ops/unsafe_split_with_sizes_ops.h>
|
| 429 |
+
#include <ATen/ops/unsqueeze_ops.h>
|
| 430 |
+
#include <ATen/ops/values_ops.h>
|
| 431 |
+
#include <ATen/ops/var_ops.h>
|
| 432 |
+
#include <ATen/ops/vdot_ops.h>
|
| 433 |
+
#include <ATen/ops/view_as_ops.h>
|
| 434 |
+
#include <ATen/ops/view_ops.h>
|
| 435 |
+
#include <ATen/ops/vsplit_ops.h>
|
| 436 |
+
#include <ATen/ops/where_ops.h>
|
| 437 |
+
#include <ATen/ops/xlogy_ops.h>
|
| 438 |
+
#include <ATen/ops/xor_ops.h>
|
| 439 |
+
#include <ATen/ops/zero_ops.h>
|
| 440 |
+
|
| 441 |
+
namespace at {
|
| 442 |
+
namespace _ops {
|
| 443 |
+
|
| 444 |
+
} // namespace _ops
|
| 445 |
+
} // namespace at
|
| 446 |
+
|
| 447 |
+
#else
|
| 448 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 449 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/NamedTensor.h
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#include <ATen/core/NamedTensor.h>
|
| 3 |
+
|
| 4 |
+
#else
|
| 5 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 6 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/NativeMetaFunctions.h
ADDED
|
@@ -0,0 +1,1352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
// @generated by torchgen/gen.py from NativeMetaFunctions.h
|
| 5 |
+
|
| 6 |
+
#include <ATen/core/Tensor.h>
|
| 7 |
+
#include <ATen/core/IListRef.h>
|
| 8 |
+
#include <ATen/TensorMeta.h>
|
| 9 |
+
#include <ATen/TensorIterator.h>
|
| 10 |
+
|
| 11 |
+
#include <ATen/ops/_adaptive_avg_pool2d_meta.h>
|
| 12 |
+
#include <ATen/ops/_adaptive_avg_pool2d_backward_meta.h>
|
| 13 |
+
#include <ATen/ops/_adaptive_avg_pool3d_meta.h>
|
| 14 |
+
#include <ATen/ops/_adaptive_avg_pool3d_backward_meta.h>
|
| 15 |
+
#include <ATen/ops/_add_batch_dim_meta.h>
|
| 16 |
+
#include <ATen/ops/_add_relu_meta.h>
|
| 17 |
+
#include <ATen/ops/_addmm_activation_meta.h>
|
| 18 |
+
#include <ATen/ops/_aminmax_meta.h>
|
| 19 |
+
#include <ATen/ops/_amp_foreach_non_finite_check_and_unscale_meta.h>
|
| 20 |
+
#include <ATen/ops/_amp_update_scale_meta.h>
|
| 21 |
+
#include <ATen/ops/_assert_async_meta.h>
|
| 22 |
+
#include <ATen/ops/_assert_scalar_meta.h>
|
| 23 |
+
#include <ATen/ops/_assert_tensor_metadata_meta.h>
|
| 24 |
+
#include <ATen/ops/_autocast_to_full_precision_meta.h>
|
| 25 |
+
#include <ATen/ops/_autocast_to_reduced_precision_meta.h>
|
| 26 |
+
#include <ATen/ops/_backward_meta.h>
|
| 27 |
+
#include <ATen/ops/_batch_norm_impl_index_meta.h>
|
| 28 |
+
#include <ATen/ops/_batch_norm_impl_index_backward_meta.h>
|
| 29 |
+
#include <ATen/ops/_batch_norm_no_update_meta.h>
|
| 30 |
+
#include <ATen/ops/_batch_norm_with_update_meta.h>
|
| 31 |
+
#include <ATen/ops/_cast_Byte_meta.h>
|
| 32 |
+
#include <ATen/ops/_cast_Char_meta.h>
|
| 33 |
+
#include <ATen/ops/_cast_Double_meta.h>
|
| 34 |
+
#include <ATen/ops/_cast_Float_meta.h>
|
| 35 |
+
#include <ATen/ops/_cast_Half_meta.h>
|
| 36 |
+
#include <ATen/ops/_cast_Int_meta.h>
|
| 37 |
+
#include <ATen/ops/_cast_Long_meta.h>
|
| 38 |
+
#include <ATen/ops/_cast_Short_meta.h>
|
| 39 |
+
#include <ATen/ops/_cdist_backward_meta.h>
|
| 40 |
+
#include <ATen/ops/_cdist_forward_meta.h>
|
| 41 |
+
#include <ATen/ops/_cholesky_solve_helper_meta.h>
|
| 42 |
+
#include <ATen/ops/_choose_qparams_per_tensor_meta.h>
|
| 43 |
+
#include <ATen/ops/_chunk_cat_meta.h>
|
| 44 |
+
#include <ATen/ops/_coalesce_meta.h>
|
| 45 |
+
#include <ATen/ops/_coalesced_meta.h>
|
| 46 |
+
#include <ATen/ops/_compute_linear_combination_meta.h>
|
| 47 |
+
#include <ATen/ops/_conj_meta.h>
|
| 48 |
+
#include <ATen/ops/_conj_copy_meta.h>
|
| 49 |
+
#include <ATen/ops/_conj_physical_meta.h>
|
| 50 |
+
#include <ATen/ops/_conv_depthwise2d_meta.h>
|
| 51 |
+
#include <ATen/ops/_convert_indices_from_coo_to_csr_meta.h>
|
| 52 |
+
#include <ATen/ops/_convert_indices_from_csr_to_coo_meta.h>
|
| 53 |
+
#include <ATen/ops/_convert_weight_to_int4pack_meta.h>
|
| 54 |
+
#include <ATen/ops/_convert_weight_to_int4pack_for_cpu_meta.h>
|
| 55 |
+
#include <ATen/ops/_convolution_meta.h>
|
| 56 |
+
#include <ATen/ops/_convolution_double_backward_meta.h>
|
| 57 |
+
#include <ATen/ops/_convolution_mode_meta.h>
|
| 58 |
+
#include <ATen/ops/_copy_from_meta.h>
|
| 59 |
+
#include <ATen/ops/_copy_from_and_resize_meta.h>
|
| 60 |
+
#include <ATen/ops/_cslt_compress_meta.h>
|
| 61 |
+
#include <ATen/ops/_cslt_sparse_mm_meta.h>
|
| 62 |
+
#include <ATen/ops/_cslt_sparse_mm_search_meta.h>
|
| 63 |
+
#include <ATen/ops/_ctc_loss_meta.h>
|
| 64 |
+
#include <ATen/ops/_ctc_loss_backward_meta.h>
|
| 65 |
+
#include <ATen/ops/_cudnn_attention_backward_meta.h>
|
| 66 |
+
#include <ATen/ops/_cudnn_attention_forward_meta.h>
|
| 67 |
+
#include <ATen/ops/_cudnn_ctc_loss_meta.h>
|
| 68 |
+
#include <ATen/ops/_cudnn_init_dropout_state_meta.h>
|
| 69 |
+
#include <ATen/ops/_cudnn_rnn_meta.h>
|
| 70 |
+
#include <ATen/ops/_cudnn_rnn_backward_meta.h>
|
| 71 |
+
#include <ATen/ops/_cudnn_rnn_flatten_weight_meta.h>
|
| 72 |
+
#include <ATen/ops/_cufft_clear_plan_cache_meta.h>
|
| 73 |
+
#include <ATen/ops/_cufft_get_plan_cache_max_size_meta.h>
|
| 74 |
+
#include <ATen/ops/_cufft_get_plan_cache_size_meta.h>
|
| 75 |
+
#include <ATen/ops/_cufft_set_plan_cache_max_size_meta.h>
|
| 76 |
+
#include <ATen/ops/_cummax_helper_meta.h>
|
| 77 |
+
#include <ATen/ops/_cummin_helper_meta.h>
|
| 78 |
+
#include <ATen/ops/_debug_has_internal_overlap_meta.h>
|
| 79 |
+
#include <ATen/ops/_dimI_meta.h>
|
| 80 |
+
#include <ATen/ops/_dimV_meta.h>
|
| 81 |
+
#include <ATen/ops/_dim_arange_meta.h>
|
| 82 |
+
#include <ATen/ops/_dirichlet_grad_meta.h>
|
| 83 |
+
#include <ATen/ops/_dyn_quant_matmul_4bit_meta.h>
|
| 84 |
+
#include <ATen/ops/_dyn_quant_pack_4bit_weight_meta.h>
|
| 85 |
+
#include <ATen/ops/_efficient_attention_backward_meta.h>
|
| 86 |
+
#include <ATen/ops/_efficient_attention_forward_meta.h>
|
| 87 |
+
#include <ATen/ops/_efficientzerotensor_meta.h>
|
| 88 |
+
#include <ATen/ops/_embedding_bag_meta.h>
|
| 89 |
+
#include <ATen/ops/_embedding_bag_backward_meta.h>
|
| 90 |
+
#include <ATen/ops/_embedding_bag_dense_backward_meta.h>
|
| 91 |
+
#include <ATen/ops/_embedding_bag_forward_only_meta.h>
|
| 92 |
+
#include <ATen/ops/_embedding_bag_per_sample_weights_backward_meta.h>
|
| 93 |
+
#include <ATen/ops/_embedding_bag_sparse_backward_meta.h>
|
| 94 |
+
#include <ATen/ops/_empty_affine_quantized_meta.h>
|
| 95 |
+
#include <ATen/ops/_empty_per_channel_affine_quantized_meta.h>
|
| 96 |
+
#include <ATen/ops/_euclidean_dist_meta.h>
|
| 97 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_meta.h>
|
| 98 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_backward_meta.h>
|
| 99 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_meta.h>
|
| 100 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_backward_meta.h>
|
| 101 |
+
#include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_meta.h>
|
| 102 |
+
#include <ATen/ops/_fft_c2c_meta.h>
|
| 103 |
+
#include <ATen/ops/_fft_c2r_meta.h>
|
| 104 |
+
#include <ATen/ops/_fft_r2c_meta.h>
|
| 105 |
+
#include <ATen/ops/_fill_mem_eff_dropout_mask_meta.h>
|
| 106 |
+
#include <ATen/ops/_flash_attention_backward_meta.h>
|
| 107 |
+
#include <ATen/ops/_flash_attention_forward_meta.h>
|
| 108 |
+
#include <ATen/ops/_foobar_meta.h>
|
| 109 |
+
#include <ATen/ops/_foreach_abs_meta.h>
|
| 110 |
+
#include <ATen/ops/_foreach_acos_meta.h>
|
| 111 |
+
#include <ATen/ops/_foreach_add_meta.h>
|
| 112 |
+
#include <ATen/ops/_foreach_addcdiv_meta.h>
|
| 113 |
+
#include <ATen/ops/_foreach_addcmul_meta.h>
|
| 114 |
+
#include <ATen/ops/_foreach_asin_meta.h>
|
| 115 |
+
#include <ATen/ops/_foreach_atan_meta.h>
|
| 116 |
+
#include <ATen/ops/_foreach_ceil_meta.h>
|
| 117 |
+
#include <ATen/ops/_foreach_clamp_max_meta.h>
|
| 118 |
+
#include <ATen/ops/_foreach_clamp_min_meta.h>
|
| 119 |
+
#include <ATen/ops/_foreach_copy_meta.h>
|
| 120 |
+
#include <ATen/ops/_foreach_cos_meta.h>
|
| 121 |
+
#include <ATen/ops/_foreach_cosh_meta.h>
|
| 122 |
+
#include <ATen/ops/_foreach_div_meta.h>
|
| 123 |
+
#include <ATen/ops/_foreach_erf_meta.h>
|
| 124 |
+
#include <ATen/ops/_foreach_erfc_meta.h>
|
| 125 |
+
#include <ATen/ops/_foreach_exp_meta.h>
|
| 126 |
+
#include <ATen/ops/_foreach_expm1_meta.h>
|
| 127 |
+
#include <ATen/ops/_foreach_floor_meta.h>
|
| 128 |
+
#include <ATen/ops/_foreach_frac_meta.h>
|
| 129 |
+
#include <ATen/ops/_foreach_lerp_meta.h>
|
| 130 |
+
#include <ATen/ops/_foreach_lgamma_meta.h>
|
| 131 |
+
#include <ATen/ops/_foreach_log_meta.h>
|
| 132 |
+
#include <ATen/ops/_foreach_log10_meta.h>
|
| 133 |
+
#include <ATen/ops/_foreach_log1p_meta.h>
|
| 134 |
+
#include <ATen/ops/_foreach_log2_meta.h>
|
| 135 |
+
#include <ATen/ops/_foreach_max_meta.h>
|
| 136 |
+
#include <ATen/ops/_foreach_maximum_meta.h>
|
| 137 |
+
#include <ATen/ops/_foreach_minimum_meta.h>
|
| 138 |
+
#include <ATen/ops/_foreach_mul_meta.h>
|
| 139 |
+
#include <ATen/ops/_foreach_neg_meta.h>
|
| 140 |
+
#include <ATen/ops/_foreach_norm_meta.h>
|
| 141 |
+
#include <ATen/ops/_foreach_pow_meta.h>
|
| 142 |
+
#include <ATen/ops/_foreach_reciprocal_meta.h>
|
| 143 |
+
#include <ATen/ops/_foreach_round_meta.h>
|
| 144 |
+
#include <ATen/ops/_foreach_rsqrt_meta.h>
|
| 145 |
+
#include <ATen/ops/_foreach_sigmoid_meta.h>
|
| 146 |
+
#include <ATen/ops/_foreach_sign_meta.h>
|
| 147 |
+
#include <ATen/ops/_foreach_sin_meta.h>
|
| 148 |
+
#include <ATen/ops/_foreach_sinh_meta.h>
|
| 149 |
+
#include <ATen/ops/_foreach_sqrt_meta.h>
|
| 150 |
+
#include <ATen/ops/_foreach_sub_meta.h>
|
| 151 |
+
#include <ATen/ops/_foreach_tan_meta.h>
|
| 152 |
+
#include <ATen/ops/_foreach_tanh_meta.h>
|
| 153 |
+
#include <ATen/ops/_foreach_trunc_meta.h>
|
| 154 |
+
#include <ATen/ops/_foreach_zero_meta.h>
|
| 155 |
+
#include <ATen/ops/_functional_assert_async_meta.h>
|
| 156 |
+
#include <ATen/ops/_functional_assert_scalar_meta.h>
|
| 157 |
+
#include <ATen/ops/_functional_sym_constrain_range_meta.h>
|
| 158 |
+
#include <ATen/ops/_functional_sym_constrain_range_for_size_meta.h>
|
| 159 |
+
#include <ATen/ops/_fused_adagrad_meta.h>
|
| 160 |
+
#include <ATen/ops/_fused_adam_meta.h>
|
| 161 |
+
#include <ATen/ops/_fused_adamw_meta.h>
|
| 162 |
+
#include <ATen/ops/_fused_dropout_meta.h>
|
| 163 |
+
#include <ATen/ops/_fused_moving_avg_obs_fq_helper_meta.h>
|
| 164 |
+
#include <ATen/ops/_fused_rms_norm_meta.h>
|
| 165 |
+
#include <ATen/ops/_fused_rms_norm_backward_meta.h>
|
| 166 |
+
#include <ATen/ops/_fused_sdp_choice_meta.h>
|
| 167 |
+
#include <ATen/ops/_fused_sgd_meta.h>
|
| 168 |
+
#include <ATen/ops/_fw_primal_meta.h>
|
| 169 |
+
#include <ATen/ops/_fw_primal_copy_meta.h>
|
| 170 |
+
#include <ATen/ops/_gather_sparse_backward_meta.h>
|
| 171 |
+
#include <ATen/ops/_grid_sampler_2d_cpu_fallback_meta.h>
|
| 172 |
+
#include <ATen/ops/_grid_sampler_2d_cpu_fallback_backward_meta.h>
|
| 173 |
+
#include <ATen/ops/_grouped_mm_meta.h>
|
| 174 |
+
#include <ATen/ops/_has_compatible_shallow_copy_type_meta.h>
|
| 175 |
+
#include <ATen/ops/_has_same_storage_numel_meta.h>
|
| 176 |
+
#include <ATen/ops/_histogramdd_bin_edges_meta.h>
|
| 177 |
+
#include <ATen/ops/_histogramdd_from_bin_cts_meta.h>
|
| 178 |
+
#include <ATen/ops/_histogramdd_from_bin_tensors_meta.h>
|
| 179 |
+
#include <ATen/ops/_index_put_impl_meta.h>
|
| 180 |
+
#include <ATen/ops/_indices_meta.h>
|
| 181 |
+
#include <ATen/ops/_indices_copy_meta.h>
|
| 182 |
+
#include <ATen/ops/_int_mm_meta.h>
|
| 183 |
+
#include <ATen/ops/_is_all_true_meta.h>
|
| 184 |
+
#include <ATen/ops/_is_any_true_meta.h>
|
| 185 |
+
#include <ATen/ops/_is_zerotensor_meta.h>
|
| 186 |
+
#include <ATen/ops/_jagged_to_padded_dense_forward_meta.h>
|
| 187 |
+
#include <ATen/ops/_lazy_clone_meta.h>
|
| 188 |
+
#include <ATen/ops/_linalg_check_errors_meta.h>
|
| 189 |
+
#include <ATen/ops/_linalg_det_meta.h>
|
| 190 |
+
#include <ATen/ops/_linalg_eigh_meta.h>
|
| 191 |
+
#include <ATen/ops/_linalg_eigvals_meta.h>
|
| 192 |
+
#include <ATen/ops/_linalg_slogdet_meta.h>
|
| 193 |
+
#include <ATen/ops/_linalg_solve_ex_meta.h>
|
| 194 |
+
#include <ATen/ops/_linalg_svd_meta.h>
|
| 195 |
+
#include <ATen/ops/_local_scalar_dense_meta.h>
|
| 196 |
+
#include <ATen/ops/_log_softmax_meta.h>
|
| 197 |
+
#include <ATen/ops/_log_softmax_backward_data_meta.h>
|
| 198 |
+
#include <ATen/ops/_logcumsumexp_meta.h>
|
| 199 |
+
#include <ATen/ops/_lstm_mps_meta.h>
|
| 200 |
+
#include <ATen/ops/_lu_with_info_meta.h>
|
| 201 |
+
#include <ATen/ops/_make_dep_token_meta.h>
|
| 202 |
+
#include <ATen/ops/_make_dual_meta.h>
|
| 203 |
+
#include <ATen/ops/_make_dual_copy_meta.h>
|
| 204 |
+
#include <ATen/ops/_make_per_channel_quantized_tensor_meta.h>
|
| 205 |
+
#include <ATen/ops/_make_per_tensor_quantized_tensor_meta.h>
|
| 206 |
+
#include <ATen/ops/_masked_scale_meta.h>
|
| 207 |
+
#include <ATen/ops/_masked_softmax_meta.h>
|
| 208 |
+
#include <ATen/ops/_masked_softmax_backward_meta.h>
|
| 209 |
+
#include <ATen/ops/_mixed_dtypes_linear_meta.h>
|
| 210 |
+
#include <ATen/ops/_mkldnn_reshape_meta.h>
|
| 211 |
+
#include <ATen/ops/_mkldnn_transpose_meta.h>
|
| 212 |
+
#include <ATen/ops/_mps_convolution_meta.h>
|
| 213 |
+
#include <ATen/ops/_mps_convolution_transpose_meta.h>
|
| 214 |
+
#include <ATen/ops/_native_batch_norm_legit_meta.h>
|
| 215 |
+
#include <ATen/ops/_native_batch_norm_legit_no_training_meta.h>
|
| 216 |
+
#include <ATen/ops/_native_multi_head_attention_meta.h>
|
| 217 |
+
#include <ATen/ops/_neg_view_meta.h>
|
| 218 |
+
#include <ATen/ops/_neg_view_copy_meta.h>
|
| 219 |
+
#include <ATen/ops/_nested_compute_contiguous_strides_offsets_meta.h>
|
| 220 |
+
#include <ATen/ops/_nested_from_padded_meta.h>
|
| 221 |
+
#include <ATen/ops/_nested_from_padded_and_nested_example_meta.h>
|
| 222 |
+
#include <ATen/ops/_nested_from_padded_tensor_meta.h>
|
| 223 |
+
#include <ATen/ops/_nested_get_jagged_dummy_meta.h>
|
| 224 |
+
#include <ATen/ops/_nested_get_lengths_meta.h>
|
| 225 |
+
#include <ATen/ops/_nested_get_max_seqlen_meta.h>
|
| 226 |
+
#include <ATen/ops/_nested_get_min_seqlen_meta.h>
|
| 227 |
+
#include <ATen/ops/_nested_get_offsets_meta.h>
|
| 228 |
+
#include <ATen/ops/_nested_get_ragged_idx_meta.h>
|
| 229 |
+
#include <ATen/ops/_nested_get_values_meta.h>
|
| 230 |
+
#include <ATen/ops/_nested_get_values_copy_meta.h>
|
| 231 |
+
#include <ATen/ops/_nested_select_backward_meta.h>
|
| 232 |
+
#include <ATen/ops/_nested_sum_backward_meta.h>
|
| 233 |
+
#include <ATen/ops/_nested_tensor_from_mask_meta.h>
|
| 234 |
+
#include <ATen/ops/_nested_tensor_from_mask_left_aligned_meta.h>
|
| 235 |
+
#include <ATen/ops/_nested_tensor_from_tensor_list_meta.h>
|
| 236 |
+
#include <ATen/ops/_nested_tensor_size_meta.h>
|
| 237 |
+
#include <ATen/ops/_nested_tensor_softmax_with_shape_meta.h>
|
| 238 |
+
#include <ATen/ops/_nested_tensor_storage_offsets_meta.h>
|
| 239 |
+
#include <ATen/ops/_nested_tensor_strides_meta.h>
|
| 240 |
+
#include <ATen/ops/_nested_view_from_buffer_meta.h>
|
| 241 |
+
#include <ATen/ops/_nested_view_from_buffer_copy_meta.h>
|
| 242 |
+
#include <ATen/ops/_nested_view_from_jagged_meta.h>
|
| 243 |
+
#include <ATen/ops/_nested_view_from_jagged_copy_meta.h>
|
| 244 |
+
#include <ATen/ops/_new_zeros_with_same_feature_meta_meta.h>
|
| 245 |
+
#include <ATen/ops/_nnpack_available_meta.h>
|
| 246 |
+
#include <ATen/ops/_nnpack_spatial_convolution_meta.h>
|
| 247 |
+
#include <ATen/ops/_nnz_meta.h>
|
| 248 |
+
#include <ATen/ops/_pack_padded_sequence_meta.h>
|
| 249 |
+
#include <ATen/ops/_pack_padded_sequence_backward_meta.h>
|
| 250 |
+
#include <ATen/ops/_pad_circular_meta.h>
|
| 251 |
+
#include <ATen/ops/_pad_enum_meta.h>
|
| 252 |
+
#include <ATen/ops/_pad_packed_sequence_meta.h>
|
| 253 |
+
#include <ATen/ops/_padded_dense_to_jagged_forward_meta.h>
|
| 254 |
+
#include <ATen/ops/_pdist_backward_meta.h>
|
| 255 |
+
#include <ATen/ops/_pdist_forward_meta.h>
|
| 256 |
+
#include <ATen/ops/_pin_memory_meta.h>
|
| 257 |
+
#include <ATen/ops/_prelu_kernel_meta.h>
|
| 258 |
+
#include <ATen/ops/_prelu_kernel_backward_meta.h>
|
| 259 |
+
#include <ATen/ops/_print_meta.h>
|
| 260 |
+
#include <ATen/ops/_propagate_xla_data_meta.h>
|
| 261 |
+
#include <ATen/ops/_remove_batch_dim_meta.h>
|
| 262 |
+
#include <ATen/ops/_reshape_alias_meta.h>
|
| 263 |
+
#include <ATen/ops/_reshape_alias_copy_meta.h>
|
| 264 |
+
#include <ATen/ops/_reshape_copy_meta.h>
|
| 265 |
+
#include <ATen/ops/_reshape_from_tensor_meta.h>
|
| 266 |
+
#include <ATen/ops/_resize_output_meta.h>
|
| 267 |
+
#include <ATen/ops/_rowwise_prune_meta.h>
|
| 268 |
+
#include <ATen/ops/_safe_softmax_meta.h>
|
| 269 |
+
#include <ATen/ops/_sample_dirichlet_meta.h>
|
| 270 |
+
#include <ATen/ops/_saturate_weight_to_fp16_meta.h>
|
| 271 |
+
#include <ATen/ops/_scaled_dot_product_attention_math_meta.h>
|
| 272 |
+
#include <ATen/ops/_scaled_dot_product_attention_math_for_mps_meta.h>
|
| 273 |
+
#include <ATen/ops/_scaled_dot_product_cudnn_attention_meta.h>
|
| 274 |
+
#include <ATen/ops/_scaled_dot_product_cudnn_attention_backward_meta.h>
|
| 275 |
+
#include <ATen/ops/_scaled_dot_product_efficient_attention_meta.h>
|
| 276 |
+
#include <ATen/ops/_scaled_dot_product_efficient_attention_backward_meta.h>
|
| 277 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_meta.h>
|
| 278 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_backward_meta.h>
|
| 279 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_meta.h>
|
| 280 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward_meta.h>
|
| 281 |
+
#include <ATen/ops/_scaled_dot_product_fused_attention_overrideable_meta.h>
|
| 282 |
+
#include <ATen/ops/_scaled_dot_product_fused_attention_overrideable_backward_meta.h>
|
| 283 |
+
#include <ATen/ops/_scaled_grouped_mm_meta.h>
|
| 284 |
+
#include <ATen/ops/_scaled_grouped_mm_v2_meta.h>
|
| 285 |
+
#include <ATen/ops/_scaled_mm_meta.h>
|
| 286 |
+
#include <ATen/ops/_scaled_mm_v2_meta.h>
|
| 287 |
+
#include <ATen/ops/_segment_reduce_backward_meta.h>
|
| 288 |
+
#include <ATen/ops/_shape_as_tensor_meta.h>
|
| 289 |
+
#include <ATen/ops/_slow_conv2d_backward_meta.h>
|
| 290 |
+
#include <ATen/ops/_slow_conv2d_forward_meta.h>
|
| 291 |
+
#include <ATen/ops/_sobol_engine_draw_meta.h>
|
| 292 |
+
#include <ATen/ops/_sobol_engine_ff_meta.h>
|
| 293 |
+
#include <ATen/ops/_sobol_engine_initialize_state_meta.h>
|
| 294 |
+
#include <ATen/ops/_sobol_engine_scramble_meta.h>
|
| 295 |
+
#include <ATen/ops/_softmax_meta.h>
|
| 296 |
+
#include <ATen/ops/_softmax_backward_data_meta.h>
|
| 297 |
+
#include <ATen/ops/_sparse_addmm_meta.h>
|
| 298 |
+
#include <ATen/ops/_sparse_broadcast_to_meta.h>
|
| 299 |
+
#include <ATen/ops/_sparse_broadcast_to_copy_meta.h>
|
| 300 |
+
#include <ATen/ops/_sparse_bsc_tensor_unsafe_meta.h>
|
| 301 |
+
#include <ATen/ops/_sparse_bsr_tensor_unsafe_meta.h>
|
| 302 |
+
#include <ATen/ops/_sparse_compressed_tensor_unsafe_meta.h>
|
| 303 |
+
#include <ATen/ops/_sparse_compressed_tensor_with_dims_meta.h>
|
| 304 |
+
#include <ATen/ops/_sparse_coo_tensor_unsafe_meta.h>
|
| 305 |
+
#include <ATen/ops/_sparse_coo_tensor_with_dims_meta.h>
|
| 306 |
+
#include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors_meta.h>
|
| 307 |
+
#include <ATen/ops/_sparse_csc_tensor_unsafe_meta.h>
|
| 308 |
+
#include <ATen/ops/_sparse_csr_prod_meta.h>
|
| 309 |
+
#include <ATen/ops/_sparse_csr_sum_meta.h>
|
| 310 |
+
#include <ATen/ops/_sparse_csr_tensor_unsafe_meta.h>
|
| 311 |
+
#include <ATen/ops/_sparse_log_softmax_meta.h>
|
| 312 |
+
#include <ATen/ops/_sparse_log_softmax_backward_data_meta.h>
|
| 313 |
+
#include <ATen/ops/_sparse_mask_projection_meta.h>
|
| 314 |
+
#include <ATen/ops/_sparse_mm_meta.h>
|
| 315 |
+
#include <ATen/ops/_sparse_mm_reduce_impl_meta.h>
|
| 316 |
+
#include <ATen/ops/_sparse_mm_reduce_impl_backward_meta.h>
|
| 317 |
+
#include <ATen/ops/_sparse_semi_structured_addmm_meta.h>
|
| 318 |
+
#include <ATen/ops/_sparse_semi_structured_apply_meta.h>
|
| 319 |
+
#include <ATen/ops/_sparse_semi_structured_apply_dense_meta.h>
|
| 320 |
+
#include <ATen/ops/_sparse_semi_structured_linear_meta.h>
|
| 321 |
+
#include <ATen/ops/_sparse_semi_structured_mm_meta.h>
|
| 322 |
+
#include <ATen/ops/_sparse_semi_structured_tile_meta.h>
|
| 323 |
+
#include <ATen/ops/_sparse_softmax_meta.h>
|
| 324 |
+
#include <ATen/ops/_sparse_softmax_backward_data_meta.h>
|
| 325 |
+
#include <ATen/ops/_sparse_sparse_matmul_meta.h>
|
| 326 |
+
#include <ATen/ops/_sparse_sum_meta.h>
|
| 327 |
+
#include <ATen/ops/_sparse_sum_backward_meta.h>
|
| 328 |
+
#include <ATen/ops/_spdiags_meta.h>
|
| 329 |
+
#include <ATen/ops/_spsolve_meta.h>
|
| 330 |
+
#include <ATen/ops/_stack_meta.h>
|
| 331 |
+
#include <ATen/ops/_standard_gamma_meta.h>
|
| 332 |
+
#include <ATen/ops/_standard_gamma_grad_meta.h>
|
| 333 |
+
#include <ATen/ops/_test_ambiguous_defaults_meta.h>
|
| 334 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_meta.h>
|
| 335 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_view_meta.h>
|
| 336 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_view_copy_meta.h>
|
| 337 |
+
#include <ATen/ops/_test_check_tensor_meta.h>
|
| 338 |
+
#include <ATen/ops/_test_functorch_fallback_meta.h>
|
| 339 |
+
#include <ATen/ops/_test_optional_filled_intlist_meta.h>
|
| 340 |
+
#include <ATen/ops/_test_optional_floatlist_meta.h>
|
| 341 |
+
#include <ATen/ops/_test_optional_intlist_meta.h>
|
| 342 |
+
#include <ATen/ops/_test_parallel_materialize_meta.h>
|
| 343 |
+
#include <ATen/ops/_test_serialization_subcmul_meta.h>
|
| 344 |
+
#include <ATen/ops/_test_string_default_meta.h>
|
| 345 |
+
#include <ATen/ops/_test_warn_in_autograd_meta.h>
|
| 346 |
+
#include <ATen/ops/_thnn_differentiable_gru_cell_backward_meta.h>
|
| 347 |
+
#include <ATen/ops/_thnn_differentiable_lstm_cell_backward_meta.h>
|
| 348 |
+
#include <ATen/ops/_thnn_fused_gru_cell_meta.h>
|
| 349 |
+
#include <ATen/ops/_thnn_fused_gru_cell_backward_meta.h>
|
| 350 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_meta.h>
|
| 351 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_backward_meta.h>
|
| 352 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_backward_impl_meta.h>
|
| 353 |
+
#include <ATen/ops/_to_copy_meta.h>
|
| 354 |
+
#include <ATen/ops/_to_cpu_meta.h>
|
| 355 |
+
#include <ATen/ops/_to_dense_meta.h>
|
| 356 |
+
#include <ATen/ops/_to_sparse_meta.h>
|
| 357 |
+
#include <ATen/ops/_to_sparse_bsc_meta.h>
|
| 358 |
+
#include <ATen/ops/_to_sparse_bsr_meta.h>
|
| 359 |
+
#include <ATen/ops/_to_sparse_csc_meta.h>
|
| 360 |
+
#include <ATen/ops/_to_sparse_csr_meta.h>
|
| 361 |
+
#include <ATen/ops/_to_sparse_semi_structured_meta.h>
|
| 362 |
+
#include <ATen/ops/_transform_bias_rescale_qkv_meta.h>
|
| 363 |
+
#include <ATen/ops/_transformer_encoder_layer_fwd_meta.h>
|
| 364 |
+
#include <ATen/ops/_trilinear_meta.h>
|
| 365 |
+
#include <ATen/ops/_triton_multi_head_attention_meta.h>
|
| 366 |
+
#include <ATen/ops/_triton_scaled_dot_attention_meta.h>
|
| 367 |
+
#include <ATen/ops/_unique_meta.h>
|
| 368 |
+
#include <ATen/ops/_unique2_meta.h>
|
| 369 |
+
#include <ATen/ops/_unpack_dual_meta.h>
|
| 370 |
+
#include <ATen/ops/_unsafe_index_meta.h>
|
| 371 |
+
#include <ATen/ops/_unsafe_index_put_meta.h>
|
| 372 |
+
#include <ATen/ops/_unsafe_masked_index_meta.h>
|
| 373 |
+
#include <ATen/ops/_unsafe_masked_index_put_accumulate_meta.h>
|
| 374 |
+
#include <ATen/ops/_unsafe_view_meta.h>
|
| 375 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_meta.h>
|
| 376 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_backward_meta.h>
|
| 377 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_meta.h>
|
| 378 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_backward_meta.h>
|
| 379 |
+
#include <ATen/ops/_upsample_nearest_exact1d_meta.h>
|
| 380 |
+
#include <ATen/ops/_upsample_nearest_exact1d_backward_meta.h>
|
| 381 |
+
#include <ATen/ops/_upsample_nearest_exact2d_meta.h>
|
| 382 |
+
#include <ATen/ops/_upsample_nearest_exact2d_backward_meta.h>
|
| 383 |
+
#include <ATen/ops/_upsample_nearest_exact3d_meta.h>
|
| 384 |
+
#include <ATen/ops/_upsample_nearest_exact3d_backward_meta.h>
|
| 385 |
+
#include <ATen/ops/_use_cudnn_ctc_loss_meta.h>
|
| 386 |
+
#include <ATen/ops/_use_cudnn_rnn_flatten_weight_meta.h>
|
| 387 |
+
#include <ATen/ops/_validate_compressed_sparse_indices_meta.h>
|
| 388 |
+
#include <ATen/ops/_validate_sparse_bsc_tensor_args_meta.h>
|
| 389 |
+
#include <ATen/ops/_validate_sparse_bsr_tensor_args_meta.h>
|
| 390 |
+
#include <ATen/ops/_validate_sparse_compressed_tensor_args_meta.h>
|
| 391 |
+
#include <ATen/ops/_validate_sparse_coo_tensor_args_meta.h>
|
| 392 |
+
#include <ATen/ops/_validate_sparse_csc_tensor_args_meta.h>
|
| 393 |
+
#include <ATen/ops/_validate_sparse_csr_tensor_args_meta.h>
|
| 394 |
+
#include <ATen/ops/_values_meta.h>
|
| 395 |
+
#include <ATen/ops/_values_copy_meta.h>
|
| 396 |
+
#include <ATen/ops/_version_meta.h>
|
| 397 |
+
#include <ATen/ops/_weight_int4pack_mm_meta.h>
|
| 398 |
+
#include <ATen/ops/_weight_int4pack_mm_for_cpu_meta.h>
|
| 399 |
+
#include <ATen/ops/_weight_int4pack_mm_with_scales_and_zeros_meta.h>
|
| 400 |
+
#include <ATen/ops/_weight_int8pack_mm_meta.h>
|
| 401 |
+
#include <ATen/ops/_weight_norm_meta.h>
|
| 402 |
+
#include <ATen/ops/_weight_norm_differentiable_backward_meta.h>
|
| 403 |
+
#include <ATen/ops/_weight_norm_interface_meta.h>
|
| 404 |
+
#include <ATen/ops/_weight_norm_interface_backward_meta.h>
|
| 405 |
+
#include <ATen/ops/_wrapped_linear_prepack_meta.h>
|
| 406 |
+
#include <ATen/ops/_wrapped_quantized_linear_prepacked_meta.h>
|
| 407 |
+
#include <ATen/ops/abs_meta.h>
|
| 408 |
+
#include <ATen/ops/absolute_meta.h>
|
| 409 |
+
#include <ATen/ops/acos_meta.h>
|
| 410 |
+
#include <ATen/ops/acosh_meta.h>
|
| 411 |
+
#include <ATen/ops/adaptive_avg_pool1d_meta.h>
|
| 412 |
+
#include <ATen/ops/adaptive_avg_pool2d_meta.h>
|
| 413 |
+
#include <ATen/ops/adaptive_avg_pool3d_meta.h>
|
| 414 |
+
#include <ATen/ops/adaptive_avg_pool3d_backward_meta.h>
|
| 415 |
+
#include <ATen/ops/adaptive_max_pool1d_meta.h>
|
| 416 |
+
#include <ATen/ops/adaptive_max_pool2d_meta.h>
|
| 417 |
+
#include <ATen/ops/adaptive_max_pool2d_backward_meta.h>
|
| 418 |
+
#include <ATen/ops/adaptive_max_pool3d_meta.h>
|
| 419 |
+
#include <ATen/ops/adaptive_max_pool3d_backward_meta.h>
|
| 420 |
+
#include <ATen/ops/add_meta.h>
|
| 421 |
+
#include <ATen/ops/addbmm_meta.h>
|
| 422 |
+
#include <ATen/ops/addcdiv_meta.h>
|
| 423 |
+
#include <ATen/ops/addcmul_meta.h>
|
| 424 |
+
#include <ATen/ops/addmm_meta.h>
|
| 425 |
+
#include <ATen/ops/addmv_meta.h>
|
| 426 |
+
#include <ATen/ops/addr_meta.h>
|
| 427 |
+
#include <ATen/ops/adjoint_meta.h>
|
| 428 |
+
#include <ATen/ops/affine_grid_generator_meta.h>
|
| 429 |
+
#include <ATen/ops/affine_grid_generator_backward_meta.h>
|
| 430 |
+
#include <ATen/ops/alias_meta.h>
|
| 431 |
+
#include <ATen/ops/alias_copy_meta.h>
|
| 432 |
+
#include <ATen/ops/align_as_meta.h>
|
| 433 |
+
#include <ATen/ops/align_tensors_meta.h>
|
| 434 |
+
#include <ATen/ops/align_to_meta.h>
|
| 435 |
+
#include <ATen/ops/all_meta.h>
|
| 436 |
+
#include <ATen/ops/allclose_meta.h>
|
| 437 |
+
#include <ATen/ops/alpha_dropout_meta.h>
|
| 438 |
+
#include <ATen/ops/amax_meta.h>
|
| 439 |
+
#include <ATen/ops/amin_meta.h>
|
| 440 |
+
#include <ATen/ops/aminmax_meta.h>
|
| 441 |
+
#include <ATen/ops/and_meta.h>
|
| 442 |
+
#include <ATen/ops/angle_meta.h>
|
| 443 |
+
#include <ATen/ops/any_meta.h>
|
| 444 |
+
#include <ATen/ops/arange_meta.h>
|
| 445 |
+
#include <ATen/ops/arccos_meta.h>
|
| 446 |
+
#include <ATen/ops/arccosh_meta.h>
|
| 447 |
+
#include <ATen/ops/arcsin_meta.h>
|
| 448 |
+
#include <ATen/ops/arcsinh_meta.h>
|
| 449 |
+
#include <ATen/ops/arctan_meta.h>
|
| 450 |
+
#include <ATen/ops/arctan2_meta.h>
|
| 451 |
+
#include <ATen/ops/arctanh_meta.h>
|
| 452 |
+
#include <ATen/ops/argmax_meta.h>
|
| 453 |
+
#include <ATen/ops/argmin_meta.h>
|
| 454 |
+
#include <ATen/ops/argsort_meta.h>
|
| 455 |
+
#include <ATen/ops/argwhere_meta.h>
|
| 456 |
+
#include <ATen/ops/as_strided_meta.h>
|
| 457 |
+
#include <ATen/ops/as_strided_copy_meta.h>
|
| 458 |
+
#include <ATen/ops/as_strided_scatter_meta.h>
|
| 459 |
+
#include <ATen/ops/asin_meta.h>
|
| 460 |
+
#include <ATen/ops/asinh_meta.h>
|
| 461 |
+
#include <ATen/ops/atan_meta.h>
|
| 462 |
+
#include <ATen/ops/atan2_meta.h>
|
| 463 |
+
#include <ATen/ops/atanh_meta.h>
|
| 464 |
+
#include <ATen/ops/atleast_1d_meta.h>
|
| 465 |
+
#include <ATen/ops/atleast_2d_meta.h>
|
| 466 |
+
#include <ATen/ops/atleast_3d_meta.h>
|
| 467 |
+
#include <ATen/ops/avg_pool1d_meta.h>
|
| 468 |
+
#include <ATen/ops/avg_pool2d_meta.h>
|
| 469 |
+
#include <ATen/ops/avg_pool2d_backward_meta.h>
|
| 470 |
+
#include <ATen/ops/avg_pool3d_meta.h>
|
| 471 |
+
#include <ATen/ops/avg_pool3d_backward_meta.h>
|
| 472 |
+
#include <ATen/ops/baddbmm_meta.h>
|
| 473 |
+
#include <ATen/ops/bartlett_window_meta.h>
|
| 474 |
+
#include <ATen/ops/batch_norm_meta.h>
|
| 475 |
+
#include <ATen/ops/batch_norm_backward_meta.h>
|
| 476 |
+
#include <ATen/ops/batch_norm_backward_elemt_meta.h>
|
| 477 |
+
#include <ATen/ops/batch_norm_backward_reduce_meta.h>
|
| 478 |
+
#include <ATen/ops/batch_norm_elemt_meta.h>
|
| 479 |
+
#include <ATen/ops/batch_norm_gather_stats_meta.h>
|
| 480 |
+
#include <ATen/ops/batch_norm_gather_stats_with_counts_meta.h>
|
| 481 |
+
#include <ATen/ops/batch_norm_stats_meta.h>
|
| 482 |
+
#include <ATen/ops/batch_norm_update_stats_meta.h>
|
| 483 |
+
#include <ATen/ops/bernoulli_meta.h>
|
| 484 |
+
#include <ATen/ops/bilinear_meta.h>
|
| 485 |
+
#include <ATen/ops/binary_cross_entropy_meta.h>
|
| 486 |
+
#include <ATen/ops/binary_cross_entropy_backward_meta.h>
|
| 487 |
+
#include <ATen/ops/binary_cross_entropy_with_logits_meta.h>
|
| 488 |
+
#include <ATen/ops/bincount_meta.h>
|
| 489 |
+
#include <ATen/ops/binomial_meta.h>
|
| 490 |
+
#include <ATen/ops/bitwise_and_meta.h>
|
| 491 |
+
#include <ATen/ops/bitwise_left_shift_meta.h>
|
| 492 |
+
#include <ATen/ops/bitwise_not_meta.h>
|
| 493 |
+
#include <ATen/ops/bitwise_or_meta.h>
|
| 494 |
+
#include <ATen/ops/bitwise_right_shift_meta.h>
|
| 495 |
+
#include <ATen/ops/bitwise_xor_meta.h>
|
| 496 |
+
#include <ATen/ops/blackman_window_meta.h>
|
| 497 |
+
#include <ATen/ops/block_diag_meta.h>
|
| 498 |
+
#include <ATen/ops/bmm_meta.h>
|
| 499 |
+
#include <ATen/ops/broadcast_tensors_meta.h>
|
| 500 |
+
#include <ATen/ops/broadcast_to_meta.h>
|
| 501 |
+
#include <ATen/ops/bucketize_meta.h>
|
| 502 |
+
#include <ATen/ops/can_cast_meta.h>
|
| 503 |
+
#include <ATen/ops/cartesian_prod_meta.h>
|
| 504 |
+
#include <ATen/ops/cat_meta.h>
|
| 505 |
+
#include <ATen/ops/cauchy_meta.h>
|
| 506 |
+
#include <ATen/ops/ccol_indices_meta.h>
|
| 507 |
+
#include <ATen/ops/ccol_indices_copy_meta.h>
|
| 508 |
+
#include <ATen/ops/cdist_meta.h>
|
| 509 |
+
#include <ATen/ops/ceil_meta.h>
|
| 510 |
+
#include <ATen/ops/celu_meta.h>
|
| 511 |
+
#include <ATen/ops/chain_matmul_meta.h>
|
| 512 |
+
#include <ATen/ops/chalf_meta.h>
|
| 513 |
+
#include <ATen/ops/channel_shuffle_meta.h>
|
| 514 |
+
#include <ATen/ops/cholesky_meta.h>
|
| 515 |
+
#include <ATen/ops/cholesky_inverse_meta.h>
|
| 516 |
+
#include <ATen/ops/cholesky_solve_meta.h>
|
| 517 |
+
#include <ATen/ops/choose_qparams_optimized_meta.h>
|
| 518 |
+
#include <ATen/ops/chunk_meta.h>
|
| 519 |
+
#include <ATen/ops/clamp_meta.h>
|
| 520 |
+
#include <ATen/ops/clamp_max_meta.h>
|
| 521 |
+
#include <ATen/ops/clamp_min_meta.h>
|
| 522 |
+
#include <ATen/ops/clip_meta.h>
|
| 523 |
+
#include <ATen/ops/clone_meta.h>
|
| 524 |
+
#include <ATen/ops/coalesce_meta.h>
|
| 525 |
+
#include <ATen/ops/col2im_meta.h>
|
| 526 |
+
#include <ATen/ops/col_indices_meta.h>
|
| 527 |
+
#include <ATen/ops/col_indices_copy_meta.h>
|
| 528 |
+
#include <ATen/ops/column_stack_meta.h>
|
| 529 |
+
#include <ATen/ops/combinations_meta.h>
|
| 530 |
+
#include <ATen/ops/complex_meta.h>
|
| 531 |
+
#include <ATen/ops/concat_meta.h>
|
| 532 |
+
#include <ATen/ops/concatenate_meta.h>
|
| 533 |
+
#include <ATen/ops/conj_meta.h>
|
| 534 |
+
#include <ATen/ops/conj_physical_meta.h>
|
| 535 |
+
#include <ATen/ops/constant_pad_nd_meta.h>
|
| 536 |
+
#include <ATen/ops/contiguous_meta.h>
|
| 537 |
+
#include <ATen/ops/conv1d_meta.h>
|
| 538 |
+
#include <ATen/ops/conv2d_meta.h>
|
| 539 |
+
#include <ATen/ops/conv3d_meta.h>
|
| 540 |
+
#include <ATen/ops/conv_depthwise3d_meta.h>
|
| 541 |
+
#include <ATen/ops/conv_tbc_meta.h>
|
| 542 |
+
#include <ATen/ops/conv_tbc_backward_meta.h>
|
| 543 |
+
#include <ATen/ops/conv_transpose1d_meta.h>
|
| 544 |
+
#include <ATen/ops/conv_transpose2d_meta.h>
|
| 545 |
+
#include <ATen/ops/conv_transpose3d_meta.h>
|
| 546 |
+
#include <ATen/ops/convolution_meta.h>
|
| 547 |
+
#include <ATen/ops/convolution_backward_meta.h>
|
| 548 |
+
#include <ATen/ops/convolution_backward_overrideable_meta.h>
|
| 549 |
+
#include <ATen/ops/convolution_overrideable_meta.h>
|
| 550 |
+
#include <ATen/ops/copy_meta.h>
|
| 551 |
+
#include <ATen/ops/copy_sparse_to_sparse_meta.h>
|
| 552 |
+
#include <ATen/ops/copysign_meta.h>
|
| 553 |
+
#include <ATen/ops/corrcoef_meta.h>
|
| 554 |
+
#include <ATen/ops/cos_meta.h>
|
| 555 |
+
#include <ATen/ops/cosh_meta.h>
|
| 556 |
+
#include <ATen/ops/cosine_embedding_loss_meta.h>
|
| 557 |
+
#include <ATen/ops/cosine_similarity_meta.h>
|
| 558 |
+
#include <ATen/ops/count_nonzero_meta.h>
|
| 559 |
+
#include <ATen/ops/cov_meta.h>
|
| 560 |
+
#include <ATen/ops/cross_meta.h>
|
| 561 |
+
#include <ATen/ops/cross_entropy_loss_meta.h>
|
| 562 |
+
#include <ATen/ops/crow_indices_meta.h>
|
| 563 |
+
#include <ATen/ops/crow_indices_copy_meta.h>
|
| 564 |
+
#include <ATen/ops/ctc_loss_meta.h>
|
| 565 |
+
#include <ATen/ops/cudnn_affine_grid_generator_meta.h>
|
| 566 |
+
#include <ATen/ops/cudnn_affine_grid_generator_backward_meta.h>
|
| 567 |
+
#include <ATen/ops/cudnn_batch_norm_meta.h>
|
| 568 |
+
#include <ATen/ops/cudnn_batch_norm_backward_meta.h>
|
| 569 |
+
#include <ATen/ops/cudnn_convolution_meta.h>
|
| 570 |
+
#include <ATen/ops/cudnn_convolution_add_relu_meta.h>
|
| 571 |
+
#include <ATen/ops/cudnn_convolution_relu_meta.h>
|
| 572 |
+
#include <ATen/ops/cudnn_convolution_transpose_meta.h>
|
| 573 |
+
#include <ATen/ops/cudnn_grid_sampler_meta.h>
|
| 574 |
+
#include <ATen/ops/cudnn_grid_sampler_backward_meta.h>
|
| 575 |
+
#include <ATen/ops/cudnn_is_acceptable_meta.h>
|
| 576 |
+
#include <ATen/ops/cummax_meta.h>
|
| 577 |
+
#include <ATen/ops/cummaxmin_backward_meta.h>
|
| 578 |
+
#include <ATen/ops/cummin_meta.h>
|
| 579 |
+
#include <ATen/ops/cumprod_meta.h>
|
| 580 |
+
#include <ATen/ops/cumprod_backward_meta.h>
|
| 581 |
+
#include <ATen/ops/cumsum_meta.h>
|
| 582 |
+
#include <ATen/ops/cumulative_trapezoid_meta.h>
|
| 583 |
+
#include <ATen/ops/data_meta.h>
|
| 584 |
+
#include <ATen/ops/deg2rad_meta.h>
|
| 585 |
+
#include <ATen/ops/dense_dim_meta.h>
|
| 586 |
+
#include <ATen/ops/dequantize_meta.h>
|
| 587 |
+
#include <ATen/ops/det_meta.h>
|
| 588 |
+
#include <ATen/ops/detach_meta.h>
|
| 589 |
+
#include <ATen/ops/detach_copy_meta.h>
|
| 590 |
+
#include <ATen/ops/diag_meta.h>
|
| 591 |
+
#include <ATen/ops/diag_embed_meta.h>
|
| 592 |
+
#include <ATen/ops/diagflat_meta.h>
|
| 593 |
+
#include <ATen/ops/diagonal_meta.h>
|
| 594 |
+
#include <ATen/ops/diagonal_backward_meta.h>
|
| 595 |
+
#include <ATen/ops/diagonal_copy_meta.h>
|
| 596 |
+
#include <ATen/ops/diagonal_scatter_meta.h>
|
| 597 |
+
#include <ATen/ops/diff_meta.h>
|
| 598 |
+
#include <ATen/ops/digamma_meta.h>
|
| 599 |
+
#include <ATen/ops/dist_meta.h>
|
| 600 |
+
#include <ATen/ops/div_meta.h>
|
| 601 |
+
#include <ATen/ops/divide_meta.h>
|
| 602 |
+
#include <ATen/ops/dot_meta.h>
|
| 603 |
+
#include <ATen/ops/dropout_meta.h>
|
| 604 |
+
#include <ATen/ops/dsplit_meta.h>
|
| 605 |
+
#include <ATen/ops/dstack_meta.h>
|
| 606 |
+
#include <ATen/ops/einsum_meta.h>
|
| 607 |
+
#include <ATen/ops/elu_meta.h>
|
| 608 |
+
#include <ATen/ops/elu_backward_meta.h>
|
| 609 |
+
#include <ATen/ops/embedding_meta.h>
|
| 610 |
+
#include <ATen/ops/embedding_backward_meta.h>
|
| 611 |
+
#include <ATen/ops/embedding_bag_meta.h>
|
| 612 |
+
#include <ATen/ops/embedding_dense_backward_meta.h>
|
| 613 |
+
#include <ATen/ops/embedding_renorm_meta.h>
|
| 614 |
+
#include <ATen/ops/embedding_sparse_backward_meta.h>
|
| 615 |
+
#include <ATen/ops/empty_meta.h>
|
| 616 |
+
#include <ATen/ops/empty_like_meta.h>
|
| 617 |
+
#include <ATen/ops/empty_permuted_meta.h>
|
| 618 |
+
#include <ATen/ops/empty_quantized_meta.h>
|
| 619 |
+
#include <ATen/ops/empty_strided_meta.h>
|
| 620 |
+
#include <ATen/ops/eq_meta.h>
|
| 621 |
+
#include <ATen/ops/equal_meta.h>
|
| 622 |
+
#include <ATen/ops/erf_meta.h>
|
| 623 |
+
#include <ATen/ops/erfc_meta.h>
|
| 624 |
+
#include <ATen/ops/erfinv_meta.h>
|
| 625 |
+
#include <ATen/ops/exp_meta.h>
|
| 626 |
+
#include <ATen/ops/exp2_meta.h>
|
| 627 |
+
#include <ATen/ops/expand_meta.h>
|
| 628 |
+
#include <ATen/ops/expand_as_meta.h>
|
| 629 |
+
#include <ATen/ops/expand_copy_meta.h>
|
| 630 |
+
#include <ATen/ops/expm1_meta.h>
|
| 631 |
+
#include <ATen/ops/exponential_meta.h>
|
| 632 |
+
#include <ATen/ops/eye_meta.h>
|
| 633 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_meta.h>
|
| 634 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_meta.h>
|
| 635 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_backward_meta.h>
|
| 636 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_meta.h>
|
| 637 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_meta.h>
|
| 638 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_backward_meta.h>
|
| 639 |
+
#include <ATen/ops/fbgemm_linear_fp16_weight_meta.h>
|
| 640 |
+
#include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation_meta.h>
|
| 641 |
+
#include <ATen/ops/fbgemm_linear_int8_weight_meta.h>
|
| 642 |
+
#include <ATen/ops/fbgemm_linear_int8_weight_fp32_activation_meta.h>
|
| 643 |
+
#include <ATen/ops/fbgemm_linear_quantize_weight_meta.h>
|
| 644 |
+
#include <ATen/ops/fbgemm_pack_gemm_matrix_fp16_meta.h>
|
| 645 |
+
#include <ATen/ops/fbgemm_pack_quantized_matrix_meta.h>
|
| 646 |
+
#include <ATen/ops/feature_alpha_dropout_meta.h>
|
| 647 |
+
#include <ATen/ops/feature_dropout_meta.h>
|
| 648 |
+
#include <ATen/ops/fft_fft_meta.h>
|
| 649 |
+
#include <ATen/ops/fft_fft2_meta.h>
|
| 650 |
+
#include <ATen/ops/fft_fftfreq_meta.h>
|
| 651 |
+
#include <ATen/ops/fft_fftn_meta.h>
|
| 652 |
+
#include <ATen/ops/fft_fftshift_meta.h>
|
| 653 |
+
#include <ATen/ops/fft_hfft_meta.h>
|
| 654 |
+
#include <ATen/ops/fft_hfft2_meta.h>
|
| 655 |
+
#include <ATen/ops/fft_hfftn_meta.h>
|
| 656 |
+
#include <ATen/ops/fft_ifft_meta.h>
|
| 657 |
+
#include <ATen/ops/fft_ifft2_meta.h>
|
| 658 |
+
#include <ATen/ops/fft_ifftn_meta.h>
|
| 659 |
+
#include <ATen/ops/fft_ifftshift_meta.h>
|
| 660 |
+
#include <ATen/ops/fft_ihfft_meta.h>
|
| 661 |
+
#include <ATen/ops/fft_ihfft2_meta.h>
|
| 662 |
+
#include <ATen/ops/fft_ihfftn_meta.h>
|
| 663 |
+
#include <ATen/ops/fft_irfft_meta.h>
|
| 664 |
+
#include <ATen/ops/fft_irfft2_meta.h>
|
| 665 |
+
#include <ATen/ops/fft_irfftn_meta.h>
|
| 666 |
+
#include <ATen/ops/fft_rfft_meta.h>
|
| 667 |
+
#include <ATen/ops/fft_rfft2_meta.h>
|
| 668 |
+
#include <ATen/ops/fft_rfftfreq_meta.h>
|
| 669 |
+
#include <ATen/ops/fft_rfftn_meta.h>
|
| 670 |
+
#include <ATen/ops/fill_meta.h>
|
| 671 |
+
#include <ATen/ops/fill_diagonal_meta.h>
|
| 672 |
+
#include <ATen/ops/fix_meta.h>
|
| 673 |
+
#include <ATen/ops/flatten_meta.h>
|
| 674 |
+
#include <ATen/ops/flatten_dense_tensors_meta.h>
|
| 675 |
+
#include <ATen/ops/flip_meta.h>
|
| 676 |
+
#include <ATen/ops/fliplr_meta.h>
|
| 677 |
+
#include <ATen/ops/flipud_meta.h>
|
| 678 |
+
#include <ATen/ops/float_power_meta.h>
|
| 679 |
+
#include <ATen/ops/floor_meta.h>
|
| 680 |
+
#include <ATen/ops/floor_divide_meta.h>
|
| 681 |
+
#include <ATen/ops/fmax_meta.h>
|
| 682 |
+
#include <ATen/ops/fmin_meta.h>
|
| 683 |
+
#include <ATen/ops/fmod_meta.h>
|
| 684 |
+
#include <ATen/ops/frac_meta.h>
|
| 685 |
+
#include <ATen/ops/fractional_max_pool2d_meta.h>
|
| 686 |
+
#include <ATen/ops/fractional_max_pool2d_backward_meta.h>
|
| 687 |
+
#include <ATen/ops/fractional_max_pool3d_meta.h>
|
| 688 |
+
#include <ATen/ops/fractional_max_pool3d_backward_meta.h>
|
| 689 |
+
#include <ATen/ops/frexp_meta.h>
|
| 690 |
+
#include <ATen/ops/frobenius_norm_meta.h>
|
| 691 |
+
#include <ATen/ops/from_file_meta.h>
|
| 692 |
+
#include <ATen/ops/full_meta.h>
|
| 693 |
+
#include <ATen/ops/full_like_meta.h>
|
| 694 |
+
#include <ATen/ops/fused_moving_avg_obs_fake_quant_meta.h>
|
| 695 |
+
#include <ATen/ops/gather_meta.h>
|
| 696 |
+
#include <ATen/ops/gather_backward_meta.h>
|
| 697 |
+
#include <ATen/ops/gcd_meta.h>
|
| 698 |
+
#include <ATen/ops/ge_meta.h>
|
| 699 |
+
#include <ATen/ops/gelu_meta.h>
|
| 700 |
+
#include <ATen/ops/gelu_backward_meta.h>
|
| 701 |
+
#include <ATen/ops/geometric_meta.h>
|
| 702 |
+
#include <ATen/ops/geqrf_meta.h>
|
| 703 |
+
#include <ATen/ops/ger_meta.h>
|
| 704 |
+
#include <ATen/ops/glu_meta.h>
|
| 705 |
+
#include <ATen/ops/glu_backward_meta.h>
|
| 706 |
+
#include <ATen/ops/glu_backward_jvp_meta.h>
|
| 707 |
+
#include <ATen/ops/glu_jvp_meta.h>
|
| 708 |
+
#include <ATen/ops/gradient_meta.h>
|
| 709 |
+
#include <ATen/ops/greater_meta.h>
|
| 710 |
+
#include <ATen/ops/greater_equal_meta.h>
|
| 711 |
+
#include <ATen/ops/grid_sampler_meta.h>
|
| 712 |
+
#include <ATen/ops/grid_sampler_2d_meta.h>
|
| 713 |
+
#include <ATen/ops/grid_sampler_2d_backward_meta.h>
|
| 714 |
+
#include <ATen/ops/grid_sampler_3d_meta.h>
|
| 715 |
+
#include <ATen/ops/grid_sampler_3d_backward_meta.h>
|
| 716 |
+
#include <ATen/ops/group_norm_meta.h>
|
| 717 |
+
#include <ATen/ops/gru_meta.h>
|
| 718 |
+
#include <ATen/ops/gru_cell_meta.h>
|
| 719 |
+
#include <ATen/ops/gt_meta.h>
|
| 720 |
+
#include <ATen/ops/hamming_window_meta.h>
|
| 721 |
+
#include <ATen/ops/hann_window_meta.h>
|
| 722 |
+
#include <ATen/ops/hardshrink_meta.h>
|
| 723 |
+
#include <ATen/ops/hardshrink_backward_meta.h>
|
| 724 |
+
#include <ATen/ops/hardsigmoid_meta.h>
|
| 725 |
+
#include <ATen/ops/hardsigmoid_backward_meta.h>
|
| 726 |
+
#include <ATen/ops/hardswish_meta.h>
|
| 727 |
+
#include <ATen/ops/hardswish_backward_meta.h>
|
| 728 |
+
#include <ATen/ops/hardtanh_meta.h>
|
| 729 |
+
#include <ATen/ops/hardtanh_backward_meta.h>
|
| 730 |
+
#include <ATen/ops/hash_tensor_meta.h>
|
| 731 |
+
#include <ATen/ops/heaviside_meta.h>
|
| 732 |
+
#include <ATen/ops/hinge_embedding_loss_meta.h>
|
| 733 |
+
#include <ATen/ops/histc_meta.h>
|
| 734 |
+
#include <ATen/ops/histogram_meta.h>
|
| 735 |
+
#include <ATen/ops/histogramdd_meta.h>
|
| 736 |
+
#include <ATen/ops/hsplit_meta.h>
|
| 737 |
+
#include <ATen/ops/hspmm_meta.h>
|
| 738 |
+
#include <ATen/ops/hstack_meta.h>
|
| 739 |
+
#include <ATen/ops/huber_loss_meta.h>
|
| 740 |
+
#include <ATen/ops/huber_loss_backward_meta.h>
|
| 741 |
+
#include <ATen/ops/hypot_meta.h>
|
| 742 |
+
#include <ATen/ops/i0_meta.h>
|
| 743 |
+
#include <ATen/ops/igamma_meta.h>
|
| 744 |
+
#include <ATen/ops/igammac_meta.h>
|
| 745 |
+
#include <ATen/ops/im2col_meta.h>
|
| 746 |
+
#include <ATen/ops/imag_meta.h>
|
| 747 |
+
#include <ATen/ops/index_meta.h>
|
| 748 |
+
#include <ATen/ops/index_add_meta.h>
|
| 749 |
+
#include <ATen/ops/index_copy_meta.h>
|
| 750 |
+
#include <ATen/ops/index_fill_meta.h>
|
| 751 |
+
#include <ATen/ops/index_put_meta.h>
|
| 752 |
+
#include <ATen/ops/index_reduce_meta.h>
|
| 753 |
+
#include <ATen/ops/index_select_meta.h>
|
| 754 |
+
#include <ATen/ops/index_select_backward_meta.h>
|
| 755 |
+
#include <ATen/ops/indices_meta.h>
|
| 756 |
+
#include <ATen/ops/indices_copy_meta.h>
|
| 757 |
+
#include <ATen/ops/infinitely_differentiable_gelu_backward_meta.h>
|
| 758 |
+
#include <ATen/ops/inner_meta.h>
|
| 759 |
+
#include <ATen/ops/instance_norm_meta.h>
|
| 760 |
+
#include <ATen/ops/int_repr_meta.h>
|
| 761 |
+
#include <ATen/ops/inverse_meta.h>
|
| 762 |
+
#include <ATen/ops/is_coalesced_meta.h>
|
| 763 |
+
#include <ATen/ops/is_complex_meta.h>
|
| 764 |
+
#include <ATen/ops/is_conj_meta.h>
|
| 765 |
+
#include <ATen/ops/is_distributed_meta.h>
|
| 766 |
+
#include <ATen/ops/is_floating_point_meta.h>
|
| 767 |
+
#include <ATen/ops/is_inference_meta.h>
|
| 768 |
+
#include <ATen/ops/is_leaf_meta.h>
|
| 769 |
+
#include <ATen/ops/is_neg_meta.h>
|
| 770 |
+
#include <ATen/ops/is_nonzero_meta.h>
|
| 771 |
+
#include <ATen/ops/is_pinned_meta.h>
|
| 772 |
+
#include <ATen/ops/is_same_size_meta.h>
|
| 773 |
+
#include <ATen/ops/is_set_to_meta.h>
|
| 774 |
+
#include <ATen/ops/is_signed_meta.h>
|
| 775 |
+
#include <ATen/ops/is_vulkan_available_meta.h>
|
| 776 |
+
#include <ATen/ops/isclose_meta.h>
|
| 777 |
+
#include <ATen/ops/isfinite_meta.h>
|
| 778 |
+
#include <ATen/ops/isin_meta.h>
|
| 779 |
+
#include <ATen/ops/isinf_meta.h>
|
| 780 |
+
#include <ATen/ops/isnan_meta.h>
|
| 781 |
+
#include <ATen/ops/isneginf_meta.h>
|
| 782 |
+
#include <ATen/ops/isposinf_meta.h>
|
| 783 |
+
#include <ATen/ops/isreal_meta.h>
|
| 784 |
+
#include <ATen/ops/istft_meta.h>
|
| 785 |
+
#include <ATen/ops/item_meta.h>
|
| 786 |
+
#include <ATen/ops/kaiser_window_meta.h>
|
| 787 |
+
#include <ATen/ops/kl_div_meta.h>
|
| 788 |
+
#include <ATen/ops/kron_meta.h>
|
| 789 |
+
#include <ATen/ops/kthvalue_meta.h>
|
| 790 |
+
#include <ATen/ops/l1_loss_meta.h>
|
| 791 |
+
#include <ATen/ops/layer_norm_meta.h>
|
| 792 |
+
#include <ATen/ops/lcm_meta.h>
|
| 793 |
+
#include <ATen/ops/ldexp_meta.h>
|
| 794 |
+
#include <ATen/ops/le_meta.h>
|
| 795 |
+
#include <ATen/ops/leaky_relu_meta.h>
|
| 796 |
+
#include <ATen/ops/leaky_relu_backward_meta.h>
|
| 797 |
+
#include <ATen/ops/lerp_meta.h>
|
| 798 |
+
#include <ATen/ops/less_meta.h>
|
| 799 |
+
#include <ATen/ops/less_equal_meta.h>
|
| 800 |
+
#include <ATen/ops/lgamma_meta.h>
|
| 801 |
+
#include <ATen/ops/lift_meta.h>
|
| 802 |
+
#include <ATen/ops/lift_fresh_meta.h>
|
| 803 |
+
#include <ATen/ops/lift_fresh_copy_meta.h>
|
| 804 |
+
#include <ATen/ops/linalg_cholesky_meta.h>
|
| 805 |
+
#include <ATen/ops/linalg_cholesky_ex_meta.h>
|
| 806 |
+
#include <ATen/ops/linalg_cond_meta.h>
|
| 807 |
+
#include <ATen/ops/linalg_cross_meta.h>
|
| 808 |
+
#include <ATen/ops/linalg_det_meta.h>
|
| 809 |
+
#include <ATen/ops/linalg_diagonal_meta.h>
|
| 810 |
+
#include <ATen/ops/linalg_eig_meta.h>
|
| 811 |
+
#include <ATen/ops/linalg_eigh_meta.h>
|
| 812 |
+
#include <ATen/ops/linalg_eigvals_meta.h>
|
| 813 |
+
#include <ATen/ops/linalg_eigvalsh_meta.h>
|
| 814 |
+
#include <ATen/ops/linalg_householder_product_meta.h>
|
| 815 |
+
#include <ATen/ops/linalg_inv_meta.h>
|
| 816 |
+
#include <ATen/ops/linalg_inv_ex_meta.h>
|
| 817 |
+
#include <ATen/ops/linalg_ldl_factor_meta.h>
|
| 818 |
+
#include <ATen/ops/linalg_ldl_factor_ex_meta.h>
|
| 819 |
+
#include <ATen/ops/linalg_ldl_solve_meta.h>
|
| 820 |
+
#include <ATen/ops/linalg_lstsq_meta.h>
|
| 821 |
+
#include <ATen/ops/linalg_lu_meta.h>
|
| 822 |
+
#include <ATen/ops/linalg_lu_factor_meta.h>
|
| 823 |
+
#include <ATen/ops/linalg_lu_factor_ex_meta.h>
|
| 824 |
+
#include <ATen/ops/linalg_lu_solve_meta.h>
|
| 825 |
+
#include <ATen/ops/linalg_matmul_meta.h>
|
| 826 |
+
#include <ATen/ops/linalg_matrix_exp_meta.h>
|
| 827 |
+
#include <ATen/ops/linalg_matrix_norm_meta.h>
|
| 828 |
+
#include <ATen/ops/linalg_matrix_power_meta.h>
|
| 829 |
+
#include <ATen/ops/linalg_matrix_rank_meta.h>
|
| 830 |
+
#include <ATen/ops/linalg_multi_dot_meta.h>
|
| 831 |
+
#include <ATen/ops/linalg_norm_meta.h>
|
| 832 |
+
#include <ATen/ops/linalg_pinv_meta.h>
|
| 833 |
+
#include <ATen/ops/linalg_qr_meta.h>
|
| 834 |
+
#include <ATen/ops/linalg_slogdet_meta.h>
|
| 835 |
+
#include <ATen/ops/linalg_solve_meta.h>
|
| 836 |
+
#include <ATen/ops/linalg_solve_ex_meta.h>
|
| 837 |
+
#include <ATen/ops/linalg_solve_triangular_meta.h>
|
| 838 |
+
#include <ATen/ops/linalg_svd_meta.h>
|
| 839 |
+
#include <ATen/ops/linalg_svdvals_meta.h>
|
| 840 |
+
#include <ATen/ops/linalg_tensorinv_meta.h>
|
| 841 |
+
#include <ATen/ops/linalg_tensorsolve_meta.h>
|
| 842 |
+
#include <ATen/ops/linalg_vander_meta.h>
|
| 843 |
+
#include <ATen/ops/linalg_vecdot_meta.h>
|
| 844 |
+
#include <ATen/ops/linalg_vector_norm_meta.h>
|
| 845 |
+
#include <ATen/ops/linear_meta.h>
|
| 846 |
+
#include <ATen/ops/linear_backward_meta.h>
|
| 847 |
+
#include <ATen/ops/linspace_meta.h>
|
| 848 |
+
#include <ATen/ops/log_meta.h>
|
| 849 |
+
#include <ATen/ops/log10_meta.h>
|
| 850 |
+
#include <ATen/ops/log1p_meta.h>
|
| 851 |
+
#include <ATen/ops/log2_meta.h>
|
| 852 |
+
#include <ATen/ops/log_normal_meta.h>
|
| 853 |
+
#include <ATen/ops/log_sigmoid_meta.h>
|
| 854 |
+
#include <ATen/ops/log_sigmoid_backward_meta.h>
|
| 855 |
+
#include <ATen/ops/log_sigmoid_forward_meta.h>
|
| 856 |
+
#include <ATen/ops/log_softmax_meta.h>
|
| 857 |
+
#include <ATen/ops/logaddexp_meta.h>
|
| 858 |
+
#include <ATen/ops/logaddexp2_meta.h>
|
| 859 |
+
#include <ATen/ops/logcumsumexp_meta.h>
|
| 860 |
+
#include <ATen/ops/logdet_meta.h>
|
| 861 |
+
#include <ATen/ops/logical_and_meta.h>
|
| 862 |
+
#include <ATen/ops/logical_not_meta.h>
|
| 863 |
+
#include <ATen/ops/logical_or_meta.h>
|
| 864 |
+
#include <ATen/ops/logical_xor_meta.h>
|
| 865 |
+
#include <ATen/ops/logit_meta.h>
|
| 866 |
+
#include <ATen/ops/logit_backward_meta.h>
|
| 867 |
+
#include <ATen/ops/logspace_meta.h>
|
| 868 |
+
#include <ATen/ops/logsumexp_meta.h>
|
| 869 |
+
#include <ATen/ops/lshift_meta.h>
|
| 870 |
+
#include <ATen/ops/lstm_meta.h>
|
| 871 |
+
#include <ATen/ops/lstm_cell_meta.h>
|
| 872 |
+
#include <ATen/ops/lstm_mps_backward_meta.h>
|
| 873 |
+
#include <ATen/ops/lt_meta.h>
|
| 874 |
+
#include <ATen/ops/lu_solve_meta.h>
|
| 875 |
+
#include <ATen/ops/lu_unpack_meta.h>
|
| 876 |
+
#include <ATen/ops/mH_meta.h>
|
| 877 |
+
#include <ATen/ops/mT_meta.h>
|
| 878 |
+
#include <ATen/ops/margin_ranking_loss_meta.h>
|
| 879 |
+
#include <ATen/ops/masked_fill_meta.h>
|
| 880 |
+
#include <ATen/ops/masked_scatter_meta.h>
|
| 881 |
+
#include <ATen/ops/masked_scatter_backward_meta.h>
|
| 882 |
+
#include <ATen/ops/masked_select_meta.h>
|
| 883 |
+
#include <ATen/ops/masked_select_backward_meta.h>
|
| 884 |
+
#include <ATen/ops/matmul_meta.h>
|
| 885 |
+
#include <ATen/ops/matmul_backward_meta.h>
|
| 886 |
+
#include <ATen/ops/matrix_H_meta.h>
|
| 887 |
+
#include <ATen/ops/matrix_exp_meta.h>
|
| 888 |
+
#include <ATen/ops/matrix_exp_backward_meta.h>
|
| 889 |
+
#include <ATen/ops/matrix_power_meta.h>
|
| 890 |
+
#include <ATen/ops/max_meta.h>
|
| 891 |
+
#include <ATen/ops/max_pool1d_meta.h>
|
| 892 |
+
#include <ATen/ops/max_pool1d_with_indices_meta.h>
|
| 893 |
+
#include <ATen/ops/max_pool2d_meta.h>
|
| 894 |
+
#include <ATen/ops/max_pool2d_backward_meta.h>
|
| 895 |
+
#include <ATen/ops/max_pool2d_with_indices_meta.h>
|
| 896 |
+
#include <ATen/ops/max_pool2d_with_indices_backward_meta.h>
|
| 897 |
+
#include <ATen/ops/max_pool3d_meta.h>
|
| 898 |
+
#include <ATen/ops/max_pool3d_with_indices_meta.h>
|
| 899 |
+
#include <ATen/ops/max_pool3d_with_indices_backward_meta.h>
|
| 900 |
+
#include <ATen/ops/max_unpool2d_meta.h>
|
| 901 |
+
#include <ATen/ops/max_unpool3d_meta.h>
|
| 902 |
+
#include <ATen/ops/maximum_meta.h>
|
| 903 |
+
#include <ATen/ops/mean_meta.h>
|
| 904 |
+
#include <ATen/ops/median_meta.h>
|
| 905 |
+
#include <ATen/ops/meshgrid_meta.h>
|
| 906 |
+
#include <ATen/ops/min_meta.h>
|
| 907 |
+
#include <ATen/ops/minimum_meta.h>
|
| 908 |
+
#include <ATen/ops/miopen_batch_norm_meta.h>
|
| 909 |
+
#include <ATen/ops/miopen_batch_norm_backward_meta.h>
|
| 910 |
+
#include <ATen/ops/miopen_convolution_meta.h>
|
| 911 |
+
#include <ATen/ops/miopen_convolution_add_relu_meta.h>
|
| 912 |
+
#include <ATen/ops/miopen_convolution_relu_meta.h>
|
| 913 |
+
#include <ATen/ops/miopen_convolution_transpose_meta.h>
|
| 914 |
+
#include <ATen/ops/miopen_depthwise_convolution_meta.h>
|
| 915 |
+
#include <ATen/ops/miopen_rnn_meta.h>
|
| 916 |
+
#include <ATen/ops/miopen_rnn_backward_meta.h>
|
| 917 |
+
#include <ATen/ops/mish_meta.h>
|
| 918 |
+
#include <ATen/ops/mish_backward_meta.h>
|
| 919 |
+
#include <ATen/ops/mkldnn_adaptive_avg_pool2d_meta.h>
|
| 920 |
+
#include <ATen/ops/mkldnn_adaptive_avg_pool2d_backward_meta.h>
|
| 921 |
+
#include <ATen/ops/mkldnn_convolution_meta.h>
|
| 922 |
+
#include <ATen/ops/mkldnn_linear_meta.h>
|
| 923 |
+
#include <ATen/ops/mkldnn_linear_backward_meta.h>
|
| 924 |
+
#include <ATen/ops/mkldnn_linear_backward_input_meta.h>
|
| 925 |
+
#include <ATen/ops/mkldnn_linear_backward_weights_meta.h>
|
| 926 |
+
#include <ATen/ops/mkldnn_max_pool2d_meta.h>
|
| 927 |
+
#include <ATen/ops/mkldnn_max_pool2d_backward_meta.h>
|
| 928 |
+
#include <ATen/ops/mkldnn_max_pool3d_meta.h>
|
| 929 |
+
#include <ATen/ops/mkldnn_max_pool3d_backward_meta.h>
|
| 930 |
+
#include <ATen/ops/mkldnn_reorder_conv2d_weight_meta.h>
|
| 931 |
+
#include <ATen/ops/mkldnn_reorder_conv3d_weight_meta.h>
|
| 932 |
+
#include <ATen/ops/mkldnn_rnn_layer_meta.h>
|
| 933 |
+
#include <ATen/ops/mkldnn_rnn_layer_backward_meta.h>
|
| 934 |
+
#include <ATen/ops/mm_meta.h>
|
| 935 |
+
#include <ATen/ops/mode_meta.h>
|
| 936 |
+
#include <ATen/ops/moveaxis_meta.h>
|
| 937 |
+
#include <ATen/ops/movedim_meta.h>
|
| 938 |
+
#include <ATen/ops/mps_convolution_backward_meta.h>
|
| 939 |
+
#include <ATen/ops/mps_convolution_transpose_backward_meta.h>
|
| 940 |
+
#include <ATen/ops/mse_loss_meta.h>
|
| 941 |
+
#include <ATen/ops/mse_loss_backward_meta.h>
|
| 942 |
+
#include <ATen/ops/msort_meta.h>
|
| 943 |
+
#include <ATen/ops/mul_meta.h>
|
| 944 |
+
#include <ATen/ops/multi_margin_loss_meta.h>
|
| 945 |
+
#include <ATen/ops/multi_margin_loss_backward_meta.h>
|
| 946 |
+
#include <ATen/ops/multilabel_margin_loss_meta.h>
|
| 947 |
+
#include <ATen/ops/multilabel_margin_loss_backward_meta.h>
|
| 948 |
+
#include <ATen/ops/multilabel_margin_loss_forward_meta.h>
|
| 949 |
+
#include <ATen/ops/multinomial_meta.h>
|
| 950 |
+
#include <ATen/ops/multiply_meta.h>
|
| 951 |
+
#include <ATen/ops/mv_meta.h>
|
| 952 |
+
#include <ATen/ops/mvlgamma_meta.h>
|
| 953 |
+
#include <ATen/ops/nan_to_num_meta.h>
|
| 954 |
+
#include <ATen/ops/nanmean_meta.h>
|
| 955 |
+
#include <ATen/ops/nanmedian_meta.h>
|
| 956 |
+
#include <ATen/ops/nanquantile_meta.h>
|
| 957 |
+
#include <ATen/ops/nansum_meta.h>
|
| 958 |
+
#include <ATen/ops/narrow_meta.h>
|
| 959 |
+
#include <ATen/ops/narrow_copy_meta.h>
|
| 960 |
+
#include <ATen/ops/native_batch_norm_meta.h>
|
| 961 |
+
#include <ATen/ops/native_batch_norm_backward_meta.h>
|
| 962 |
+
#include <ATen/ops/native_channel_shuffle_meta.h>
|
| 963 |
+
#include <ATen/ops/native_dropout_meta.h>
|
| 964 |
+
#include <ATen/ops/native_dropout_backward_meta.h>
|
| 965 |
+
#include <ATen/ops/native_group_norm_meta.h>
|
| 966 |
+
#include <ATen/ops/native_group_norm_backward_meta.h>
|
| 967 |
+
#include <ATen/ops/native_layer_norm_meta.h>
|
| 968 |
+
#include <ATen/ops/native_layer_norm_backward_meta.h>
|
| 969 |
+
#include <ATen/ops/native_norm_meta.h>
|
| 970 |
+
#include <ATen/ops/ne_meta.h>
|
| 971 |
+
#include <ATen/ops/neg_meta.h>
|
| 972 |
+
#include <ATen/ops/negative_meta.h>
|
| 973 |
+
#include <ATen/ops/nested_to_padded_tensor_meta.h>
|
| 974 |
+
#include <ATen/ops/new_empty_meta.h>
|
| 975 |
+
#include <ATen/ops/new_empty_strided_meta.h>
|
| 976 |
+
#include <ATen/ops/new_full_meta.h>
|
| 977 |
+
#include <ATen/ops/new_ones_meta.h>
|
| 978 |
+
#include <ATen/ops/new_zeros_meta.h>
|
| 979 |
+
#include <ATen/ops/nextafter_meta.h>
|
| 980 |
+
#include <ATen/ops/nll_loss_meta.h>
|
| 981 |
+
#include <ATen/ops/nll_loss2d_meta.h>
|
| 982 |
+
#include <ATen/ops/nll_loss2d_backward_meta.h>
|
| 983 |
+
#include <ATen/ops/nll_loss2d_forward_meta.h>
|
| 984 |
+
#include <ATen/ops/nll_loss_backward_meta.h>
|
| 985 |
+
#include <ATen/ops/nll_loss_forward_meta.h>
|
| 986 |
+
#include <ATen/ops/nll_loss_nd_meta.h>
|
| 987 |
+
#include <ATen/ops/nonzero_meta.h>
|
| 988 |
+
#include <ATen/ops/nonzero_numpy_meta.h>
|
| 989 |
+
#include <ATen/ops/nonzero_static_meta.h>
|
| 990 |
+
#include <ATen/ops/norm_meta.h>
|
| 991 |
+
#include <ATen/ops/norm_except_dim_meta.h>
|
| 992 |
+
#include <ATen/ops/normal_meta.h>
|
| 993 |
+
#include <ATen/ops/not_equal_meta.h>
|
| 994 |
+
#include <ATen/ops/nuclear_norm_meta.h>
|
| 995 |
+
#include <ATen/ops/numpy_T_meta.h>
|
| 996 |
+
#include <ATen/ops/one_hot_meta.h>
|
| 997 |
+
#include <ATen/ops/ones_meta.h>
|
| 998 |
+
#include <ATen/ops/ones_like_meta.h>
|
| 999 |
+
#include <ATen/ops/or_meta.h>
|
| 1000 |
+
#include <ATen/ops/orgqr_meta.h>
|
| 1001 |
+
#include <ATen/ops/ormqr_meta.h>
|
| 1002 |
+
#include <ATen/ops/outer_meta.h>
|
| 1003 |
+
#include <ATen/ops/output_nr_meta.h>
|
| 1004 |
+
#include <ATen/ops/pad_meta.h>
|
| 1005 |
+
#include <ATen/ops/pad_sequence_meta.h>
|
| 1006 |
+
#include <ATen/ops/pairwise_distance_meta.h>
|
| 1007 |
+
#include <ATen/ops/pdist_meta.h>
|
| 1008 |
+
#include <ATen/ops/permute_meta.h>
|
| 1009 |
+
#include <ATen/ops/permute_copy_meta.h>
|
| 1010 |
+
#include <ATen/ops/pin_memory_meta.h>
|
| 1011 |
+
#include <ATen/ops/pinverse_meta.h>
|
| 1012 |
+
#include <ATen/ops/pixel_shuffle_meta.h>
|
| 1013 |
+
#include <ATen/ops/pixel_unshuffle_meta.h>
|
| 1014 |
+
#include <ATen/ops/poisson_meta.h>
|
| 1015 |
+
#include <ATen/ops/poisson_nll_loss_meta.h>
|
| 1016 |
+
#include <ATen/ops/polar_meta.h>
|
| 1017 |
+
#include <ATen/ops/polygamma_meta.h>
|
| 1018 |
+
#include <ATen/ops/positive_meta.h>
|
| 1019 |
+
#include <ATen/ops/pow_meta.h>
|
| 1020 |
+
#include <ATen/ops/prelu_meta.h>
|
| 1021 |
+
#include <ATen/ops/prod_meta.h>
|
| 1022 |
+
#include <ATen/ops/promote_types_meta.h>
|
| 1023 |
+
#include <ATen/ops/put_meta.h>
|
| 1024 |
+
#include <ATen/ops/q_per_channel_axis_meta.h>
|
| 1025 |
+
#include <ATen/ops/q_per_channel_scales_meta.h>
|
| 1026 |
+
#include <ATen/ops/q_per_channel_zero_points_meta.h>
|
| 1027 |
+
#include <ATen/ops/q_scale_meta.h>
|
| 1028 |
+
#include <ATen/ops/q_zero_point_meta.h>
|
| 1029 |
+
#include <ATen/ops/qr_meta.h>
|
| 1030 |
+
#include <ATen/ops/qscheme_meta.h>
|
| 1031 |
+
#include <ATen/ops/quantile_meta.h>
|
| 1032 |
+
#include <ATen/ops/quantize_per_channel_meta.h>
|
| 1033 |
+
#include <ATen/ops/quantize_per_tensor_meta.h>
|
| 1034 |
+
#include <ATen/ops/quantize_per_tensor_dynamic_meta.h>
|
| 1035 |
+
#include <ATen/ops/quantized_batch_norm_meta.h>
|
| 1036 |
+
#include <ATen/ops/quantized_gru_cell_meta.h>
|
| 1037 |
+
#include <ATen/ops/quantized_lstm_cell_meta.h>
|
| 1038 |
+
#include <ATen/ops/quantized_max_pool1d_meta.h>
|
| 1039 |
+
#include <ATen/ops/quantized_max_pool2d_meta.h>
|
| 1040 |
+
#include <ATen/ops/quantized_max_pool3d_meta.h>
|
| 1041 |
+
#include <ATen/ops/quantized_rnn_relu_cell_meta.h>
|
| 1042 |
+
#include <ATen/ops/quantized_rnn_tanh_cell_meta.h>
|
| 1043 |
+
#include <ATen/ops/rad2deg_meta.h>
|
| 1044 |
+
#include <ATen/ops/rand_meta.h>
|
| 1045 |
+
#include <ATen/ops/rand_like_meta.h>
|
| 1046 |
+
#include <ATen/ops/randint_meta.h>
|
| 1047 |
+
#include <ATen/ops/randint_like_meta.h>
|
| 1048 |
+
#include <ATen/ops/randn_meta.h>
|
| 1049 |
+
#include <ATen/ops/randn_like_meta.h>
|
| 1050 |
+
#include <ATen/ops/random_meta.h>
|
| 1051 |
+
#include <ATen/ops/randperm_meta.h>
|
| 1052 |
+
#include <ATen/ops/range_meta.h>
|
| 1053 |
+
#include <ATen/ops/ravel_meta.h>
|
| 1054 |
+
#include <ATen/ops/real_meta.h>
|
| 1055 |
+
#include <ATen/ops/reciprocal_meta.h>
|
| 1056 |
+
#include <ATen/ops/record_stream_meta.h>
|
| 1057 |
+
#include <ATen/ops/refine_names_meta.h>
|
| 1058 |
+
#include <ATen/ops/reflection_pad1d_meta.h>
|
| 1059 |
+
#include <ATen/ops/reflection_pad1d_backward_meta.h>
|
| 1060 |
+
#include <ATen/ops/reflection_pad2d_meta.h>
|
| 1061 |
+
#include <ATen/ops/reflection_pad2d_backward_meta.h>
|
| 1062 |
+
#include <ATen/ops/reflection_pad3d_meta.h>
|
| 1063 |
+
#include <ATen/ops/reflection_pad3d_backward_meta.h>
|
| 1064 |
+
#include <ATen/ops/relu_meta.h>
|
| 1065 |
+
#include <ATen/ops/relu6_meta.h>
|
| 1066 |
+
#include <ATen/ops/remainder_meta.h>
|
| 1067 |
+
#include <ATen/ops/rename_meta.h>
|
| 1068 |
+
#include <ATen/ops/renorm_meta.h>
|
| 1069 |
+
#include <ATen/ops/repeat_meta.h>
|
| 1070 |
+
#include <ATen/ops/repeat_interleave_meta.h>
|
| 1071 |
+
#include <ATen/ops/replication_pad1d_meta.h>
|
| 1072 |
+
#include <ATen/ops/replication_pad1d_backward_meta.h>
|
| 1073 |
+
#include <ATen/ops/replication_pad2d_meta.h>
|
| 1074 |
+
#include <ATen/ops/replication_pad2d_backward_meta.h>
|
| 1075 |
+
#include <ATen/ops/replication_pad3d_meta.h>
|
| 1076 |
+
#include <ATen/ops/replication_pad3d_backward_meta.h>
|
| 1077 |
+
#include <ATen/ops/requires_grad_meta.h>
|
| 1078 |
+
#include <ATen/ops/reshape_meta.h>
|
| 1079 |
+
#include <ATen/ops/reshape_as_meta.h>
|
| 1080 |
+
#include <ATen/ops/resize_meta.h>
|
| 1081 |
+
#include <ATen/ops/resize_as_meta.h>
|
| 1082 |
+
#include <ATen/ops/resize_as_sparse_meta.h>
|
| 1083 |
+
#include <ATen/ops/resolve_conj_meta.h>
|
| 1084 |
+
#include <ATen/ops/resolve_neg_meta.h>
|
| 1085 |
+
#include <ATen/ops/result_type_meta.h>
|
| 1086 |
+
#include <ATen/ops/retain_grad_meta.h>
|
| 1087 |
+
#include <ATen/ops/retains_grad_meta.h>
|
| 1088 |
+
#include <ATen/ops/rms_norm_meta.h>
|
| 1089 |
+
#include <ATen/ops/rnn_relu_meta.h>
|
| 1090 |
+
#include <ATen/ops/rnn_relu_cell_meta.h>
|
| 1091 |
+
#include <ATen/ops/rnn_tanh_meta.h>
|
| 1092 |
+
#include <ATen/ops/rnn_tanh_cell_meta.h>
|
| 1093 |
+
#include <ATen/ops/roll_meta.h>
|
| 1094 |
+
#include <ATen/ops/rot90_meta.h>
|
| 1095 |
+
#include <ATen/ops/round_meta.h>
|
| 1096 |
+
#include <ATen/ops/row_indices_meta.h>
|
| 1097 |
+
#include <ATen/ops/row_indices_copy_meta.h>
|
| 1098 |
+
#include <ATen/ops/row_stack_meta.h>
|
| 1099 |
+
#include <ATen/ops/rrelu_meta.h>
|
| 1100 |
+
#include <ATen/ops/rrelu_with_noise_meta.h>
|
| 1101 |
+
#include <ATen/ops/rrelu_with_noise_backward_meta.h>
|
| 1102 |
+
#include <ATen/ops/rshift_meta.h>
|
| 1103 |
+
#include <ATen/ops/rsqrt_meta.h>
|
| 1104 |
+
#include <ATen/ops/rsub_meta.h>
|
| 1105 |
+
#include <ATen/ops/scalar_tensor_meta.h>
|
| 1106 |
+
#include <ATen/ops/scaled_dot_product_attention_meta.h>
|
| 1107 |
+
#include <ATen/ops/scatter_meta.h>
|
| 1108 |
+
#include <ATen/ops/scatter_add_meta.h>
|
| 1109 |
+
#include <ATen/ops/scatter_reduce_meta.h>
|
| 1110 |
+
#include <ATen/ops/searchsorted_meta.h>
|
| 1111 |
+
#include <ATen/ops/segment_reduce_meta.h>
|
| 1112 |
+
#include <ATen/ops/select_meta.h>
|
| 1113 |
+
#include <ATen/ops/select_backward_meta.h>
|
| 1114 |
+
#include <ATen/ops/select_copy_meta.h>
|
| 1115 |
+
#include <ATen/ops/select_scatter_meta.h>
|
| 1116 |
+
#include <ATen/ops/selu_meta.h>
|
| 1117 |
+
#include <ATen/ops/set_meta.h>
|
| 1118 |
+
#include <ATen/ops/set_data_meta.h>
|
| 1119 |
+
#include <ATen/ops/sgn_meta.h>
|
| 1120 |
+
#include <ATen/ops/sigmoid_meta.h>
|
| 1121 |
+
#include <ATen/ops/sigmoid_backward_meta.h>
|
| 1122 |
+
#include <ATen/ops/sign_meta.h>
|
| 1123 |
+
#include <ATen/ops/signbit_meta.h>
|
| 1124 |
+
#include <ATen/ops/silu_meta.h>
|
| 1125 |
+
#include <ATen/ops/silu_backward_meta.h>
|
| 1126 |
+
#include <ATen/ops/sin_meta.h>
|
| 1127 |
+
#include <ATen/ops/sinc_meta.h>
|
| 1128 |
+
#include <ATen/ops/sinh_meta.h>
|
| 1129 |
+
#include <ATen/ops/size_meta.h>
|
| 1130 |
+
#include <ATen/ops/slice_meta.h>
|
| 1131 |
+
#include <ATen/ops/slice_backward_meta.h>
|
| 1132 |
+
#include <ATen/ops/slice_copy_meta.h>
|
| 1133 |
+
#include <ATen/ops/slice_inverse_meta.h>
|
| 1134 |
+
#include <ATen/ops/slice_scatter_meta.h>
|
| 1135 |
+
#include <ATen/ops/slogdet_meta.h>
|
| 1136 |
+
#include <ATen/ops/slow_conv3d_meta.h>
|
| 1137 |
+
#include <ATen/ops/slow_conv3d_forward_meta.h>
|
| 1138 |
+
#include <ATen/ops/slow_conv_dilated2d_meta.h>
|
| 1139 |
+
#include <ATen/ops/slow_conv_dilated3d_meta.h>
|
| 1140 |
+
#include <ATen/ops/slow_conv_transpose2d_meta.h>
|
| 1141 |
+
#include <ATen/ops/slow_conv_transpose3d_meta.h>
|
| 1142 |
+
#include <ATen/ops/smm_meta.h>
|
| 1143 |
+
#include <ATen/ops/smooth_l1_loss_meta.h>
|
| 1144 |
+
#include <ATen/ops/smooth_l1_loss_backward_meta.h>
|
| 1145 |
+
#include <ATen/ops/soft_margin_loss_meta.h>
|
| 1146 |
+
#include <ATen/ops/soft_margin_loss_backward_meta.h>
|
| 1147 |
+
#include <ATen/ops/softmax_meta.h>
|
| 1148 |
+
#include <ATen/ops/softplus_meta.h>
|
| 1149 |
+
#include <ATen/ops/softplus_backward_meta.h>
|
| 1150 |
+
#include <ATen/ops/softshrink_meta.h>
|
| 1151 |
+
#include <ATen/ops/softshrink_backward_meta.h>
|
| 1152 |
+
#include <ATen/ops/sort_meta.h>
|
| 1153 |
+
#include <ATen/ops/sparse_bsc_tensor_meta.h>
|
| 1154 |
+
#include <ATen/ops/sparse_bsr_tensor_meta.h>
|
| 1155 |
+
#include <ATen/ops/sparse_compressed_tensor_meta.h>
|
| 1156 |
+
#include <ATen/ops/sparse_coo_tensor_meta.h>
|
| 1157 |
+
#include <ATen/ops/sparse_csc_tensor_meta.h>
|
| 1158 |
+
#include <ATen/ops/sparse_csr_tensor_meta.h>
|
| 1159 |
+
#include <ATen/ops/sparse_dim_meta.h>
|
| 1160 |
+
#include <ATen/ops/sparse_mask_meta.h>
|
| 1161 |
+
#include <ATen/ops/sparse_resize_meta.h>
|
| 1162 |
+
#include <ATen/ops/sparse_resize_and_clear_meta.h>
|
| 1163 |
+
#include <ATen/ops/sparse_sampled_addmm_meta.h>
|
| 1164 |
+
#include <ATen/ops/special_airy_ai_meta.h>
|
| 1165 |
+
#include <ATen/ops/special_bessel_j0_meta.h>
|
| 1166 |
+
#include <ATen/ops/special_bessel_j1_meta.h>
|
| 1167 |
+
#include <ATen/ops/special_bessel_y0_meta.h>
|
| 1168 |
+
#include <ATen/ops/special_bessel_y1_meta.h>
|
| 1169 |
+
#include <ATen/ops/special_chebyshev_polynomial_t_meta.h>
|
| 1170 |
+
#include <ATen/ops/special_chebyshev_polynomial_u_meta.h>
|
| 1171 |
+
#include <ATen/ops/special_chebyshev_polynomial_v_meta.h>
|
| 1172 |
+
#include <ATen/ops/special_chebyshev_polynomial_w_meta.h>
|
| 1173 |
+
#include <ATen/ops/special_digamma_meta.h>
|
| 1174 |
+
#include <ATen/ops/special_entr_meta.h>
|
| 1175 |
+
#include <ATen/ops/special_erf_meta.h>
|
| 1176 |
+
#include <ATen/ops/special_erfc_meta.h>
|
| 1177 |
+
#include <ATen/ops/special_erfcx_meta.h>
|
| 1178 |
+
#include <ATen/ops/special_erfinv_meta.h>
|
| 1179 |
+
#include <ATen/ops/special_exp2_meta.h>
|
| 1180 |
+
#include <ATen/ops/special_expit_meta.h>
|
| 1181 |
+
#include <ATen/ops/special_expm1_meta.h>
|
| 1182 |
+
#include <ATen/ops/special_gammainc_meta.h>
|
| 1183 |
+
#include <ATen/ops/special_gammaincc_meta.h>
|
| 1184 |
+
#include <ATen/ops/special_gammaln_meta.h>
|
| 1185 |
+
#include <ATen/ops/special_hermite_polynomial_h_meta.h>
|
| 1186 |
+
#include <ATen/ops/special_hermite_polynomial_he_meta.h>
|
| 1187 |
+
#include <ATen/ops/special_i0_meta.h>
|
| 1188 |
+
#include <ATen/ops/special_i0e_meta.h>
|
| 1189 |
+
#include <ATen/ops/special_i1_meta.h>
|
| 1190 |
+
#include <ATen/ops/special_i1e_meta.h>
|
| 1191 |
+
#include <ATen/ops/special_laguerre_polynomial_l_meta.h>
|
| 1192 |
+
#include <ATen/ops/special_legendre_polynomial_p_meta.h>
|
| 1193 |
+
#include <ATen/ops/special_log1p_meta.h>
|
| 1194 |
+
#include <ATen/ops/special_log_ndtr_meta.h>
|
| 1195 |
+
#include <ATen/ops/special_log_softmax_meta.h>
|
| 1196 |
+
#include <ATen/ops/special_logit_meta.h>
|
| 1197 |
+
#include <ATen/ops/special_logsumexp_meta.h>
|
| 1198 |
+
#include <ATen/ops/special_modified_bessel_i0_meta.h>
|
| 1199 |
+
#include <ATen/ops/special_modified_bessel_i1_meta.h>
|
| 1200 |
+
#include <ATen/ops/special_modified_bessel_k0_meta.h>
|
| 1201 |
+
#include <ATen/ops/special_modified_bessel_k1_meta.h>
|
| 1202 |
+
#include <ATen/ops/special_multigammaln_meta.h>
|
| 1203 |
+
#include <ATen/ops/special_ndtr_meta.h>
|
| 1204 |
+
#include <ATen/ops/special_ndtri_meta.h>
|
| 1205 |
+
#include <ATen/ops/special_polygamma_meta.h>
|
| 1206 |
+
#include <ATen/ops/special_psi_meta.h>
|
| 1207 |
+
#include <ATen/ops/special_round_meta.h>
|
| 1208 |
+
#include <ATen/ops/special_scaled_modified_bessel_k0_meta.h>
|
| 1209 |
+
#include <ATen/ops/special_scaled_modified_bessel_k1_meta.h>
|
| 1210 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_t_meta.h>
|
| 1211 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_u_meta.h>
|
| 1212 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_v_meta.h>
|
| 1213 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_w_meta.h>
|
| 1214 |
+
#include <ATen/ops/special_sinc_meta.h>
|
| 1215 |
+
#include <ATen/ops/special_softmax_meta.h>
|
| 1216 |
+
#include <ATen/ops/special_spherical_bessel_j0_meta.h>
|
| 1217 |
+
#include <ATen/ops/special_xlog1py_meta.h>
|
| 1218 |
+
#include <ATen/ops/special_xlogy_meta.h>
|
| 1219 |
+
#include <ATen/ops/special_zeta_meta.h>
|
| 1220 |
+
#include <ATen/ops/split_meta.h>
|
| 1221 |
+
#include <ATen/ops/split_copy_meta.h>
|
| 1222 |
+
#include <ATen/ops/split_with_sizes_meta.h>
|
| 1223 |
+
#include <ATen/ops/split_with_sizes_copy_meta.h>
|
| 1224 |
+
#include <ATen/ops/sqrt_meta.h>
|
| 1225 |
+
#include <ATen/ops/square_meta.h>
|
| 1226 |
+
#include <ATen/ops/squeeze_meta.h>
|
| 1227 |
+
#include <ATen/ops/squeeze_copy_meta.h>
|
| 1228 |
+
#include <ATen/ops/sspaddmm_meta.h>
|
| 1229 |
+
#include <ATen/ops/stack_meta.h>
|
| 1230 |
+
#include <ATen/ops/std_meta.h>
|
| 1231 |
+
#include <ATen/ops/std_mean_meta.h>
|
| 1232 |
+
#include <ATen/ops/stft_meta.h>
|
| 1233 |
+
#include <ATen/ops/stride_meta.h>
|
| 1234 |
+
#include <ATen/ops/sub_meta.h>
|
| 1235 |
+
#include <ATen/ops/subtract_meta.h>
|
| 1236 |
+
#include <ATen/ops/sum_meta.h>
|
| 1237 |
+
#include <ATen/ops/sum_to_size_meta.h>
|
| 1238 |
+
#include <ATen/ops/svd_meta.h>
|
| 1239 |
+
#include <ATen/ops/swapaxes_meta.h>
|
| 1240 |
+
#include <ATen/ops/swapdims_meta.h>
|
| 1241 |
+
#include <ATen/ops/sym_constrain_range_meta.h>
|
| 1242 |
+
#include <ATen/ops/sym_constrain_range_for_size_meta.h>
|
| 1243 |
+
#include <ATen/ops/sym_is_contiguous_meta.h>
|
| 1244 |
+
#include <ATen/ops/sym_numel_meta.h>
|
| 1245 |
+
#include <ATen/ops/sym_size_meta.h>
|
| 1246 |
+
#include <ATen/ops/sym_storage_offset_meta.h>
|
| 1247 |
+
#include <ATen/ops/sym_stride_meta.h>
|
| 1248 |
+
#include <ATen/ops/t_meta.h>
|
| 1249 |
+
#include <ATen/ops/t_copy_meta.h>
|
| 1250 |
+
#include <ATen/ops/take_meta.h>
|
| 1251 |
+
#include <ATen/ops/take_along_dim_meta.h>
|
| 1252 |
+
#include <ATen/ops/tan_meta.h>
|
| 1253 |
+
#include <ATen/ops/tanh_meta.h>
|
| 1254 |
+
#include <ATen/ops/tanh_backward_meta.h>
|
| 1255 |
+
#include <ATen/ops/tensor_split_meta.h>
|
| 1256 |
+
#include <ATen/ops/tensordot_meta.h>
|
| 1257 |
+
#include <ATen/ops/thnn_conv2d_meta.h>
|
| 1258 |
+
#include <ATen/ops/threshold_meta.h>
|
| 1259 |
+
#include <ATen/ops/threshold_backward_meta.h>
|
| 1260 |
+
#include <ATen/ops/tile_meta.h>
|
| 1261 |
+
#include <ATen/ops/to_meta.h>
|
| 1262 |
+
#include <ATen/ops/to_dense_meta.h>
|
| 1263 |
+
#include <ATen/ops/to_dense_backward_meta.h>
|
| 1264 |
+
#include <ATen/ops/to_mkldnn_meta.h>
|
| 1265 |
+
#include <ATen/ops/to_mkldnn_backward_meta.h>
|
| 1266 |
+
#include <ATen/ops/to_padded_tensor_meta.h>
|
| 1267 |
+
#include <ATen/ops/to_sparse_meta.h>
|
| 1268 |
+
#include <ATen/ops/to_sparse_bsc_meta.h>
|
| 1269 |
+
#include <ATen/ops/to_sparse_bsr_meta.h>
|
| 1270 |
+
#include <ATen/ops/to_sparse_csc_meta.h>
|
| 1271 |
+
#include <ATen/ops/to_sparse_csr_meta.h>
|
| 1272 |
+
#include <ATen/ops/topk_meta.h>
|
| 1273 |
+
#include <ATen/ops/trace_meta.h>
|
| 1274 |
+
#include <ATen/ops/trace_backward_meta.h>
|
| 1275 |
+
#include <ATen/ops/transpose_meta.h>
|
| 1276 |
+
#include <ATen/ops/transpose_copy_meta.h>
|
| 1277 |
+
#include <ATen/ops/trapezoid_meta.h>
|
| 1278 |
+
#include <ATen/ops/trapz_meta.h>
|
| 1279 |
+
#include <ATen/ops/triangular_solve_meta.h>
|
| 1280 |
+
#include <ATen/ops/tril_meta.h>
|
| 1281 |
+
#include <ATen/ops/tril_indices_meta.h>
|
| 1282 |
+
#include <ATen/ops/triplet_margin_loss_meta.h>
|
| 1283 |
+
#include <ATen/ops/triu_meta.h>
|
| 1284 |
+
#include <ATen/ops/triu_indices_meta.h>
|
| 1285 |
+
#include <ATen/ops/true_divide_meta.h>
|
| 1286 |
+
#include <ATen/ops/trunc_meta.h>
|
| 1287 |
+
#include <ATen/ops/type_as_meta.h>
|
| 1288 |
+
#include <ATen/ops/unbind_meta.h>
|
| 1289 |
+
#include <ATen/ops/unbind_copy_meta.h>
|
| 1290 |
+
#include <ATen/ops/unflatten_meta.h>
|
| 1291 |
+
#include <ATen/ops/unflatten_dense_tensors_meta.h>
|
| 1292 |
+
#include <ATen/ops/unfold_meta.h>
|
| 1293 |
+
#include <ATen/ops/unfold_backward_meta.h>
|
| 1294 |
+
#include <ATen/ops/unfold_copy_meta.h>
|
| 1295 |
+
#include <ATen/ops/uniform_meta.h>
|
| 1296 |
+
#include <ATen/ops/unique_consecutive_meta.h>
|
| 1297 |
+
#include <ATen/ops/unique_dim_meta.h>
|
| 1298 |
+
#include <ATen/ops/unique_dim_consecutive_meta.h>
|
| 1299 |
+
#include <ATen/ops/unsafe_chunk_meta.h>
|
| 1300 |
+
#include <ATen/ops/unsafe_split_meta.h>
|
| 1301 |
+
#include <ATen/ops/unsafe_split_with_sizes_meta.h>
|
| 1302 |
+
#include <ATen/ops/unsqueeze_meta.h>
|
| 1303 |
+
#include <ATen/ops/unsqueeze_copy_meta.h>
|
| 1304 |
+
#include <ATen/ops/upsample_bicubic2d_meta.h>
|
| 1305 |
+
#include <ATen/ops/upsample_bicubic2d_backward_meta.h>
|
| 1306 |
+
#include <ATen/ops/upsample_bilinear2d_meta.h>
|
| 1307 |
+
#include <ATen/ops/upsample_bilinear2d_backward_meta.h>
|
| 1308 |
+
#include <ATen/ops/upsample_linear1d_meta.h>
|
| 1309 |
+
#include <ATen/ops/upsample_linear1d_backward_meta.h>
|
| 1310 |
+
#include <ATen/ops/upsample_nearest1d_meta.h>
|
| 1311 |
+
#include <ATen/ops/upsample_nearest1d_backward_meta.h>
|
| 1312 |
+
#include <ATen/ops/upsample_nearest2d_meta.h>
|
| 1313 |
+
#include <ATen/ops/upsample_nearest2d_backward_meta.h>
|
| 1314 |
+
#include <ATen/ops/upsample_nearest3d_meta.h>
|
| 1315 |
+
#include <ATen/ops/upsample_nearest3d_backward_meta.h>
|
| 1316 |
+
#include <ATen/ops/upsample_trilinear3d_meta.h>
|
| 1317 |
+
#include <ATen/ops/upsample_trilinear3d_backward_meta.h>
|
| 1318 |
+
#include <ATen/ops/value_selecting_reduction_backward_meta.h>
|
| 1319 |
+
#include <ATen/ops/values_meta.h>
|
| 1320 |
+
#include <ATen/ops/values_copy_meta.h>
|
| 1321 |
+
#include <ATen/ops/vander_meta.h>
|
| 1322 |
+
#include <ATen/ops/var_meta.h>
|
| 1323 |
+
#include <ATen/ops/var_mean_meta.h>
|
| 1324 |
+
#include <ATen/ops/vdot_meta.h>
|
| 1325 |
+
#include <ATen/ops/view_meta.h>
|
| 1326 |
+
#include <ATen/ops/view_as_meta.h>
|
| 1327 |
+
#include <ATen/ops/view_as_complex_meta.h>
|
| 1328 |
+
#include <ATen/ops/view_as_complex_copy_meta.h>
|
| 1329 |
+
#include <ATen/ops/view_as_real_meta.h>
|
| 1330 |
+
#include <ATen/ops/view_as_real_copy_meta.h>
|
| 1331 |
+
#include <ATen/ops/view_copy_meta.h>
|
| 1332 |
+
#include <ATen/ops/vsplit_meta.h>
|
| 1333 |
+
#include <ATen/ops/vstack_meta.h>
|
| 1334 |
+
#include <ATen/ops/where_meta.h>
|
| 1335 |
+
#include <ATen/ops/xlogy_meta.h>
|
| 1336 |
+
#include <ATen/ops/xor_meta.h>
|
| 1337 |
+
#include <ATen/ops/zero_meta.h>
|
| 1338 |
+
#include <ATen/ops/zeros_meta.h>
|
| 1339 |
+
#include <ATen/ops/zeros_like_meta.h>
|
| 1340 |
+
|
| 1341 |
+
namespace at {
|
| 1342 |
+
|
| 1343 |
+
namespace meta {
|
| 1344 |
+
|
| 1345 |
+
|
| 1346 |
+
|
| 1347 |
+
} // namespace meta
|
| 1348 |
+
} // namespace at
|
| 1349 |
+
|
| 1350 |
+
#else
|
| 1351 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 1352 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/NestedTensorImpl.h
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
#include <ATen/MemoryOverlap.h>
|
| 4 |
+
#include <ATen/Tensor.h>
|
| 5 |
+
#include <c10/core/DispatchKey.h>
|
| 6 |
+
#include <c10/core/DispatchKeySet.h>
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/TensorImpl.h>
|
| 9 |
+
#include <c10/util/ArrayRef.h>
|
| 10 |
+
#include <c10/util/Exception.h>
|
| 11 |
+
#include <c10/util/Metaprogramming.h>
|
| 12 |
+
#include <c10/util/irange.h>
|
| 13 |
+
|
| 14 |
+
namespace at::native {
|
| 15 |
+
struct NestedTensorImpl;
|
| 16 |
+
inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt);
|
| 17 |
+
int64_t get_numel_from_nested_size_tensor(const at::Tensor& tensor);
|
| 18 |
+
at::Tensor construct_nested_strides(const at::Tensor& nested_size);
|
| 19 |
+
at::Tensor construct_offsets(const at::Tensor& nested_size);
|
| 20 |
+
|
| 21 |
+
struct TORCH_API NestedTensorImpl : public c10::TensorImpl {
|
| 22 |
+
explicit NestedTensorImpl(
|
| 23 |
+
Storage storage,
|
| 24 |
+
c10::DispatchKeySet key_set,
|
| 25 |
+
const caffe2::TypeMeta data_type,
|
| 26 |
+
at::Tensor nested_sizes,
|
| 27 |
+
at::Tensor nested_strides,
|
| 28 |
+
at::Tensor storage_offsets);
|
| 29 |
+
|
| 30 |
+
explicit NestedTensorImpl(
|
| 31 |
+
const at::Tensor& buffer,
|
| 32 |
+
at::Tensor nested_sizes,
|
| 33 |
+
at::Tensor nested_strides,
|
| 34 |
+
at::Tensor storage_offsets);
|
| 35 |
+
// assume contiguous, `nested_strides` and `offsets`
|
| 36 |
+
// can be inferred from `nested_sizes`
|
| 37 |
+
explicit NestedTensorImpl(
|
| 38 |
+
const at::Tensor& buffer,
|
| 39 |
+
const at::Tensor& nested_sizes);
|
| 40 |
+
|
| 41 |
+
// This constructor is used creating view tensors from nested tensors
|
| 42 |
+
explicit NestedTensorImpl(
|
| 43 |
+
c10::TensorImpl::ImplType impl_type,
|
| 44 |
+
const at::Tensor& base_tensor,
|
| 45 |
+
at::Tensor nested_sizes,
|
| 46 |
+
at::Tensor nested_strides,
|
| 47 |
+
at::Tensor storage_offsets);
|
| 48 |
+
|
| 49 |
+
// TODO: don't expose private implementation details like this; in
|
| 50 |
+
// particular, resizing this tensor will mess up our dim() and
|
| 51 |
+
// callers cannot fix it.
|
| 52 |
+
const Tensor& get_nested_sizes() const {
|
| 53 |
+
return nested_sizes_;
|
| 54 |
+
}
|
| 55 |
+
// TODO: don't expose private implementation details like this
|
| 56 |
+
const Tensor& get_nested_strides() const {
|
| 57 |
+
return nested_strides_;
|
| 58 |
+
}
|
| 59 |
+
const Tensor& get_storage_offsets() const {
|
| 60 |
+
return storage_offsets_;
|
| 61 |
+
}
|
| 62 |
+
// Returns nullopt if the ith dimension is irregular. The ith dimension
|
| 63 |
+
// of a NestedTensor is regular if the unbound tensors match in
|
| 64 |
+
// size at the (i-1)th dimension.
|
| 65 |
+
std::optional<int64_t> opt_size(int64_t d) const;
|
| 66 |
+
|
| 67 |
+
int64_t size(int64_t d) const {
|
| 68 |
+
std::optional<int64_t> optional_size = this->opt_size(d);
|
| 69 |
+
TORCH_CHECK(
|
| 70 |
+
optional_size.has_value(),
|
| 71 |
+
"Given dimension ",
|
| 72 |
+
d,
|
| 73 |
+
" is irregular and does not have a size.");
|
| 74 |
+
return *optional_size;
|
| 75 |
+
}
|
| 76 |
+
/**
|
| 77 |
+
* Return a view of the nested tensor as a 1 dimensional contiguous tensor.
|
| 78 |
+
*
|
| 79 |
+
* The buffer tensor created by this function shares the same storage_impl as
|
| 80 |
+
* the original nested tensor, and therefore can be seen as a view.
|
| 81 |
+
*
|
| 82 |
+
* @return A newly constructed view tensor
|
| 83 |
+
*/
|
| 84 |
+
at::Tensor get_buffer() const {
|
| 85 |
+
TORCH_CHECK(
|
| 86 |
+
nested_tensor_impl_is_contiguous(this),
|
| 87 |
+
"NestedTensor must be contiguous to get buffer.");
|
| 88 |
+
return get_unsafe_storage_as_tensor();
|
| 89 |
+
}
|
| 90 |
+
/**
|
| 91 |
+
* If possible use get_buffer() instead. This function returns the storage
|
| 92 |
+
* as a tensor directly, which is not safe to use in general. If using this
|
| 93 |
+
* function, The caller must ensure to account for nested_sizes,
|
| 94 |
+
* nested_strides and storage_offsets.
|
| 95 |
+
*
|
| 96 |
+
* @return A newly constructed view tensor
|
| 97 |
+
*/
|
| 98 |
+
at::Tensor get_unsafe_storage_as_tensor() const {
|
| 99 |
+
auto buffer_key_set_ = generate_buffer_key_set();
|
| 100 |
+
const auto buffer_size = get_buffer_size();
|
| 101 |
+
auto buffer_tensor_impl = c10::make_intrusive<TensorImpl>(
|
| 102 |
+
c10::TensorImpl::VIEW, Storage(storage_), buffer_key_set_, data_type_);
|
| 103 |
+
buffer_tensor_impl->set_sizes_contiguous(
|
| 104 |
+
c10::makeArrayRef(static_cast<int64_t>(buffer_size)));
|
| 105 |
+
return Tensor(buffer_tensor_impl);
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
size_t get_buffer_size() const {
|
| 109 |
+
return storage_.nbytes() / data_type_.itemsize();
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
protected:
|
| 113 |
+
const char* tensorimpl_type_name() const override;
|
| 114 |
+
|
| 115 |
+
// TODO: numel_custom and is_contiguous_custom can be profitably overridden
|
| 116 |
+
// with real implementations
|
| 117 |
+
int64_t numel_custom() const override;
|
| 118 |
+
c10::SymInt sym_numel_custom() const override;
|
| 119 |
+
c10::SymBool sym_is_contiguous_custom(
|
| 120 |
+
MemoryFormat /*memory_format*/) const override;
|
| 121 |
+
int64_t size_custom(int64_t d) const override {
|
| 122 |
+
return this->size(d);
|
| 123 |
+
}
|
| 124 |
+
c10::SymInt sym_size_custom(int64_t d) const override {
|
| 125 |
+
return c10::SymInt{this->size(d)};
|
| 126 |
+
}
|
| 127 |
+
IntArrayRef sizes_custom() const override;
|
| 128 |
+
c10::SymIntArrayRef sym_sizes_custom() const override;
|
| 129 |
+
IntArrayRef strides_custom() const override;
|
| 130 |
+
c10::SymIntArrayRef sym_strides_custom() const override;
|
| 131 |
+
|
| 132 |
+
// this one is real
|
| 133 |
+
int64_t dim_custom() const override;
|
| 134 |
+
|
| 135 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 136 |
+
const c10::VariableVersion& version_counter,
|
| 137 |
+
bool allow_tensor_metadata_change) const override;
|
| 138 |
+
|
| 139 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 140 |
+
c10::VariableVersion&& version_counter,
|
| 141 |
+
bool allow_tensor_metadata_change) const override;
|
| 142 |
+
|
| 143 |
+
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
|
| 144 |
+
copy_tensor_metadata(
|
| 145 |
+
/*src_impl=*/impl.get(),
|
| 146 |
+
/*dest_impl=*/this,
|
| 147 |
+
/*version_counter=*/version_counter(),
|
| 148 |
+
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
private:
|
| 152 |
+
// Must be called after any changes to our dim() to sync the state
|
| 153 |
+
// to TensorImpl.
|
| 154 |
+
void refresh_dim();
|
| 155 |
+
|
| 156 |
+
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
| 157 |
+
const at::Tensor nested_sizes_, nested_strides_;
|
| 158 |
+
// The starting positions of the underlying tensors in contiguous buffer
|
| 159 |
+
// i.e. the buffer memory offsets to get the underlying tensors
|
| 160 |
+
// The reason to keep this metadata is that, without strong enough constraint
|
| 161 |
+
// it cannot be derived from `nested_sizes_`
|
| 162 |
+
// and `nested_strides_`:
|
| 163 |
+
// 1. when buffer has blanks, e.g. [tensor1, blank, tensor2]
|
| 164 |
+
// this can happen e.g. after slicing a nested tensor
|
| 165 |
+
// 2. when multiple tensors share a same memory
|
| 166 |
+
// 3. when the nesting ordering is changed, e.g. [tensor1, tensor3, tensor2]
|
| 167 |
+
// Some strong enough constraints are:
|
| 168 |
+
// 1. every underlying tensor is contiguous in memory
|
| 169 |
+
// && nesting in ascending order
|
| 170 |
+
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
| 171 |
+
const at::Tensor storage_offsets_;
|
| 172 |
+
// NOTE: -1 here means the size is missing
|
| 173 |
+
// Optional to allow it to be computed lazily from nested.
|
| 174 |
+
// TODO: maybe we can remove this metadata since
|
| 175 |
+
// we can compute it from `nested_sizes_`
|
| 176 |
+
mutable std::optional<std::vector<int64_t>> opt_sizes_;
|
| 177 |
+
|
| 178 |
+
template <typename VariableVersion>
|
| 179 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
|
| 180 |
+
VariableVersion&& version_counter,
|
| 181 |
+
bool allow_tensor_metadata_change) const;
|
| 182 |
+
|
| 183 |
+
/**
|
| 184 |
+
* Generates a non-nested key_set from a nested tensor.
|
| 185 |
+
*
|
| 186 |
+
* For many nested tensor kernel implementations a buffer tensor
|
| 187 |
+
* is generated and redispatched to a non-nested kernel this function
|
| 188 |
+
* generates the key set used by that buffer tensor
|
| 189 |
+
*
|
| 190 |
+
* @return Appropriate key set for non-nested tensor
|
| 191 |
+
*/
|
| 192 |
+
inline c10::DispatchKeySet generate_buffer_key_set() const {
|
| 193 |
+
auto buffer_key_set = this->key_set();
|
| 194 |
+
const bool Autograd = buffer_key_set.has_any(c10::autograd_dispatch_keyset);
|
| 195 |
+
// Remove nested tensor specific keys
|
| 196 |
+
buffer_key_set = buffer_key_set -
|
| 197 |
+
c10::DispatchKeySet{
|
| 198 |
+
c10::DispatchKey::NestedTensor,
|
| 199 |
+
c10::DispatchKey::AutogradNestedTensor};
|
| 200 |
+
|
| 201 |
+
// Add dense tensor specific keys
|
| 202 |
+
buffer_key_set =
|
| 203 |
+
buffer_key_set | c10::DispatchKeySet{c10::DispatchKey::Dense};
|
| 204 |
+
buffer_key_set = Autograd
|
| 205 |
+
? c10::DispatchKeySet{c10::DispatchKey::Autograd} | buffer_key_set
|
| 206 |
+
: buffer_key_set;
|
| 207 |
+
|
| 208 |
+
return buffer_key_set;
|
| 209 |
+
}
|
| 210 |
+
};
|
| 211 |
+
|
| 212 |
+
inline NestedTensorImpl* get_nested_tensor_impl_or_null(
|
| 213 |
+
const at::Tensor& tensor) {
|
| 214 |
+
if (tensor.is_nested()) {
|
| 215 |
+
return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
|
| 216 |
+
}
|
| 217 |
+
return nullptr;
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
inline NestedTensorImpl* get_nested_tensor_impl(const at::Tensor& tensor) {
|
| 221 |
+
TORCH_CHECK(
|
| 222 |
+
tensor.is_nested(), "get_nested_tensor_impl requires a NestedTensor.");
|
| 223 |
+
return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt) {
|
| 227 |
+
int64_t ntensors = nt->size(0);
|
| 228 |
+
if (ntensors == 0) {
|
| 229 |
+
return true;
|
| 230 |
+
}
|
| 231 |
+
const Tensor &sizemat = nt->get_nested_sizes(),
|
| 232 |
+
&stridemat = nt->get_nested_strides();
|
| 233 |
+
const int64_t* offsets_ptr =
|
| 234 |
+
nt->get_storage_offsets().const_data_ptr<int64_t>();
|
| 235 |
+
int64_t orig_dim = sizemat.size(1);
|
| 236 |
+
// nesting scalars
|
| 237 |
+
if (orig_dim == 0) {
|
| 238 |
+
// each scalar must be contiguous
|
| 239 |
+
// if there is blank memory between underlying scalars
|
| 240 |
+
for (int64_t i = 0; i < ntensors; i++) {
|
| 241 |
+
if (offsets_ptr[i] != i) {
|
| 242 |
+
return false;
|
| 243 |
+
}
|
| 244 |
+
}
|
| 245 |
+
}
|
| 246 |
+
// nesting tensors
|
| 247 |
+
else {
|
| 248 |
+
// if any underlying tensor is non-contiguous
|
| 249 |
+
const int64_t *sizemat_ptr = sizemat.const_data_ptr<int64_t>(),
|
| 250 |
+
*stridemat_ptr = stridemat.const_data_ptr<int64_t>();
|
| 251 |
+
for (int64_t i = 0; i < ntensors; i++) {
|
| 252 |
+
if (stridemat_ptr[orig_dim - 1] != 1) {
|
| 253 |
+
return false;
|
| 254 |
+
}
|
| 255 |
+
int64_t product = sizemat_ptr[orig_dim - 1];
|
| 256 |
+
for (int64_t j = orig_dim - 2; j >= 0; j--) {
|
| 257 |
+
if (stridemat_ptr[j] != product) {
|
| 258 |
+
return false;
|
| 259 |
+
}
|
| 260 |
+
product *= sizemat_ptr[j];
|
| 261 |
+
}
|
| 262 |
+
sizemat_ptr += orig_dim;
|
| 263 |
+
stridemat_ptr += orig_dim;
|
| 264 |
+
}
|
| 265 |
+
// if there is blank memory between underlying tensors
|
| 266 |
+
if (offsets_ptr[0] != 0) {
|
| 267 |
+
return false;
|
| 268 |
+
}
|
| 269 |
+
sizemat_ptr = sizemat.const_data_ptr<int64_t>();
|
| 270 |
+
stridemat_ptr = stridemat.const_data_ptr<int64_t>();
|
| 271 |
+
for (int64_t i = 1; i < ntensors; i++) {
|
| 272 |
+
if (offsets_ptr[i] !=
|
| 273 |
+
offsets_ptr[i - 1] + *sizemat_ptr * *stridemat_ptr) {
|
| 274 |
+
return false;
|
| 275 |
+
}
|
| 276 |
+
sizemat_ptr += orig_dim;
|
| 277 |
+
stridemat_ptr += orig_dim;
|
| 278 |
+
}
|
| 279 |
+
}
|
| 280 |
+
// everything is fine
|
| 281 |
+
return true;
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
inline const at::Tensor& get_nested_sizes(const at::Tensor& tensor) {
|
| 285 |
+
return get_nested_tensor_impl(tensor)->get_nested_sizes();
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
} // namespace at::native
|
| 289 |
+
|
| 290 |
+
#else
|
| 291 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 292 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/NumericUtils.h
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#ifdef __HIPCC__
|
| 5 |
+
#include <hip/hip_runtime.h>
|
| 6 |
+
#endif
|
| 7 |
+
|
| 8 |
+
#include <c10/macros/Macros.h>
|
| 9 |
+
#include <c10/util/BFloat16.h>
|
| 10 |
+
#include <c10/util/Float8_e4m3fn.h>
|
| 11 |
+
#include <c10/util/Float8_e4m3fnuz.h>
|
| 12 |
+
#include <c10/util/Float8_e5m2.h>
|
| 13 |
+
#include <c10/util/Float8_e5m2fnuz.h>
|
| 14 |
+
#include <c10/util/Half.h>
|
| 15 |
+
#include <c10/util/complex.h>
|
| 16 |
+
|
| 17 |
+
#include <cmath>
|
| 18 |
+
#include <type_traits>
|
| 19 |
+
|
| 20 |
+
namespace at {
|
| 21 |
+
|
| 22 |
+
// std::isnan isn't performant to use on integral types; it will
|
| 23 |
+
// (uselessly) convert to floating point and then do the test.
|
| 24 |
+
// This function is.
|
| 25 |
+
|
| 26 |
+
template <typename T, std::enable_if_t<std::is_integral_v<T>, int> = 0>
|
| 27 |
+
inline C10_HOST_DEVICE bool _isnan(T /*val*/) {
|
| 28 |
+
return false;
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
template <typename T, std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
|
| 32 |
+
inline C10_HOST_DEVICE bool _isnan(T val) {
|
| 33 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 34 |
+
return ::isnan(val);
|
| 35 |
+
#else
|
| 36 |
+
return std::isnan(val);
|
| 37 |
+
#endif
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
template <typename T, std::enable_if_t<c10::is_complex<T>::value, int> = 0>
|
| 41 |
+
inline C10_HOST_DEVICE bool _isnan(T val) {
|
| 42 |
+
return std::isnan(val.real()) || std::isnan(val.imag());
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
template <typename T, std::enable_if_t<std::is_same_v<T, at::Half>, int> = 0>
|
| 46 |
+
inline C10_HOST_DEVICE bool _isnan(T val) {
|
| 47 |
+
return at::_isnan(static_cast<float>(val));
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
template <
|
| 51 |
+
typename T,
|
| 52 |
+
std::enable_if_t<std::is_same_v<T, at::BFloat16>, int> = 0>
|
| 53 |
+
inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {
|
| 54 |
+
return at::_isnan(static_cast<float>(val));
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {
|
| 58 |
+
return at::_isnan(static_cast<float>(val));
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
template <
|
| 62 |
+
typename T,
|
| 63 |
+
std::enable_if_t<std::is_same_v<T, at::Float8_e5m2>, int> = 0>
|
| 64 |
+
inline C10_HOST_DEVICE bool _isnan(T val) {
|
| 65 |
+
return val.isnan();
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
template <
|
| 69 |
+
typename T,
|
| 70 |
+
std::enable_if_t<std::is_same_v<T, at::Float8_e4m3fn>, int> = 0>
|
| 71 |
+
inline C10_HOST_DEVICE bool _isnan(T val) {
|
| 72 |
+
return val.isnan();
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
template <
|
| 76 |
+
typename T,
|
| 77 |
+
std::enable_if_t<std::is_same_v<T, at::Float8_e5m2fnuz>, int> = 0>
|
| 78 |
+
inline C10_HOST_DEVICE bool _isnan(T val) {
|
| 79 |
+
return val.isnan();
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
template <
|
| 83 |
+
typename T,
|
| 84 |
+
std::enable_if_t<std::is_same_v<T, at::Float8_e4m3fnuz>, int> = 0>
|
| 85 |
+
inline C10_HOST_DEVICE bool _isnan(T val) {
|
| 86 |
+
return val.isnan();
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
// std::isinf isn't performant to use on integral types; it will
|
| 90 |
+
// (uselessly) convert to floating point and then do the test.
|
| 91 |
+
// This function is.
|
| 92 |
+
|
| 93 |
+
template <typename T, std::enable_if_t<std::is_integral_v<T>, int> = 0>
|
| 94 |
+
inline C10_HOST_DEVICE bool _isinf(T /*val*/) {
|
| 95 |
+
return false;
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
template <typename T, std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
|
| 99 |
+
inline C10_HOST_DEVICE bool _isinf(T val) {
|
| 100 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 101 |
+
return ::isinf(val);
|
| 102 |
+
#else
|
| 103 |
+
return std::isinf(val);
|
| 104 |
+
#endif
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
inline C10_HOST_DEVICE bool _isinf(at::Half val) {
|
| 108 |
+
return at::_isinf(static_cast<float>(val));
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
inline C10_HOST_DEVICE bool _isinf(at::BFloat16 val) {
|
| 112 |
+
return at::_isinf(static_cast<float>(val));
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2 val) {
|
| 116 |
+
return val.isinf();
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fn val [[maybe_unused]]) {
|
| 120 |
+
return false;
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2fnuz val [[maybe_unused]]) {
|
| 124 |
+
return false;
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fnuz val [[maybe_unused]]) {
|
| 128 |
+
return false;
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
template <typename T>
|
| 132 |
+
C10_HOST_DEVICE inline T exp(T x) {
|
| 133 |
+
static_assert(
|
| 134 |
+
!std::is_same_v<T, double>,
|
| 135 |
+
"this template must be used with float or less precise type");
|
| 136 |
+
#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
|
| 137 |
+
// use __expf fast approximation for peak bandwidth
|
| 138 |
+
return __expf(x);
|
| 139 |
+
#else
|
| 140 |
+
return ::exp(x);
|
| 141 |
+
#endif
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
template <>
|
| 145 |
+
C10_HOST_DEVICE inline double exp<double>(double x) {
|
| 146 |
+
return ::exp(x);
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
template <typename T>
|
| 150 |
+
C10_HOST_DEVICE inline T log(T x) {
|
| 151 |
+
static_assert(
|
| 152 |
+
!std::is_same_v<T, double>,
|
| 153 |
+
"this template must be used with float or less precise type");
|
| 154 |
+
#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
|
| 155 |
+
// use __logf fast approximation for peak bandwidth
|
| 156 |
+
return __logf(x);
|
| 157 |
+
#else
|
| 158 |
+
return ::log(x);
|
| 159 |
+
#endif
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
template <>
|
| 163 |
+
C10_HOST_DEVICE inline double log<double>(double x) {
|
| 164 |
+
return ::log(x);
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
template <typename T>
|
| 168 |
+
C10_HOST_DEVICE inline T log1p(T x) {
|
| 169 |
+
static_assert(
|
| 170 |
+
!std::is_same_v<T, double>,
|
| 171 |
+
"this template must be used with float or less precise type");
|
| 172 |
+
#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
|
| 173 |
+
// use __logf fast approximation for peak bandwidth
|
| 174 |
+
// NOTE: There is no __log1pf so unfortunately we lose precision.
|
| 175 |
+
return __logf(1.0f + x);
|
| 176 |
+
#else
|
| 177 |
+
return ::log1p(x);
|
| 178 |
+
#endif
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
template <>
|
| 182 |
+
C10_HOST_DEVICE inline double log1p<double>(double x) {
|
| 183 |
+
return ::log1p(x);
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
template <typename T>
|
| 187 |
+
C10_HOST_DEVICE inline T tan(T x) {
|
| 188 |
+
static_assert(
|
| 189 |
+
!std::is_same_v<T, double>,
|
| 190 |
+
"this template must be used with float or less precise type");
|
| 191 |
+
#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
|
| 192 |
+
// use __tanf fast approximation for peak bandwidth
|
| 193 |
+
return __tanf(x);
|
| 194 |
+
#else
|
| 195 |
+
return ::tan(x);
|
| 196 |
+
#endif
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
template <>
|
| 200 |
+
C10_HOST_DEVICE inline double tan<double>(double x) {
|
| 201 |
+
return ::tan(x);
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
} // namespace at
|
| 205 |
+
|
| 206 |
+
#else
|
| 207 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 208 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/ParallelOpenMP.h
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <algorithm>
|
| 5 |
+
#include <atomic>
|
| 6 |
+
#include <cstddef>
|
| 7 |
+
#include <exception>
|
| 8 |
+
|
| 9 |
+
#ifdef _OPENMP
|
| 10 |
+
#define INTRA_OP_PARALLEL
|
| 11 |
+
|
| 12 |
+
#include <omp.h>
|
| 13 |
+
#endif
|
| 14 |
+
|
| 15 |
+
#ifdef _OPENMP
|
| 16 |
+
namespace at::internal {
|
| 17 |
+
template <typename F>
|
| 18 |
+
inline void invoke_parallel(
|
| 19 |
+
int64_t begin,
|
| 20 |
+
int64_t end,
|
| 21 |
+
int64_t grain_size,
|
| 22 |
+
const F& f) {
|
| 23 |
+
std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
|
| 24 |
+
std::exception_ptr eptr;
|
| 25 |
+
|
| 26 |
+
#pragma omp parallel
|
| 27 |
+
{
|
| 28 |
+
// choose number of tasks based on grain size and number of threads
|
| 29 |
+
// can't use num_threads clause due to bugs in GOMP's thread pool (See
|
| 30 |
+
// #32008)
|
| 31 |
+
int64_t num_threads = omp_get_num_threads();
|
| 32 |
+
if (grain_size > 0) {
|
| 33 |
+
num_threads = std::min(num_threads, divup((end - begin), grain_size));
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
int64_t tid = omp_get_thread_num();
|
| 37 |
+
int64_t chunk_size = divup((end - begin), num_threads);
|
| 38 |
+
int64_t begin_tid = begin + tid * chunk_size;
|
| 39 |
+
if (begin_tid < end) {
|
| 40 |
+
try {
|
| 41 |
+
internal::ThreadIdGuard tid_guard(tid);
|
| 42 |
+
f(begin_tid, std::min(end, chunk_size + begin_tid));
|
| 43 |
+
} catch (...) {
|
| 44 |
+
if (!err_flag.test_and_set()) {
|
| 45 |
+
eptr = std::current_exception();
|
| 46 |
+
}
|
| 47 |
+
}
|
| 48 |
+
}
|
| 49 |
+
}
|
| 50 |
+
if (eptr) {
|
| 51 |
+
std::rethrow_exception(eptr);
|
| 52 |
+
}
|
| 53 |
+
}
|
| 54 |
+
} // namespace at::internal
|
| 55 |
+
#endif // _OPENMP
|
| 56 |
+
|
| 57 |
+
#else
|
| 58 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 59 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/RedispatchFunctions.h
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/RegistrationDeclarations.h
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/SDPBackend.h
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
#include <cstdint>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
|
| 7 |
+
constexpr int32_t num_sdp_backends = 5;
|
| 8 |
+
enum class SDPBackend {
|
| 9 |
+
error = -1,
|
| 10 |
+
math = 0,
|
| 11 |
+
flash_attention = 1,
|
| 12 |
+
efficient_attention = 2,
|
| 13 |
+
cudnn_attention = 3,
|
| 14 |
+
overrideable = 4
|
| 15 |
+
};
|
| 16 |
+
|
| 17 |
+
} // namespace at
|
| 18 |
+
|
| 19 |
+
#else
|
| 20 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 21 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/Scalar.h
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/core/Scalar.h>
|
| 5 |
+
|
| 6 |
+
#else
|
| 7 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 8 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/StorageUtils.h
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <c10/core/Storage.h>
|
| 5 |
+
#include <c10/core/StorageImpl.h>
|
| 6 |
+
#include <c10/util/intrusive_ptr.h>
|
| 7 |
+
|
| 8 |
+
namespace at {
|
| 9 |
+
|
| 10 |
+
class TensorBase;
|
| 11 |
+
|
| 12 |
+
// Here we define a series of utils to create/manipulate ATen backed
|
| 13 |
+
// c10 storage implementations.
|
| 14 |
+
|
| 15 |
+
/**
|
| 16 |
+
* Create a new shared memory storage impl managed by file descriptor
|
| 17 |
+
*
|
| 18 |
+
* @param size size in bytes
|
| 19 |
+
*/
|
| 20 |
+
C10_EXPORT c10::intrusive_ptr<c10::StorageImpl> new_shm_fd_storage(size_t size);
|
| 21 |
+
|
| 22 |
+
/**
|
| 23 |
+
* Copy src to dst
|
| 24 |
+
* Caller must guarantee the validness of the storage objects
|
| 25 |
+
* during the entire copy process, esp. when it's async.
|
| 26 |
+
*
|
| 27 |
+
* This can probably live in c10 namespace later if needed,
|
| 28 |
+
* but for now keep it in at to keep implementation simple.
|
| 29 |
+
*
|
| 30 |
+
* @param dst dst tensor
|
| 31 |
+
* @param src src tensor
|
| 32 |
+
* @param non_blocking (default false) whether this operation blocks caller
|
| 33 |
+
*/
|
| 34 |
+
C10_EXPORT void storage_copy(
|
| 35 |
+
c10::Storage& dst,
|
| 36 |
+
const c10::Storage& src,
|
| 37 |
+
bool non_blocking = false);
|
| 38 |
+
|
| 39 |
+
/**
|
| 40 |
+
* In place change the storage to shm based.
|
| 41 |
+
*
|
| 42 |
+
* This is only applicable to CPU tensors not already shared.
|
| 43 |
+
* Otherwise, it's a no op to mirror the THP tensor behavior:
|
| 44 |
+
* https://pytorch.org/docs/stable/generated/torch.Tensor.share_memory_.html
|
| 45 |
+
*
|
| 46 |
+
* @param t a tensor
|
| 47 |
+
*/
|
| 48 |
+
C10_EXPORT void share_memory_(TensorBase& t);
|
| 49 |
+
|
| 50 |
+
} // namespace at
|
| 51 |
+
|
| 52 |
+
#else
|
| 53 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 54 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/TensorAccessor.h
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
#include <ATen/core/TensorAccessor.h>
|
| 4 |
+
|
| 5 |
+
#else
|
| 6 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 7 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/ThreadLocalPythonObjects.h
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <c10/core/SafePyObject.h>
|
| 5 |
+
#include <c10/macros/Macros.h>
|
| 6 |
+
#include <unordered_map>
|
| 7 |
+
|
| 8 |
+
namespace at::impl {
|
| 9 |
+
|
| 10 |
+
struct TORCH_API ThreadLocalPythonObjects {
|
| 11 |
+
static void set(const std::string& key, std::shared_ptr<SafePyObject> value);
|
| 12 |
+
static const std::shared_ptr<SafePyObject>& get(const std::string& key);
|
| 13 |
+
static bool contains(const std::string& key);
|
| 14 |
+
|
| 15 |
+
static const ThreadLocalPythonObjects& get_state();
|
| 16 |
+
static void set_state(ThreadLocalPythonObjects state);
|
| 17 |
+
|
| 18 |
+
private:
|
| 19 |
+
std::unordered_map<std::string, std::shared_ptr<c10::SafePyObject>> obj_dict_;
|
| 20 |
+
};
|
| 21 |
+
|
| 22 |
+
} // namespace at::impl
|
| 23 |
+
|
| 24 |
+
#else
|
| 25 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 26 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/ThreadLocalState.h
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <c10/core/InferenceMode.h>
|
| 5 |
+
#include <c10/core/impl/LocalDispatchKeySet.h>
|
| 6 |
+
#include <c10/util/Exception.h>
|
| 7 |
+
#include <c10/util/ThreadLocalDebugInfo.h>
|
| 8 |
+
|
| 9 |
+
#include <ATen/FuncTorchTLS.h>
|
| 10 |
+
#include <ATen/PythonTorchFunctionTLS.h>
|
| 11 |
+
#include <ATen/SavedTensorHooks.h>
|
| 12 |
+
#include <ATen/ThreadLocalPythonObjects.h>
|
| 13 |
+
#include <ATen/record_function.h>
|
| 14 |
+
#include <c10/core/impl/PythonDispatcherTLS.h>
|
| 15 |
+
#include <c10/core/impl/TorchDispatchModeTLS.h>
|
| 16 |
+
|
| 17 |
+
namespace at {
|
| 18 |
+
|
| 19 |
+
// Thread local state contains values that are preserved across
|
| 20 |
+
// thread boundaries (e.g. at::launch/JIT fork, autograd).
|
| 21 |
+
// Note at::parallel_for doesn't preserve TLS across thread boundaries.
|
| 22 |
+
class TORCH_API ThreadLocalState {
|
| 23 |
+
public:
|
| 24 |
+
// Saves the thread local variables' values and
|
| 25 |
+
// returns them as a ThreadLocalState
|
| 26 |
+
ThreadLocalState();
|
| 27 |
+
|
| 28 |
+
// set_grad_mode - force the value of the grad mode TLS in
|
| 29 |
+
// the current state object. This is used for example in the
|
| 30 |
+
// autograd engine.
|
| 31 |
+
void set_grad_mode(bool enabled);
|
| 32 |
+
|
| 33 |
+
// set_multithreading_enabled - force the value of the multithreadinmaximum
|
| 34 |
+
// threads TLS in
|
| 35 |
+
// the current state object. This is used for example in the
|
| 36 |
+
// autograd engine.
|
| 37 |
+
void set_multithreading_enabled(bool enabled);
|
| 38 |
+
|
| 39 |
+
// Sets thread local variables in the current thread,
|
| 40 |
+
// according to the thread boundary specified
|
| 41 |
+
static void setThreadLocalState(const ThreadLocalState& state);
|
| 42 |
+
|
| 43 |
+
private:
|
| 44 |
+
c10::impl::LocalDispatchKeySet dispatch_key_;
|
| 45 |
+
|
| 46 |
+
// ThreadLocalDebugInfo does not change after being created
|
| 47 |
+
// with DebugInfoGuard
|
| 48 |
+
std::shared_ptr<c10::ThreadLocalDebugInfo> debug_info_;
|
| 49 |
+
|
| 50 |
+
// RecordFunction TLS
|
| 51 |
+
RecordFunctionTLS rf_tls_;
|
| 52 |
+
|
| 53 |
+
// TLS for out-of-tree functorch
|
| 54 |
+
// See NOTE [functorch TLS in pytorch/pytorch] for why this needs to be a
|
| 55 |
+
// pointer (spoiler alert: it's due to the indirection)
|
| 56 |
+
// This needs to be a shared_ptr instead of a unique_ptr because
|
| 57 |
+
// ThreadLocalState is copy-able and does indeed get copied. Maybe we can
|
| 58 |
+
// consider adding an explicit copy constructor for ThreadLocalState in the
|
| 59 |
+
// future but I didn't want to add one just for this.
|
| 60 |
+
std::shared_ptr<const functorch::FuncTorchTLSBase> functorch_tls_;
|
| 61 |
+
|
| 62 |
+
// TLS for AutogradModes
|
| 63 |
+
AutogradState autograd_tls_;
|
| 64 |
+
|
| 65 |
+
// TLS for enable_torch_dispatch_mode
|
| 66 |
+
c10::impl::TorchDispatchModeTLS torch_dispatch_mode_state_;
|
| 67 |
+
|
| 68 |
+
// TLS for enable_python_dispatcher
|
| 69 |
+
c10::impl::PyInterpreter* python_dispatcher_state_;
|
| 70 |
+
|
| 71 |
+
// TLS for __torch_function__ (mode and disable_torch_function)
|
| 72 |
+
at::impl::PythonTorchFunctionTLS python_torch_function_state_;
|
| 73 |
+
|
| 74 |
+
// TLS for saved tensors default hooks
|
| 75 |
+
at::impl::SavedTensorDefaultHooksTLS saved_tensors_default_hooks_state_;
|
| 76 |
+
|
| 77 |
+
bool functionalization_reapply_views_state_;
|
| 78 |
+
|
| 79 |
+
bool dtensor_allow_implicit_replication_;
|
| 80 |
+
|
| 81 |
+
// TLS for arbitrary python objects that is registered via hooks
|
| 82 |
+
at::impl::ThreadLocalPythonObjects saved_objects_;
|
| 83 |
+
|
| 84 |
+
#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && \
|
| 85 |
+
!defined(BUILD_LITE_INTERPRETER)
|
| 86 |
+
// TLS for autocast dtypes
|
| 87 |
+
std::array<at::ScalarType, at::COMPILE_TIME_MAX_DEVICE_TYPES>
|
| 88 |
+
autocast_dtypes_{};
|
| 89 |
+
#endif
|
| 90 |
+
|
| 91 |
+
friend class ThreadLocalStateGuard;
|
| 92 |
+
};
|
| 93 |
+
|
| 94 |
+
// Guard to set and reset the thread local state
|
| 95 |
+
class TORCH_API ThreadLocalStateGuard {
|
| 96 |
+
public:
|
| 97 |
+
explicit ThreadLocalStateGuard(const ThreadLocalState& state)
|
| 98 |
+
: prev_state_(ThreadLocalState()) {
|
| 99 |
+
// set the given state across the thread boundary
|
| 100 |
+
ThreadLocalState::setThreadLocalState(state);
|
| 101 |
+
}
|
| 102 |
+
ThreadLocalStateGuard(ThreadLocalStateGuard&& other) = delete;
|
| 103 |
+
ThreadLocalStateGuard(const ThreadLocalStateGuard&) = delete;
|
| 104 |
+
ThreadLocalStateGuard& operator=(const ThreadLocalStateGuard&) = delete;
|
| 105 |
+
ThreadLocalStateGuard& operator=(ThreadLocalStateGuard&&) = delete;
|
| 106 |
+
|
| 107 |
+
~ThreadLocalStateGuard() {
|
| 108 |
+
// restore previously set variables
|
| 109 |
+
ThreadLocalState::setThreadLocalState(prev_state_);
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
private:
|
| 113 |
+
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
| 114 |
+
const ThreadLocalState prev_state_;
|
| 115 |
+
};
|
| 116 |
+
|
| 117 |
+
template <typename T>
|
| 118 |
+
auto wrapPropagateTLSState(T callback) {
|
| 119 |
+
return [tls_state = ThreadLocalState(),
|
| 120 |
+
callback = std::move(callback)](auto&&... args) {
|
| 121 |
+
ThreadLocalStateGuard g(tls_state);
|
| 122 |
+
// Propagate value returned by callback().
|
| 123 |
+
return callback(std::forward<decltype(args)>(args)...);
|
| 124 |
+
};
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
} // namespace at
|
| 128 |
+
|
| 129 |
+
#else
|
| 130 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 131 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/Utils.h
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/EmptyTensor.h>
|
| 5 |
+
#include <ATen/Formatting.h>
|
| 6 |
+
#include <ATen/core/ATenGeneral.h>
|
| 7 |
+
#include <ATen/core/Generator.h>
|
| 8 |
+
#include <c10/core/ScalarType.h>
|
| 9 |
+
#include <c10/core/StorageImpl.h>
|
| 10 |
+
#include <c10/core/UndefinedTensorImpl.h>
|
| 11 |
+
#include <c10/util/ArrayRef.h>
|
| 12 |
+
#include <c10/util/Exception.h>
|
| 13 |
+
#include <c10/util/accumulate.h>
|
| 14 |
+
#include <c10/util/irange.h>
|
| 15 |
+
|
| 16 |
+
#include <algorithm>
|
| 17 |
+
|
| 18 |
+
#define AT_DISALLOW_COPY_AND_ASSIGN(TypeName) \
|
| 19 |
+
TypeName(const TypeName&) = delete; \
|
| 20 |
+
void operator=(const TypeName&) = delete
|
| 21 |
+
|
| 22 |
+
namespace at {
|
| 23 |
+
|
| 24 |
+
TORCH_API int _crash_if_asan(int /*arg*/);
|
| 25 |
+
|
| 26 |
+
// Converts a TensorList (i.e. ArrayRef<Tensor> to vector of TensorImpl*)
|
| 27 |
+
// NB: This is ONLY used by legacy TH bindings, and ONLY used by cat.
|
| 28 |
+
// Once cat is ported entirely to ATen this can be deleted!
|
| 29 |
+
inline std::vector<TensorImpl*> checked_dense_tensor_list_unwrap(
|
| 30 |
+
ArrayRef<Tensor> tensors,
|
| 31 |
+
const char* name,
|
| 32 |
+
int pos,
|
| 33 |
+
c10::DeviceType device_type,
|
| 34 |
+
ScalarType scalar_type) {
|
| 35 |
+
std::vector<TensorImpl*> unwrapped;
|
| 36 |
+
unwrapped.reserve(tensors.size());
|
| 37 |
+
for (const auto i : c10::irange(tensors.size())) {
|
| 38 |
+
const auto& expr = tensors[i];
|
| 39 |
+
if (expr.layout() != Layout::Strided) {
|
| 40 |
+
TORCH_CHECK(
|
| 41 |
+
false,
|
| 42 |
+
"Expected dense tensor but got ",
|
| 43 |
+
expr.layout(),
|
| 44 |
+
" for sequence element ",
|
| 45 |
+
i,
|
| 46 |
+
" in sequence argument at position #",
|
| 47 |
+
pos,
|
| 48 |
+
" '",
|
| 49 |
+
name,
|
| 50 |
+
"'");
|
| 51 |
+
}
|
| 52 |
+
if (expr.device().type() != device_type) {
|
| 53 |
+
TORCH_CHECK(
|
| 54 |
+
false,
|
| 55 |
+
"Expected object of device type ",
|
| 56 |
+
device_type,
|
| 57 |
+
" but got device type ",
|
| 58 |
+
expr.device().type(),
|
| 59 |
+
" for sequence element ",
|
| 60 |
+
i,
|
| 61 |
+
" in sequence argument at position #",
|
| 62 |
+
pos,
|
| 63 |
+
" '",
|
| 64 |
+
name,
|
| 65 |
+
"'");
|
| 66 |
+
}
|
| 67 |
+
if (expr.scalar_type() != scalar_type) {
|
| 68 |
+
TORCH_CHECK(
|
| 69 |
+
false,
|
| 70 |
+
"Expected object of scalar type ",
|
| 71 |
+
scalar_type,
|
| 72 |
+
" but got scalar type ",
|
| 73 |
+
expr.scalar_type(),
|
| 74 |
+
" for sequence element ",
|
| 75 |
+
i,
|
| 76 |
+
" in sequence argument at position #",
|
| 77 |
+
pos,
|
| 78 |
+
" '",
|
| 79 |
+
name,
|
| 80 |
+
"'");
|
| 81 |
+
}
|
| 82 |
+
unwrapped.emplace_back(expr.unsafeGetTensorImpl());
|
| 83 |
+
}
|
| 84 |
+
return unwrapped;
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
template <size_t N>
|
| 88 |
+
std::array<int64_t, N> check_intlist(
|
| 89 |
+
ArrayRef<int64_t> list,
|
| 90 |
+
const char* name,
|
| 91 |
+
int pos) {
|
| 92 |
+
if (list.empty()) {
|
| 93 |
+
// TODO: is this necessary? We used to treat nullptr-vs-not in IntList
|
| 94 |
+
// differently with strides as a way of faking optional.
|
| 95 |
+
list = {};
|
| 96 |
+
}
|
| 97 |
+
auto res = std::array<int64_t, N>();
|
| 98 |
+
if (list.size() == 1 && N > 1) {
|
| 99 |
+
res.fill(list[0]);
|
| 100 |
+
return res;
|
| 101 |
+
}
|
| 102 |
+
if (list.size() != N) {
|
| 103 |
+
TORCH_CHECK(
|
| 104 |
+
false,
|
| 105 |
+
"Expected a list of ",
|
| 106 |
+
N,
|
| 107 |
+
" ints but got ",
|
| 108 |
+
list.size(),
|
| 109 |
+
" for argument #",
|
| 110 |
+
pos,
|
| 111 |
+
" '",
|
| 112 |
+
name,
|
| 113 |
+
"'");
|
| 114 |
+
}
|
| 115 |
+
std::copy_n(list.begin(), N, res.begin());
|
| 116 |
+
return res;
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
using at::detail::check_size_nonnegative;
|
| 120 |
+
|
| 121 |
+
namespace detail {
|
| 122 |
+
|
| 123 |
+
template <typename T>
|
| 124 |
+
TORCH_API Tensor tensor_cpu(ArrayRef<T> values, const TensorOptions& options);
|
| 125 |
+
|
| 126 |
+
template <typename T>
|
| 127 |
+
TORCH_API Tensor
|
| 128 |
+
tensor_backend(ArrayRef<T> values, const TensorOptions& options);
|
| 129 |
+
|
| 130 |
+
template <typename T>
|
| 131 |
+
TORCH_API Tensor
|
| 132 |
+
tensor_complex_cpu(ArrayRef<T> values, const TensorOptions& options);
|
| 133 |
+
|
| 134 |
+
template <typename T>
|
| 135 |
+
TORCH_API Tensor
|
| 136 |
+
tensor_complex_backend(ArrayRef<T> values, const TensorOptions& options);
|
| 137 |
+
} // namespace detail
|
| 138 |
+
|
| 139 |
+
} // namespace at
|
| 140 |
+
|
| 141 |
+
#else
|
| 142 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 143 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/ATen/cpp_custom_type_hack.h
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 3 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 4 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 5 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 6 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 7 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 8 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 9 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 10 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 11 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 12 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 13 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 14 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 15 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 16 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 17 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 18 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 19 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 20 |
+
|
| 21 |
+
// YOU ARE IN THE WRONG PLACE! TURN BACK NOW!
|
| 22 |
+
|
| 23 |
+
// This code was a temporary hack to enable embedding arbitrary C++ structures
|
| 24 |
+
// into Tensors. THIS IS UNSAFE AND IS NOT SUPPORTED. IF YOU USE THIS CODE,
|
| 25 |
+
// IT __WILL__ BREAK.
|
| 26 |
+
|
| 27 |
+
// This code has been superseded by custom classes:
|
| 28 |
+
// https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html
|
| 29 |
+
|
| 30 |
+
// Please use custom classes and **DO NOT ADD MORE CALLSITES TO THINGS DEFINED
|
| 31 |
+
// IN THIS FILE**.
|
| 32 |
+
|
| 33 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 34 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 35 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 36 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 37 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 38 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 39 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 40 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 41 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 42 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 43 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 44 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 45 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 46 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 47 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 48 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 49 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 50 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 51 |
+
|
| 52 |
+
#include <ATen/TracerMode.h>
|
| 53 |
+
#include <ATen/core/Tensor.h>
|
| 54 |
+
|
| 55 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 56 |
+
#include <ATen/Functions.h>
|
| 57 |
+
#else
|
| 58 |
+
#include <ATen/ops/empty.h>
|
| 59 |
+
#endif
|
| 60 |
+
|
| 61 |
+
namespace at::cpp_custom_type_hack {
|
| 62 |
+
|
| 63 |
+
template <typename T>
|
| 64 |
+
[[deprecated(
|
| 65 |
+
"Use custom classes instead: "
|
| 66 |
+
"https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html")]] bool
|
| 67 |
+
isa(const Tensor& packed) {
|
| 68 |
+
return (packed.scalar_type() == kByte) &&
|
| 69 |
+
(packed.storage().data_ptr().get_deleter() ==
|
| 70 |
+
caffe2::TypeMeta::Make<T>().deleteFn());
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
template <typename T>
|
| 74 |
+
[[deprecated(
|
| 75 |
+
"Use custom classes instead: "
|
| 76 |
+
"https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html")]] T&
|
| 77 |
+
cast(const Tensor& packed) {
|
| 78 |
+
TORCH_CHECK(
|
| 79 |
+
packed.scalar_type() == kByte, "Expected temporary cpp type wrapper");
|
| 80 |
+
TORCH_CHECK(
|
| 81 |
+
packed.storage().data_ptr().get_deleter() ==
|
| 82 |
+
caffe2::TypeMeta::Make<T>().deleteFn(),
|
| 83 |
+
"Expected temporary cpp type wrapper of type ",
|
| 84 |
+
caffe2::TypeMeta::TypeName<T>());
|
| 85 |
+
return *reinterpret_cast<T*>(packed.storage().data_ptr().get());
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
template <typename T>
|
| 89 |
+
[[deprecated(
|
| 90 |
+
"Use custom classes instead: "
|
| 91 |
+
"https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html")]] Tensor
|
| 92 |
+
create(std::unique_ptr<T> ptr, TensorOptions options) {
|
| 93 |
+
// None of this should trace, so turn off Tracer dispatching
|
| 94 |
+
at::AutoDispatchBelowADInplaceOrView guard; // TODO: remove
|
| 95 |
+
at::tracer::impl::NoTracerDispatchMode tracer_guard;
|
| 96 |
+
|
| 97 |
+
// We store this instance away in a Tensor and register a deleter function
|
| 98 |
+
// so that we do not leak memory. On the other side, we pull out the storage's
|
| 99 |
+
// data_ptr and get the right typed pointer.
|
| 100 |
+
void* raw_ptr = ptr.release();
|
| 101 |
+
at::DataPtr at_ptr(
|
| 102 |
+
raw_ptr, raw_ptr, caffe2::TypeMeta::Make<T>().deleteFn(), at::kCPU);
|
| 103 |
+
|
| 104 |
+
// size doesn't really matter, but we can align it to the actual size
|
| 105 |
+
// returning variables because one likely want to use this hack from python
|
| 106 |
+
auto retval = at::empty({sizeof(T)}, options.device(kCPU).dtype(at::kByte));
|
| 107 |
+
retval.storage().set_data_ptr_noswap(std::move(at_ptr));
|
| 108 |
+
return retval;
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
} // namespace at::cpp_custom_type_hack
|
| 112 |
+
|
| 113 |
+
#else
|
| 114 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 115 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/THC/THCAtomics.cuh
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
// TODO: Remove once torchvision has been updated to use the ATen header
|
| 4 |
+
#include <ATen/cuda/Atomic.cuh>
|
| 5 |
+
|
| 6 |
+
#else
|
| 7 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 8 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/THC/THCDeviceUtils.cuh
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
// TODO: Remove this header
|
| 4 |
+
#include <ATen/cuda/DeviceUtils.cuh>
|
| 5 |
+
|
| 6 |
+
#else
|
| 7 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 8 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/ConvUtils.h
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
/*
|
| 3 |
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 4 |
+
* All rights reserved.
|
| 5 |
+
*
|
| 6 |
+
* This source code is licensed under the BSD-style license found in the
|
| 7 |
+
* LICENSE file in the root directory of this source tree.
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include <array>
|
| 13 |
+
#include <stdexcept>
|
| 14 |
+
#include <string>
|
| 15 |
+
#include <type_traits>
|
| 16 |
+
|
| 17 |
+
namespace fbgemm {
|
| 18 |
+
|
| 19 |
+
template <int N, int... Vals>
|
| 20 |
+
constexpr std::enable_if_t<N == sizeof...(Vals), std::array<int, N>>
|
| 21 |
+
array_of_ones() {
|
| 22 |
+
return std::array<int, N>{{Vals...}};
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
template <int N, int... Vals>
|
| 26 |
+
constexpr std::enable_if_t<N != sizeof...(Vals), std::array<int, N>>
|
| 27 |
+
array_of_ones() {
|
| 28 |
+
return array_of_ones<N, Vals..., 1>();
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
template <int N, int... Vals>
|
| 32 |
+
constexpr std::enable_if_t<N == sizeof...(Vals), std::array<int, N>>
|
| 33 |
+
array_of_zeroes() {
|
| 34 |
+
return std::array<int, N>{{Vals...}};
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
template <int N, int... Vals>
|
| 38 |
+
constexpr std::enable_if_t<N != sizeof...(Vals), std::array<int, N>>
|
| 39 |
+
array_of_zeroes() {
|
| 40 |
+
return array_of_zeroes<N, Vals..., 0>();
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
/**
|
| 44 |
+
* @brief A struct to conveniently store all convolution parameters.
|
| 45 |
+
*/
|
| 46 |
+
template <int SPATIAL_DIM = 2>
|
| 47 |
+
struct conv_param_t {
|
| 48 |
+
int MB; ///< Mini Batch size
|
| 49 |
+
int IC; ///< Number of Input Channels
|
| 50 |
+
int OC; ///< Number of Output Channels
|
| 51 |
+
std::array<int, SPATIAL_DIM> IN_DIM; ///< Input Image Dimension
|
| 52 |
+
int G; ///< Number of Groups
|
| 53 |
+
std::array<int, SPATIAL_DIM> K; ///< Filter (Kernel) dimensions
|
| 54 |
+
std::array<int, SPATIAL_DIM> stride; //< Strides
|
| 55 |
+
std::array<int, SPATIAL_DIM * 2>
|
| 56 |
+
pad; //< Padding (first SPATIAL_DIM is for prev/top/left padding, second
|
| 57 |
+
// SPATIAL_DIM is for next/bottom/right padding)
|
| 58 |
+
std::array<int, SPATIAL_DIM> dilation; //< Kernel dilation
|
| 59 |
+
|
| 60 |
+
// The following are derived parameters
|
| 61 |
+
std::array<int, SPATIAL_DIM> OUT_DIM; //< Output Image Dimension
|
| 62 |
+
std::array<int, SPATIAL_DIM> IN_DIMP; //< Input Image Dimension Padded
|
| 63 |
+
|
| 64 |
+
// The following is for tranposed convolution
|
| 65 |
+
std::array<int, SPATIAL_DIM>
|
| 66 |
+
output_pad; //< Padding (next/bottom/right padding in output buffer)
|
| 67 |
+
bool transposed;
|
| 68 |
+
|
| 69 |
+
/**
|
| 70 |
+
* @brief Constructor for initializing the convolution parameters.
|
| 71 |
+
*/
|
| 72 |
+
conv_param_t(
|
| 73 |
+
int mb,
|
| 74 |
+
int ic,
|
| 75 |
+
int oc,
|
| 76 |
+
std::array<int, SPATIAL_DIM> in_dim,
|
| 77 |
+
int g,
|
| 78 |
+
std::array<int, SPATIAL_DIM> k,
|
| 79 |
+
std::array<int, SPATIAL_DIM> strd,
|
| 80 |
+
std::array<int, SPATIAL_DIM * 2> pd,
|
| 81 |
+
std::array<int, SPATIAL_DIM> dilations = array_of_ones<SPATIAL_DIM>(),
|
| 82 |
+
std::array<int, SPATIAL_DIM> otpt_pd = array_of_zeroes<SPATIAL_DIM>(),
|
| 83 |
+
bool transposed = false)
|
| 84 |
+
: MB(mb),
|
| 85 |
+
IC(ic),
|
| 86 |
+
OC(oc),
|
| 87 |
+
IN_DIM(in_dim),
|
| 88 |
+
G(g),
|
| 89 |
+
K(k),
|
| 90 |
+
stride(strd),
|
| 91 |
+
pad(pd),
|
| 92 |
+
dilation(dilations),
|
| 93 |
+
output_pad(otpt_pd),
|
| 94 |
+
transposed(transposed) {
|
| 95 |
+
if (ic % g != 0) {
|
| 96 |
+
throw std::runtime_error(
|
| 97 |
+
"groups = " + std::to_string(g) +
|
| 98 |
+
" does not divide number of input channels = " + std::to_string(ic));
|
| 99 |
+
}
|
| 100 |
+
if (oc % g != 0) {
|
| 101 |
+
throw std::runtime_error(
|
| 102 |
+
"groups = " + std::to_string(g) +
|
| 103 |
+
" does not divide number of output channels = " + std::to_string(oc));
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
for (int d = 0; d < SPATIAL_DIM; ++d) {
|
| 107 |
+
if (transposed) {
|
| 108 |
+
this->IN_DIMP[d] = this->IN_DIM[d] +
|
| 109 |
+
(this->dilation[d] * (this->K[d] - 1) - this->pad[d]) +
|
| 110 |
+
(this->dilation[d] * (this->K[d] - 1) - this->pad[SPATIAL_DIM + d]);
|
| 111 |
+
this->OUT_DIM[d] = (this->IN_DIM[d] - 1) * this->stride[d] -
|
| 112 |
+
this->pad[d] - this->pad[SPATIAL_DIM + d] +
|
| 113 |
+
this->dilation[d] * (this->K[d] - 1) + output_pad[d] + 1;
|
| 114 |
+
} else {
|
| 115 |
+
IN_DIMP[d] = IN_DIM[d] + pad[d] + pad[SPATIAL_DIM + d];
|
| 116 |
+
OUT_DIM[d] =
|
| 117 |
+
(IN_DIMP[d] - dilation[d] * (K[d] - 1) - 1) / stride[d] + 1;
|
| 118 |
+
}
|
| 119 |
+
}
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
/**
|
| 123 |
+
* @brief Helper function to get convolution parameters as string.
|
| 124 |
+
*/
|
| 125 |
+
std::string toString() const {
|
| 126 |
+
std::string dim_string[3] = {"T", "H", "W"};
|
| 127 |
+
|
| 128 |
+
std::string out;
|
| 129 |
+
out += "MB:" + std::to_string(MB) + ", ";
|
| 130 |
+
out += "IC:" + std::to_string(IC) + ", ";
|
| 131 |
+
out += "OC:" + std::to_string(OC) + ", ";
|
| 132 |
+
if constexpr (SPATIAL_DIM <= 3) {
|
| 133 |
+
for (int d = 0; d < SPATIAL_DIM; ++d) {
|
| 134 |
+
out += "I" + dim_string[3 - SPATIAL_DIM + d] + ":" +
|
| 135 |
+
std::to_string(IN_DIM[d]) + ", ";
|
| 136 |
+
}
|
| 137 |
+
} else {
|
| 138 |
+
for (int d = 0; d < SPATIAL_DIM; ++d) {
|
| 139 |
+
out += "I" + std::to_string(d) + ":" + std::to_string(IN_DIM[d]) + ", ";
|
| 140 |
+
}
|
| 141 |
+
}
|
| 142 |
+
out += "G:" + std::to_string(G) + ", ";
|
| 143 |
+
if constexpr (SPATIAL_DIM <= 3) {
|
| 144 |
+
for (int d = 0; d < SPATIAL_DIM; ++d) {
|
| 145 |
+
out += "K" + dim_string[3 - SPATIAL_DIM + d] + ":" +
|
| 146 |
+
std::to_string(K[d]) + ", ";
|
| 147 |
+
}
|
| 148 |
+
for (int d = 0; d < SPATIAL_DIM; ++d) {
|
| 149 |
+
out += "stride_" + dim_string[3 - SPATIAL_DIM + d] + ":" +
|
| 150 |
+
std::to_string(stride[d]) + ", ";
|
| 151 |
+
}
|
| 152 |
+
for (int d = 0; d < SPATIAL_DIM * 2; ++d) {
|
| 153 |
+
out += "pad_" + dim_string[3 - SPATIAL_DIM + (d % SPATIAL_DIM)] + ":" +
|
| 154 |
+
std::to_string(pad[d]) + ", ";
|
| 155 |
+
}
|
| 156 |
+
for (int d = 0; d < SPATIAL_DIM; ++d) {
|
| 157 |
+
out += "dilation_" + dim_string[3 - SPATIAL_DIM + d] + ":" +
|
| 158 |
+
std::to_string(dilation[d]);
|
| 159 |
+
if (d < SPATIAL_DIM - 1) {
|
| 160 |
+
out += ", ";
|
| 161 |
+
}
|
| 162 |
+
}
|
| 163 |
+
} else {
|
| 164 |
+
for (int d = 0; d < SPATIAL_DIM; ++d) {
|
| 165 |
+
out += "K" + std::to_string(d) + ":" + std::to_string(K[d]) + ", ";
|
| 166 |
+
}
|
| 167 |
+
for (int d = 0; d < SPATIAL_DIM; ++d) {
|
| 168 |
+
out += "stride_" + std::to_string(d) + ":" + std::to_string(stride[d]) +
|
| 169 |
+
", ";
|
| 170 |
+
}
|
| 171 |
+
for (int d = 0; d < SPATIAL_DIM; ++d) {
|
| 172 |
+
out += "pad_" + std::to_string(d) + ":" + std::to_string(pad[d]);
|
| 173 |
+
if (d < SPATIAL_DIM * 2 - 1) {
|
| 174 |
+
out += ", ";
|
| 175 |
+
}
|
| 176 |
+
}
|
| 177 |
+
for (int d = 0; d < SPATIAL_DIM; ++d) {
|
| 178 |
+
out += "dilation_" + std::to_string(d) + ":" +
|
| 179 |
+
std::to_string(dilation[d]) + ", ";
|
| 180 |
+
}
|
| 181 |
+
}
|
| 182 |
+
if (transposed) {
|
| 183 |
+
for (int d = 0; d < SPATIAL_DIM; ++d) {
|
| 184 |
+
out += "output_padding_" + std::to_string(d) + ":" +
|
| 185 |
+
std::to_string(output_pad[d]) + ", ";
|
| 186 |
+
}
|
| 187 |
+
}
|
| 188 |
+
return out;
|
| 189 |
+
}
|
| 190 |
+
};
|
| 191 |
+
} // namespace fbgemm
|
| 192 |
+
|
| 193 |
+
#else
|
| 194 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 195 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/Fbgemm.h
ADDED
|
@@ -0,0 +1,1515 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
/*
|
| 3 |
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 4 |
+
* All rights reserved.
|
| 5 |
+
*
|
| 6 |
+
* This source code is licensed under the BSD-style license found in the
|
| 7 |
+
* LICENSE file in the root directory of this source tree.
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
/**
|
| 13 |
+
* Top level include file for FBGEMM.
|
| 14 |
+
*/
|
| 15 |
+
#include <cassert>
|
| 16 |
+
#include <memory>
|
| 17 |
+
#include "./ConvUtils.h" // @manual
|
| 18 |
+
#include "./FbgemmBuild.h" // @manual
|
| 19 |
+
#include "./FbgemmEmbedding.h" // @manual
|
| 20 |
+
#include "./FbgemmI8DepthwiseAvx2.h" // @manual
|
| 21 |
+
#include "./FbgemmI8DirectconvAvx2.h" // @manual
|
| 22 |
+
#include "./FbgemmI8Spmdm.h" // @manual
|
| 23 |
+
#include "./FloatConversion.h" // @manual
|
| 24 |
+
#include "./QuantUtilsAvx2.h" // @manual
|
| 25 |
+
#include "./Types.h" // @manual
|
| 26 |
+
#include "./Utils.h" // @manual
|
| 27 |
+
|
| 28 |
+
// Turning on this option will print out time breakdown of each stage (e.g.,
|
| 29 |
+
// input packing, the main GEMM kernel, each output processing pipeline).
|
| 30 |
+
// Please note that currently this option won't report accurate timing if
|
| 31 |
+
// multiple threads are used.
|
| 32 |
+
// #define FBGEMM_MEASURE_TIME_BREAKDOWN
|
| 33 |
+
|
| 34 |
+
#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
|
| 35 |
+
#include <chrono>
|
| 36 |
+
#include <iostream>
|
| 37 |
+
extern double packing_time;
|
| 38 |
+
extern double computing_time;
|
| 39 |
+
extern double kernel_time;
|
| 40 |
+
extern double postprocessing_time;
|
| 41 |
+
extern double run_time;
|
| 42 |
+
#endif
|
| 43 |
+
|
| 44 |
+
namespace fbgemm {
|
| 45 |
+
|
| 46 |
+
/**
|
| 47 |
+
* @brief Templatized struct for packing parameters for A and B matrices.
|
| 48 |
+
*
|
| 49 |
+
* @tparam T input type
|
| 50 |
+
* @tparam accT the type used for accumulation
|
| 51 |
+
* @tparam instSet anyarch/avx2/avx512
|
| 52 |
+
* @tparam int8Type an auxiliary template parameter to specialize for 8-bit
|
| 53 |
+
* input types.
|
| 54 |
+
*/
|
| 55 |
+
template <
|
| 56 |
+
typename T,
|
| 57 |
+
typename accT,
|
| 58 |
+
inst_set_t instSet,
|
| 59 |
+
typename int8Type = void>
|
| 60 |
+
struct PackingTraits;
|
| 61 |
+
|
| 62 |
+
// type specialized implementation in an include file
|
| 63 |
+
#include "./PackingTraits-inl.h" // @manual
|
| 64 |
+
|
| 65 |
+
/**
|
| 66 |
+
* @brief Base class for packing matrices for higher GEMM performance.
|
| 67 |
+
*
|
| 68 |
+
* Matrix is tiled into blockRows() * blockCols() blocks.
|
| 69 |
+
* Each block is with size blockRowSize() * blockColSize().
|
| 70 |
+
* This class is designed using CRTP
|
| 71 |
+
* (https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern)
|
| 72 |
+
*
|
| 73 |
+
* @tparam PT actual packing type, e.g., PackAWithRowOffset
|
| 74 |
+
*/
|
| 75 |
+
template <typename PT, typename inpType, typename accType = std::int32_t>
|
| 76 |
+
class PackMatrix {
|
| 77 |
+
public:
|
| 78 |
+
PackMatrix() = delete; // no default constructor
|
| 79 |
+
PackMatrix(const PackMatrix&) = delete; // no copy
|
| 80 |
+
PackMatrix& operator=(const PackMatrix&) = delete; // no copy
|
| 81 |
+
PackMatrix(PackMatrix&&) = delete; // no move
|
| 82 |
+
PackMatrix& operator=(PackMatrix&& rhs) noexcept = delete; // no move
|
| 83 |
+
|
| 84 |
+
/**
|
| 85 |
+
* @param rows total number of rows in the matrix
|
| 86 |
+
* (packed rows can be less than rows).
|
| 87 |
+
* @param cols total number of columns in the matrix
|
| 88 |
+
* @param pmat A buffer to contain the packed matrix.
|
| 89 |
+
* If nullptr, a buffer owned by PackMatrix will be allocated
|
| 90 |
+
* internally to contain the packed matrix.
|
| 91 |
+
* For non-constant matrices like activation matrices, the client
|
| 92 |
+
* code may want to pass a pre-allocated pmat to avoid the
|
| 93 |
+
* overhead of internal memory allocation everytime a PackMatrix
|
| 94 |
+
* is constructed. The client code can query how big patm should
|
| 95 |
+
* be with packedBufferSize function.
|
| 96 |
+
* @param groups when groups > 1, we compute groups number of GEMMs each
|
| 97 |
+
* multiplies A.rows by A.cols/A.groups matrix with
|
| 98 |
+
* B.rows/B.groups by B.cols matrix (in conventional BLAS
|
| 99 |
+
* terminology, this is a batched GEMM but we use the name group
|
| 100 |
+
* to follow deep learning terminology). The result matrix has
|
| 101 |
+
* dimension A.rows by B.cols*B.groups .
|
| 102 |
+
* A.groups must be same as B.groups, A.groups must divide
|
| 103 |
+
* A.cols, and B.groups must divide B.rows and C.cols.
|
| 104 |
+
*/
|
| 105 |
+
PackMatrix(
|
| 106 |
+
std::int32_t rows,
|
| 107 |
+
std::int32_t cols,
|
| 108 |
+
inpType* pmat,
|
| 109 |
+
int groups = 1,
|
| 110 |
+
const BlockingFactors* params = nullptr);
|
| 111 |
+
|
| 112 |
+
/**
|
| 113 |
+
* @return true usually when the matrix is constant matrix (e.g., weight
|
| 114 |
+
* matrices) that can be prepacked
|
| 115 |
+
*/
|
| 116 |
+
bool isPrePacked() const {
|
| 117 |
+
return static_cast<const PT*>(this)->isPrePacked();
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
/**
|
| 121 |
+
* @return true if this is the first input matrix in GEMM (i.e., A in C = A *
|
| 122 |
+
* B)
|
| 123 |
+
*/
|
| 124 |
+
static bool isA() {
|
| 125 |
+
return PT::isA();
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
/**
|
| 129 |
+
* @brief The size of the buffer used for packing (The size is in number of
|
| 130 |
+
* elements).
|
| 131 |
+
*
|
| 132 |
+
* rows and cols are only used for fully packing, i.e., for B matrix. The
|
| 133 |
+
* client code can use this function to query how big the buffer used for
|
| 134 |
+
* packing should be.
|
| 135 |
+
*/
|
| 136 |
+
static int packedBufferSize(
|
| 137 |
+
int rows = 0,
|
| 138 |
+
int cols = 0,
|
| 139 |
+
const BlockingFactors* params = nullptr);
|
| 140 |
+
|
| 141 |
+
FBGEMM_PUSH_WARNING_AND_DISABLE("-Wpragmas")
|
| 142 |
+
FBGEMM_PUSH_WARNING_AND_DISABLE("-Winfinite-recursion")
|
| 143 |
+
/**
|
| 144 |
+
* @return Pointer to a buffer containing row offset results. Some packing
|
| 145 |
+
* objects fuse row offset computation for later requantization step.
|
| 146 |
+
*/
|
| 147 |
+
std::int32_t* getRowOffsetBuffer() const {
|
| 148 |
+
return static_cast<const PT*>(this)->getRowOffsetBuffer();
|
| 149 |
+
}
|
| 150 |
+
/**
|
| 151 |
+
* @brief When k loop is also tiled/blocked, this function is used to check if
|
| 152 |
+
* have executed computations for the last k block so that we can perform
|
| 153 |
+
* post-GEMM operations.
|
| 154 |
+
*/
|
| 155 |
+
bool isThisLastKBlock(int block_id) const {
|
| 156 |
+
return static_cast<const PT*>(this)->isThisLastKBlock(block_id);
|
| 157 |
+
}
|
| 158 |
+
FBGEMM_POP_WARNING
|
| 159 |
+
FBGEMM_POP_WARNING
|
| 160 |
+
|
| 161 |
+
/**
|
| 162 |
+
* @brief Actual packing of a block of the source matrix in pmat buffer.
|
| 163 |
+
*/
|
| 164 |
+
void pack(const block_type_t& block) {
|
| 165 |
+
#if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
|
| 166 |
+
static_cast<PT*>(this)->pack(block);
|
| 167 |
+
#else
|
| 168 |
+
throw std::runtime_error("PackMatrix::pack() not implemented for aarch64");
|
| 169 |
+
#endif // __aarch64__
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
std::int32_t numRows() const {
|
| 173 |
+
return nrows_;
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
std::int32_t numCols() const {
|
| 177 |
+
return ncols_;
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
/**
|
| 181 |
+
* @return The number of rows in each block
|
| 182 |
+
*/
|
| 183 |
+
std::int32_t blockRowSize() const {
|
| 184 |
+
return brow_;
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
/**
|
| 188 |
+
* @return The number of columns in each block
|
| 189 |
+
*/
|
| 190 |
+
std::int32_t blockColSize() const {
|
| 191 |
+
return bcol_;
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
/**
|
| 195 |
+
* @return The number of blocks along rows
|
| 196 |
+
*/
|
| 197 |
+
std::int32_t blockRows() const {
|
| 198 |
+
return nbrow_;
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
/**
|
| 202 |
+
* @return The number of blocks along columns
|
| 203 |
+
*/
|
| 204 |
+
std::int32_t blockCols() const {
|
| 205 |
+
return nbcol_;
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
/**
|
| 209 |
+
* @return The number of the rows in the currently packed block of a matrix.
|
| 210 |
+
* For pre-packed (i.e., fully-packed), it's equal to the total number
|
| 211 |
+
* of rows.
|
| 212 |
+
*/
|
| 213 |
+
std::int32_t numPackedRows() const {
|
| 214 |
+
return packedBlock_.row_size;
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
/**
|
| 218 |
+
* @return The number of columns in the currently packed block of a matrix.
|
| 219 |
+
* For pre-packed (i.e., fully-packed), it's equal to the number of
|
| 220 |
+
* columns.
|
| 221 |
+
*/
|
| 222 |
+
std::int32_t numPackedCols() const {
|
| 223 |
+
return packedBlock_.col_size;
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
/**
|
| 227 |
+
* @return The first row of the block we're working on.
|
| 228 |
+
*/
|
| 229 |
+
std::int32_t packedRowStart() const {
|
| 230 |
+
return packedBlock_.row_start;
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
/**
|
| 234 |
+
* @return The first column of the block we're working on.
|
| 235 |
+
*/
|
| 236 |
+
std::int32_t packedColStart() const {
|
| 237 |
+
return packedBlock_.col_start;
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
/**
|
| 241 |
+
* @return The beginning of (rowBlockNum, colBlockNum)th block
|
| 242 |
+
*/
|
| 243 |
+
inpType* getBuf(std::int32_t rowBlockNum = 0, std::int32_t colBlockNum = 0) {
|
| 244 |
+
return buf_ + blockRowSize() * blockColSize() * rowBlockNum +
|
| 245 |
+
blockRowSize() * blockColSize() * blockCols() * colBlockNum;
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
/**
|
| 249 |
+
* @brief Print the packed block.
|
| 250 |
+
*/
|
| 251 |
+
void printPackedMatrix(const std::string& name) {
|
| 252 |
+
static_cast<PT*>(this)->printPackedMatrix(name);
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
/**
|
| 256 |
+
* @return The number of rows in the last row block.
|
| 257 |
+
*/
|
| 258 |
+
std::int32_t lastBrow() const {
|
| 259 |
+
return last_brow_;
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
/**
|
| 263 |
+
* @return The number of columns in the last column block.
|
| 264 |
+
*/
|
| 265 |
+
std::int32_t lastBcol() const {
|
| 266 |
+
return last_bcol_;
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
int numGroups() const {
|
| 270 |
+
return G_;
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
/**
|
| 274 |
+
* @return True if the last column block has fewer columns than the block
|
| 275 |
+
* size.
|
| 276 |
+
*/
|
| 277 |
+
bool isThereColRemainder() const {
|
| 278 |
+
return last_bcol_ != blockColSize();
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
virtual ~PackMatrix() {
|
| 282 |
+
if (bufAllocatedHere_) {
|
| 283 |
+
fbgemmAlignedFree(buf_);
|
| 284 |
+
}
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
protected:
|
| 288 |
+
/**
|
| 289 |
+
* Set which block we're packing
|
| 290 |
+
*/
|
| 291 |
+
void packedBlock(const block_type_t& block) {
|
| 292 |
+
packedBlock_ = block;
|
| 293 |
+
nbrow_ = (numPackedRows() + blockRowSize() - 1) / blockRowSize();
|
| 294 |
+
nbcol_ = (numPackedCols() + blockColSize() - 1) / blockColSize();
|
| 295 |
+
|
| 296 |
+
last_brow_ = ((numPackedRows() % blockRowSize()) == 0)
|
| 297 |
+
? blockRowSize()
|
| 298 |
+
: (numPackedRows() % blockRowSize());
|
| 299 |
+
last_bcol_ = ((numPackedCols() % blockColSize()) == 0)
|
| 300 |
+
? blockColSize()
|
| 301 |
+
: (numPackedCols() % blockColSize());
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
inpType* buf_;
|
| 305 |
+
std::int32_t brow_; ///< the number of rows in each block
|
| 306 |
+
std::int32_t bcol_; ///< the number of columns in each block
|
| 307 |
+
std::int32_t nbrow_; ///< the number of blocks along rows
|
| 308 |
+
std::int32_t nbcol_; ///< the number of blocks along columns
|
| 309 |
+
bool bufAllocatedHere_{false};
|
| 310 |
+
const BlockingFactors*
|
| 311 |
+
blocking_params; ///< MCB, KCB, NCB, MR, NR, NR_MIN, ROW_INTERLEAVE;
|
| 312 |
+
|
| 313 |
+
private:
|
| 314 |
+
std::int32_t nrows_, ncols_;
|
| 315 |
+
int G_;
|
| 316 |
+
block_type_t packedBlock_; ///< The block in the source matrix just packed
|
| 317 |
+
std::int32_t last_brow_, last_bcol_;
|
| 318 |
+
};
|
| 319 |
+
|
| 320 |
+
/**
|
| 321 |
+
* @brief Matrix packed for the first input matrix in GEMM (usually
|
| 322 |
+
* activation). The source matrix is already quantized. Default
|
| 323 |
+
* accumulation type is int32.
|
| 324 |
+
*/
|
| 325 |
+
template <typename T, typename accT = std::int32_t>
|
| 326 |
+
class FBGEMM_API PackAMatrix final
|
| 327 |
+
: public PackMatrix<PackAMatrix<T, accT>, T, accT> {
|
| 328 |
+
public:
|
| 329 |
+
using This = PackAMatrix<T, accT>;
|
| 330 |
+
using BaseType = PackMatrix<This, T, accT>;
|
| 331 |
+
using inpType = T;
|
| 332 |
+
using accType = accT;
|
| 333 |
+
|
| 334 |
+
PackAMatrix() = delete; // no default constructor
|
| 335 |
+
|
| 336 |
+
PackAMatrix(
|
| 337 |
+
matrix_op_t trans,
|
| 338 |
+
std::int32_t nRow,
|
| 339 |
+
std::int32_t nCol,
|
| 340 |
+
const inpType* smat,
|
| 341 |
+
std::int32_t ld,
|
| 342 |
+
inpType* pmat = nullptr,
|
| 343 |
+
int groups = 1,
|
| 344 |
+
const BlockingFactors* params = nullptr);
|
| 345 |
+
|
| 346 |
+
/**
|
| 347 |
+
* Activation matrices are not constant so cannot amortize the cost of
|
| 348 |
+
* pre-packing.
|
| 349 |
+
*/
|
| 350 |
+
bool isPrePacked() const {
|
| 351 |
+
return false;
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
/**
|
| 355 |
+
* @return True if this is used as A matrix.
|
| 356 |
+
*/
|
| 357 |
+
static constexpr bool isA() {
|
| 358 |
+
return true;
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
/**
|
| 362 |
+
* @return A pointer to the row offset buffer. There is no row offset buffer
|
| 363 |
+
* calculations with this packing class, hence, it returns nullptr.
|
| 364 |
+
*/
|
| 365 |
+
std::int32_t* getRowOffsetBuffer() const {
|
| 366 |
+
return nullptr;
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
/**
|
| 370 |
+
* @return Offset of the element in the packed matrix that was at (i, j) in
|
| 371 |
+
* the source matrix.
|
| 372 |
+
*/
|
| 373 |
+
std::int32_t addr(std::int32_t i, std::int32_t j) const;
|
| 374 |
+
|
| 375 |
+
/**
|
| 376 |
+
* @brief Packs a block of source matrix into pmat buffer.
|
| 377 |
+
*/
|
| 378 |
+
void pack(const block_type_t& block);
|
| 379 |
+
|
| 380 |
+
/**
|
| 381 |
+
* @brief Print the packed block.
|
| 382 |
+
*/
|
| 383 |
+
void printPackedMatrix(const std::string& name);
|
| 384 |
+
|
| 385 |
+
private:
|
| 386 |
+
matrix_op_t trans_;
|
| 387 |
+
const T* smat_;
|
| 388 |
+
std::int32_t ld_;
|
| 389 |
+
std::int32_t row_interleave_B_;
|
| 390 |
+
};
|
| 391 |
+
|
| 392 |
+
/**
|
| 393 |
+
* @brief Matrix packed for the second input matrix in GEMM (usually weight).
|
| 394 |
+
* The source matrix is already quantized. Default accumulation
|
| 395 |
+
* type is int32.
|
| 396 |
+
*/
|
| 397 |
+
template <typename T, typename accT = std::int32_t>
|
| 398 |
+
class FBGEMM_API PackBMatrix final
|
| 399 |
+
: public PackMatrix<PackBMatrix<T, accT>, T, accT> {
|
| 400 |
+
public:
|
| 401 |
+
using This = PackBMatrix<T, accT>;
|
| 402 |
+
using BaseType = PackMatrix<This, T, accT>;
|
| 403 |
+
using inpType = T;
|
| 404 |
+
using accType = accT;
|
| 405 |
+
|
| 406 |
+
PackBMatrix() = delete; // no default constructor
|
| 407 |
+
|
| 408 |
+
/**
|
| 409 |
+
* @param groups if > 1 and trans == NoTranspose, smat is nRow x nCol with
|
| 410 |
+
* groups are vertically concatenated: each group is
|
| 411 |
+
* (nRow / groups) x nCol .
|
| 412 |
+
* if > 1 and trans == Transpose, smat is (nCol * groups) x
|
| 413 |
+
* (nRow / groups) with groups are horizontally concatenated:
|
| 414 |
+
* each group is nCol x (nRow / groups) . Each group is
|
| 415 |
+
* transposed and vertically concatenated to match with the
|
| 416 |
+
* NoTranspose case.
|
| 417 |
+
*/
|
| 418 |
+
PackBMatrix(
|
| 419 |
+
matrix_op_t trans,
|
| 420 |
+
std::int32_t nRow,
|
| 421 |
+
std::int32_t nCol,
|
| 422 |
+
const inpType* smat,
|
| 423 |
+
std::int32_t ld,
|
| 424 |
+
inpType* pmat = nullptr,
|
| 425 |
+
int groups = 1,
|
| 426 |
+
const BlockingFactors* params = nullptr);
|
| 427 |
+
|
| 428 |
+
/**
|
| 429 |
+
* Weight matrices are usually constant so worth pre-packing.
|
| 430 |
+
*/
|
| 431 |
+
bool isPrePacked() const {
|
| 432 |
+
return true;
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
/**
|
| 436 |
+
* @return True if to be used as A matrix, False otherwise.
|
| 437 |
+
*/
|
| 438 |
+
static constexpr bool isA() {
|
| 439 |
+
return false;
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
/**
|
| 443 |
+
* @brief When k loop is also tiled/blocked, this function is used to check if
|
| 444 |
+
* have executed computations for the last k block so that we can perform
|
| 445 |
+
* post-GEMM operations.
|
| 446 |
+
*/
|
| 447 |
+
bool isThisLastKBlock(int block_id) const {
|
| 448 |
+
return (BaseType::blockRows() - 1) == block_id;
|
| 449 |
+
}
|
| 450 |
+
|
| 451 |
+
/**
|
| 452 |
+
* @return Offset of the element in the packed matrix that was at (i, j) in
|
| 453 |
+
* the source matrix.
|
| 454 |
+
*/
|
| 455 |
+
std::int32_t addr(std::int32_t i, std::int32_t j) const;
|
| 456 |
+
|
| 457 |
+
/**
|
| 458 |
+
* @brief Packs a block of source matrix into pmat buffer. The blocking
|
| 459 |
+
* parameters are needed to compute the buffer size of each group.
|
| 460 |
+
* It will use default blocking parameters if params is not provided.
|
| 461 |
+
*/
|
| 462 |
+
void pack(const block_type_t& block, const BlockingFactors* params = nullptr);
|
| 463 |
+
|
| 464 |
+
/**
|
| 465 |
+
* @brief Print the packed block.
|
| 466 |
+
*/
|
| 467 |
+
void printPackedMatrix(
|
| 468 |
+
const std::string& name,
|
| 469 |
+
const BlockingFactors* params = nullptr);
|
| 470 |
+
|
| 471 |
+
/**
|
| 472 |
+
* @return true if meta information like matrix shape is the same.
|
| 473 |
+
*/
|
| 474 |
+
bool metaEquals(const PackBMatrix<T, accT>& that) const;
|
| 475 |
+
/**
|
| 476 |
+
* @return true if matrices are the same.
|
| 477 |
+
*/
|
| 478 |
+
bool equals(const PackBMatrix<T, accT>& that) const;
|
| 479 |
+
|
| 480 |
+
/**
|
| 481 |
+
* @brief Unpack pmat buffer to the origin_buf (Used for the serialization to
|
| 482 |
+
* recover weight matrix).
|
| 483 |
+
*/
|
| 484 |
+
void unpack(T* origin_buf, const BlockingFactors* params = nullptr);
|
| 485 |
+
|
| 486 |
+
~PackBMatrix() override = default;
|
| 487 |
+
|
| 488 |
+
private:
|
| 489 |
+
matrix_op_t trans_;
|
| 490 |
+
const T* smat_;
|
| 491 |
+
std::int32_t ld_;
|
| 492 |
+
std::int32_t row_interleave_;
|
| 493 |
+
|
| 494 |
+
/**
|
| 495 |
+
* @brief Internal function performing both pack & unpack
|
| 496 |
+
*/
|
| 497 |
+
void pack_unpack_(
|
| 498 |
+
const block_type_t& block,
|
| 499 |
+
T* unpack_buf,
|
| 500 |
+
T* pack_buf,
|
| 501 |
+
bool ispack,
|
| 502 |
+
const BlockingFactors* params = nullptr);
|
| 503 |
+
};
|
| 504 |
+
|
| 505 |
+
/**
|
| 506 |
+
* @brief Matrix packed for direct group convolution.
|
| 507 |
+
* The source matrix is already quantized. Default accumulation
|
| 508 |
+
* type is int32.
|
| 509 |
+
*/
|
| 510 |
+
template <typename T, typename accT = std::int32_t, int SPATIAL_DIM = 2>
|
| 511 |
+
class FBGEMM_API PackWeightMatrixForGConv {
|
| 512 |
+
public:
|
| 513 |
+
using This = PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>;
|
| 514 |
+
using inpType = T;
|
| 515 |
+
using accType = accT;
|
| 516 |
+
|
| 517 |
+
PackWeightMatrixForGConv() = delete; // no default constructor
|
| 518 |
+
PackWeightMatrixForGConv(const PackWeightMatrixForGConv&) = delete; // no copy
|
| 519 |
+
PackWeightMatrixForGConv& operator=(const PackWeightMatrixForGConv&) =
|
| 520 |
+
delete; // no copy
|
| 521 |
+
|
| 522 |
+
PackWeightMatrixForGConv(PackWeightMatrixForGConv&&) = delete; // no move
|
| 523 |
+
PackWeightMatrixForGConv& operator=(PackWeightMatrixForGConv&&) =
|
| 524 |
+
delete; // no move
|
| 525 |
+
|
| 526 |
+
/**
|
| 527 |
+
* @param pmat if nullptr, a buffer is allocated and owned by this class.
|
| 528 |
+
*/
|
| 529 |
+
PackWeightMatrixForGConv(
|
| 530 |
+
matrix_op_t trans,
|
| 531 |
+
const conv_param_t<SPATIAL_DIM>& conv_param,
|
| 532 |
+
const inpType* sdata,
|
| 533 |
+
inpType* pdata = nullptr);
|
| 534 |
+
|
| 535 |
+
/**
|
| 536 |
+
* Number of groups we work at a time to fill the full simd width
|
| 537 |
+
* e.g., IC_PER_G = 4 and OC_PER_G = 4, we work on two groups at a time
|
| 538 |
+
* to fill the avx2 width of 256 bits.
|
| 539 |
+
*/
|
| 540 |
+
static int numOfGroupsTogether(const conv_param_t<SPATIAL_DIM>& conv_param);
|
| 541 |
+
|
| 542 |
+
/**
|
| 543 |
+
* @brief Packs a block of source matrix into pmat buffer.
|
| 544 |
+
*/
|
| 545 |
+
void pack();
|
| 546 |
+
|
| 547 |
+
/**
|
| 548 |
+
* @brief Unpacks a pmat buffer into source matrix.
|
| 549 |
+
*/
|
| 550 |
+
void unpack(T* origin_buf);
|
| 551 |
+
|
| 552 |
+
/**
|
| 553 |
+
* @brief Return packed data
|
| 554 |
+
*/
|
| 555 |
+
inpType* getBuf() {
|
| 556 |
+
return pdata_;
|
| 557 |
+
}
|
| 558 |
+
|
| 559 |
+
~PackWeightMatrixForGConv() {
|
| 560 |
+
if (bufAllocatedHere_) {
|
| 561 |
+
fbgemmAlignedFree(pdata_);
|
| 562 |
+
}
|
| 563 |
+
}
|
| 564 |
+
|
| 565 |
+
private:
|
| 566 |
+
matrix_op_t trans_;
|
| 567 |
+
const conv_param_t<SPATIAL_DIM> conv_param_;
|
| 568 |
+
const T* sdata_;
|
| 569 |
+
T* pdata_;
|
| 570 |
+
bool bufAllocatedHere_{false};
|
| 571 |
+
// Number of groups we work at a time to fill the full simd width
|
| 572 |
+
int GTogether_;
|
| 573 |
+
|
| 574 |
+
/**
|
| 575 |
+
* @brief Internal function performing both pack & unpack
|
| 576 |
+
*/
|
| 577 |
+
void pack_unpack_(const T* src, T* dst, bool ispack);
|
| 578 |
+
|
| 579 |
+
/**
|
| 580 |
+
* @brief Get the index of the unpacked data
|
| 581 |
+
*/
|
| 582 |
+
int unpacked_index_(int t, int r, int s, int k, int g, int c, bool tr);
|
| 583 |
+
|
| 584 |
+
/**
|
| 585 |
+
* @brief Get the index of the packed data
|
| 586 |
+
*/
|
| 587 |
+
int packed_index_(int t, int r, int s, int k, int g, int c);
|
| 588 |
+
};
|
| 589 |
+
|
| 590 |
+
/**
|
| 591 |
+
* @brief A container class to keep packed weight tensor for convolution.
|
| 592 |
+
* The source tensor should already be quantized.
|
| 593 |
+
*
|
| 594 |
+
* @tparam SPATIAL_DIM is equal to 2 for 2D convolutions and 3 for 3D
|
| 595 |
+
* convolutions. Default value is 2.
|
| 596 |
+
* @tparam T is the datatype for source tensor. Default value is int8.
|
| 597 |
+
* @tparam accT is the datatype to accumulate into. Default value is int32.
|
| 598 |
+
*/
|
| 599 |
+
template <
|
| 600 |
+
int SPATIAL_DIM = 2,
|
| 601 |
+
typename T = std::int8_t,
|
| 602 |
+
typename accT = std::int32_t>
|
| 603 |
+
class FBGEMM_API PackWeightsForConv {
|
| 604 |
+
public:
|
| 605 |
+
using This = PackWeightsForConv<SPATIAL_DIM, T, accT>;
|
| 606 |
+
using inpType = T;
|
| 607 |
+
using accType = accT;
|
| 608 |
+
|
| 609 |
+
PackWeightsForConv() = delete; // no default constructor
|
| 610 |
+
|
| 611 |
+
PackWeightsForConv(
|
| 612 |
+
const conv_param_t<SPATIAL_DIM>& conv_param,
|
| 613 |
+
const inpType* sdata,
|
| 614 |
+
const BlockingFactors* blocking_params = nullptr);
|
| 615 |
+
|
| 616 |
+
std::shared_ptr<PackBMatrix<T, accT>> getPackedWForIm2col() {
|
| 617 |
+
return W_im2col_packed_;
|
| 618 |
+
}
|
| 619 |
+
|
| 620 |
+
#if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
|
| 621 |
+
std::shared_ptr<PackedDepthWiseConvMatrix> getPackedWForDepthwise() {
|
| 622 |
+
return W_dw_packed_;
|
| 623 |
+
}
|
| 624 |
+
#endif // __aarch64__
|
| 625 |
+
|
| 626 |
+
std::shared_ptr<PackedDirectConvMatrix> getPackedWForDirectconv() {
|
| 627 |
+
return W_dc_packed_;
|
| 628 |
+
}
|
| 629 |
+
|
| 630 |
+
std::shared_ptr<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>>
|
| 631 |
+
getPackedWForGroupwise() {
|
| 632 |
+
return W_gconv_packed_;
|
| 633 |
+
}
|
| 634 |
+
|
| 635 |
+
std::shared_ptr<PackBMatrix<T, accT>> getPackedWForPointwise() {
|
| 636 |
+
return W_pointwise_packed_;
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
int inputChannels() {
|
| 640 |
+
return conv_param_.IC;
|
| 641 |
+
}
|
| 642 |
+
|
| 643 |
+
int outputChannels() {
|
| 644 |
+
return conv_param_.OC;
|
| 645 |
+
}
|
| 646 |
+
|
| 647 |
+
std::array<int, SPATIAL_DIM> kernelDims() {
|
| 648 |
+
return conv_param_.K;
|
| 649 |
+
}
|
| 650 |
+
|
| 651 |
+
int groups() {
|
| 652 |
+
return conv_param_.G;
|
| 653 |
+
}
|
| 654 |
+
|
| 655 |
+
/**
|
| 656 |
+
* @brief Returns true if the packed weights would work for the given
|
| 657 |
+
* convolution parameters, and false otherwise
|
| 658 |
+
*/
|
| 659 |
+
bool isPackingCompliant(const conv_param_t<SPATIAL_DIM>& conv_p);
|
| 660 |
+
|
| 661 |
+
/**
|
| 662 |
+
* @brief Returns a string of mismatching parameters
|
| 663 |
+
*/
|
| 664 |
+
std::string mismatchingParams(const conv_param_t<SPATIAL_DIM>& conv_p);
|
| 665 |
+
|
| 666 |
+
/**
|
| 667 |
+
* @brief Unpack packed matric into origin_buf (Used for the serialization to
|
| 668 |
+
* recover weight matrix).
|
| 669 |
+
*/
|
| 670 |
+
void unpack(T* origin_buf);
|
| 671 |
+
|
| 672 |
+
private:
|
| 673 |
+
const conv_param_t<SPATIAL_DIM> conv_param_;
|
| 674 |
+
// Packed weights if we use im2col based convolution implementation
|
| 675 |
+
std::shared_ptr<PackBMatrix<T, accT>> W_im2col_packed_;
|
| 676 |
+
#if defined(FBGEMM_FBCODE) || !defined(__aarch64__)
|
| 677 |
+
// Packed weights if we use depthwise convolution implementation
|
| 678 |
+
std::shared_ptr<PackedDepthWiseConvMatrix> W_dw_packed_;
|
| 679 |
+
#endif // __aarch64__
|
| 680 |
+
// Packed weights if we use direct convolution implementation
|
| 681 |
+
std::shared_ptr<PackedDirectConvMatrix> W_dc_packed_;
|
| 682 |
+
// Packed weights if we use groupwise (small channels per group) convolution
|
| 683 |
+
// implementation
|
| 684 |
+
std::shared_ptr<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>>
|
| 685 |
+
W_gconv_packed_;
|
| 686 |
+
// Packed weights if we use direct gemm for pointwise convolution
|
| 687 |
+
std::shared_ptr<PackBMatrix<T, accT>> W_pointwise_packed_;
|
| 688 |
+
};
|
| 689 |
+
|
| 690 |
+
/**
|
| 691 |
+
* @brief Matrix packed for the first input matrix in GEMM (usually activation),
|
| 692 |
+
* and row offsets used for requantization is computed during packing.
|
| 693 |
+
* Im2col is fused with packing here. The source matrix is already
|
| 694 |
+
* quantized.
|
| 695 |
+
*/
|
| 696 |
+
template <typename T, typename accT = std::int32_t, int SPATIAL_DIM = 2>
|
| 697 |
+
class FBGEMM_API PackAWithIm2Col
|
| 698 |
+
: public PackMatrix<PackAWithIm2Col<T, accT, SPATIAL_DIM>, T, accT> {
|
| 699 |
+
public:
|
| 700 |
+
using This = PackAWithIm2Col<T, accT, SPATIAL_DIM>;
|
| 701 |
+
using BaseType = PackMatrix<This, T, accT>;
|
| 702 |
+
using inpType = T;
|
| 703 |
+
using accType = accT;
|
| 704 |
+
|
| 705 |
+
PackAWithIm2Col() = delete; // no default constructor
|
| 706 |
+
/**
|
| 707 |
+
* @param zero_pt the quantized value that maps to 0.0f floating-point number.
|
| 708 |
+
* @param row_offset If nullptr, this constructor internally allocates a
|
| 709 |
+
* buffer and owns it. Otherwise, this class doesn't own
|
| 710 |
+
* the buffer. The buffer will be populated when pack
|
| 711 |
+
* function is called.
|
| 712 |
+
* @param b_symmetric if true we skip row offset computation
|
| 713 |
+
*/
|
| 714 |
+
PackAWithIm2Col(
|
| 715 |
+
const conv_param_t<SPATIAL_DIM>& conv_param,
|
| 716 |
+
const T* sdata,
|
| 717 |
+
inpType* pmat = nullptr,
|
| 718 |
+
std::int32_t a_zero_pt = 0,
|
| 719 |
+
std::int32_t* row_offset = nullptr,
|
| 720 |
+
bool b_symmetric = false,
|
| 721 |
+
const BlockingFactors* params = nullptr);
|
| 722 |
+
|
| 723 |
+
PackAWithIm2Col(const PackAWithIm2Col&) = delete;
|
| 724 |
+
PackAWithIm2Col(PackAWithIm2Col&&) = delete;
|
| 725 |
+
PackAWithIm2Col& operator=(const PackAWithIm2Col&) = delete;
|
| 726 |
+
PackAWithIm2Col& operator=(PackAWithIm2Col&&) = delete;
|
| 727 |
+
|
| 728 |
+
/**
|
| 729 |
+
* Activation matrices are not constant so cannot amortize the cost of
|
| 730 |
+
* pre-packing.
|
| 731 |
+
*/
|
| 732 |
+
bool isPrePacked() const {
|
| 733 |
+
return false;
|
| 734 |
+
}
|
| 735 |
+
|
| 736 |
+
/**
|
| 737 |
+
* @return True if this is used as A matrix.
|
| 738 |
+
*/
|
| 739 |
+
static constexpr bool isA() {
|
| 740 |
+
return true;
|
| 741 |
+
}
|
| 742 |
+
|
| 743 |
+
/**
|
| 744 |
+
* @brief Packs a block of source matrix into pmat buffer.
|
| 745 |
+
*/
|
| 746 |
+
void pack(const block_type_t& block);
|
| 747 |
+
|
| 748 |
+
/**
|
| 749 |
+
* @return A pointer to the row offset buffer.
|
| 750 |
+
*/
|
| 751 |
+
std::int32_t* getRowOffsetBuffer() const {
|
| 752 |
+
return row_offset_;
|
| 753 |
+
}
|
| 754 |
+
|
| 755 |
+
/**
|
| 756 |
+
* @brief Print the packed block.
|
| 757 |
+
*/
|
| 758 |
+
void printPackedMatrix(const std::string& name);
|
| 759 |
+
|
| 760 |
+
/**
|
| 761 |
+
* @return Size of row offset buffer in number of elements
|
| 762 |
+
*/
|
| 763 |
+
static int rowOffsetBufferSize(const BlockingFactors* params = nullptr);
|
| 764 |
+
|
| 765 |
+
~PackAWithIm2Col() override {
|
| 766 |
+
if (rowOffsetAllocatedHere) {
|
| 767 |
+
fbgemmAlignedFree(row_offset_);
|
| 768 |
+
}
|
| 769 |
+
}
|
| 770 |
+
|
| 771 |
+
private:
|
| 772 |
+
const conv_param_t<SPATIAL_DIM> conv_p_;
|
| 773 |
+
const T* sdata_;
|
| 774 |
+
std::int32_t a_zero_pt_;
|
| 775 |
+
std::int32_t* row_offset_{nullptr};
|
| 776 |
+
bool rowOffsetAllocatedHere{false};
|
| 777 |
+
std::int32_t row_interleave_B_;
|
| 778 |
+
};
|
| 779 |
+
|
| 780 |
+
/**
|
| 781 |
+
* @brief Matrix packed for the first input matrix in GEMM (usually activation),
|
| 782 |
+
* and row offsets used for requantization is computed during packing.
|
| 783 |
+
* The source matrix is already quantized.
|
| 784 |
+
*/
|
| 785 |
+
template <typename T, typename accT = std::int32_t>
|
| 786 |
+
class FBGEMM_API PackAWithRowOffset final
|
| 787 |
+
: public PackMatrix<PackAWithRowOffset<T, accT>, T, accT> {
|
| 788 |
+
public:
|
| 789 |
+
using This = PackAWithRowOffset<T, accT>;
|
| 790 |
+
using BaseType = PackMatrix<This, T, accT>;
|
| 791 |
+
using inpType = T;
|
| 792 |
+
using accType = accT;
|
| 793 |
+
|
| 794 |
+
PackAWithRowOffset() = delete; // no default constructor
|
| 795 |
+
/**
|
| 796 |
+
* @param row_offset If nullptr, this constructor internally allocates a
|
| 797 |
+
* buffer and owns it. Otherwise, this class doesn't own
|
| 798 |
+
* the buffer. The buffer will be populated when pack
|
| 799 |
+
* function is called.
|
| 800 |
+
*/
|
| 801 |
+
PackAWithRowOffset(
|
| 802 |
+
matrix_op_t trans,
|
| 803 |
+
std::uint32_t nRow,
|
| 804 |
+
std::uint32_t nCol,
|
| 805 |
+
const T* smat,
|
| 806 |
+
std::uint32_t ld,
|
| 807 |
+
inpType* pmat = nullptr,
|
| 808 |
+
int groups = 1,
|
| 809 |
+
std::int32_t* row_offset = nullptr,
|
| 810 |
+
const BlockingFactors* params = nullptr);
|
| 811 |
+
|
| 812 |
+
PackAWithRowOffset(const PackAWithRowOffset&) = delete;
|
| 813 |
+
PackAWithRowOffset(PackAWithRowOffset&&) = delete;
|
| 814 |
+
PackAWithRowOffset& operator=(const PackAWithRowOffset&) = delete;
|
| 815 |
+
PackAWithRowOffset& operator=(PackAWithRowOffset&&) = delete;
|
| 816 |
+
|
| 817 |
+
/**
|
| 818 |
+
* Activation matrices are not constant so cannot amortize the cost of
|
| 819 |
+
* pre-packing.
|
| 820 |
+
*/
|
| 821 |
+
bool isPrePacked() const {
|
| 822 |
+
return false;
|
| 823 |
+
}
|
| 824 |
+
|
| 825 |
+
/**
|
| 826 |
+
* @return True if this is used as A matrix.
|
| 827 |
+
*/
|
| 828 |
+
static constexpr bool isA() {
|
| 829 |
+
return true;
|
| 830 |
+
}
|
| 831 |
+
|
| 832 |
+
/**
|
| 833 |
+
* @return Offset of the element in the packed matrix that was at (i, j) in
|
| 834 |
+
* the source matrix
|
| 835 |
+
*/
|
| 836 |
+
std::int32_t addr(std::int32_t i, std::int32_t j) const;
|
| 837 |
+
|
| 838 |
+
/**
|
| 839 |
+
* @brief Packs a block of source matrix into pmat buffer.
|
| 840 |
+
*/
|
| 841 |
+
void pack(const block_type_t& block);
|
| 842 |
+
|
| 843 |
+
/**
|
| 844 |
+
* @return A pointer to the row offset buffer.
|
| 845 |
+
*/
|
| 846 |
+
std::int32_t* getRowOffsetBuffer() const {
|
| 847 |
+
return row_offset_;
|
| 848 |
+
}
|
| 849 |
+
|
| 850 |
+
/**
|
| 851 |
+
* @brief Print the packed block.
|
| 852 |
+
*/
|
| 853 |
+
void printPackedMatrix(const std::string& name);
|
| 854 |
+
|
| 855 |
+
/**
|
| 856 |
+
* @return size of row offset buffer in number of elements
|
| 857 |
+
*/
|
| 858 |
+
static int rowOffsetBufferSize(const BlockingFactors* params = nullptr);
|
| 859 |
+
|
| 860 |
+
~PackAWithRowOffset() override {
|
| 861 |
+
if (rowOffsetAllocatedHere) {
|
| 862 |
+
fbgemmAlignedFree(row_offset_);
|
| 863 |
+
}
|
| 864 |
+
}
|
| 865 |
+
|
| 866 |
+
private:
|
| 867 |
+
matrix_op_t trans_;
|
| 868 |
+
const T* smat_;
|
| 869 |
+
std::uint32_t ld_;
|
| 870 |
+
std::int32_t* row_offset_{nullptr};
|
| 871 |
+
bool rowOffsetAllocatedHere{false};
|
| 872 |
+
std::int32_t row_interleave_B_;
|
| 873 |
+
};
|
| 874 |
+
|
| 875 |
+
/**
|
| 876 |
+
* @brief Matrix packed for the first input matrix in GEMM (usually activation),
|
| 877 |
+
* and row offsets used for requantization is computed during packing.
|
| 878 |
+
* The source matrix is in fp32 and quantized during packing.
|
| 879 |
+
*/
|
| 880 |
+
template <typename T, typename accT = std::int32_t>
|
| 881 |
+
class FBGEMM_API PackAWithQuantRowOffset final
|
| 882 |
+
: public PackMatrix<PackAWithQuantRowOffset<T, accT>, T, accT> {
|
| 883 |
+
public:
|
| 884 |
+
using This = PackAWithQuantRowOffset<T, accT>;
|
| 885 |
+
using BaseType = PackMatrix<This, T, accT>;
|
| 886 |
+
using inpType = T;
|
| 887 |
+
using accType = accT;
|
| 888 |
+
|
| 889 |
+
PackAWithQuantRowOffset() = delete; // no default constructor
|
| 890 |
+
/**
|
| 891 |
+
* @param row_offset If nullptr, this constructor internally allocates a
|
| 892 |
+
* buffer and owns it. Otherwise, this class doesn't own
|
| 893 |
+
* the buffer. The buffer will be populated when pack
|
| 894 |
+
* function is called.
|
| 895 |
+
*/
|
| 896 |
+
PackAWithQuantRowOffset(
|
| 897 |
+
matrix_op_t trans,
|
| 898 |
+
std::int32_t nRow,
|
| 899 |
+
std::int32_t nCol,
|
| 900 |
+
const float* smat,
|
| 901 |
+
std::int32_t ld,
|
| 902 |
+
inpType* pmat = nullptr,
|
| 903 |
+
float scale = 1.0f,
|
| 904 |
+
std::int32_t zero_pt = 0,
|
| 905 |
+
int groups = 1,
|
| 906 |
+
std::int32_t* row_offset = nullptr,
|
| 907 |
+
const BlockingFactors* params = nullptr);
|
| 908 |
+
PackAWithQuantRowOffset(const PackAWithQuantRowOffset&) = delete;
|
| 909 |
+
PackAWithQuantRowOffset(PackAWithQuantRowOffset&&) = delete;
|
| 910 |
+
PackAWithQuantRowOffset& operator=(const PackAWithQuantRowOffset&) = delete;
|
| 911 |
+
PackAWithQuantRowOffset& operator=(PackAWithQuantRowOffset&&) = delete;
|
| 912 |
+
|
| 913 |
+
/**
|
| 914 |
+
* Activation matrices are not constant so cannot amortize the cost of
|
| 915 |
+
* pre-packing.
|
| 916 |
+
*/
|
| 917 |
+
bool isPrePacked() const {
|
| 918 |
+
return false;
|
| 919 |
+
}
|
| 920 |
+
|
| 921 |
+
/**
|
| 922 |
+
* @return True if this is used as A matrix.
|
| 923 |
+
*/
|
| 924 |
+
static constexpr bool isA() {
|
| 925 |
+
return true;
|
| 926 |
+
}
|
| 927 |
+
|
| 928 |
+
/**
|
| 929 |
+
* @return offset of the element in the packed matrix that was at (i, j) in
|
| 930 |
+
* the source matrix
|
| 931 |
+
*/
|
| 932 |
+
std::int32_t addr(std::int32_t i, std::int32_t j) const;
|
| 933 |
+
|
| 934 |
+
/**
|
| 935 |
+
* @brief Packs a block of source matrix into pmat buffer.
|
| 936 |
+
*/
|
| 937 |
+
void pack(const block_type_t& block);
|
| 938 |
+
|
| 939 |
+
/**
|
| 940 |
+
* @return A pointer to the row offset buffer.
|
| 941 |
+
*/
|
| 942 |
+
std::int32_t* getRowOffsetBuffer() const {
|
| 943 |
+
return row_offset_;
|
| 944 |
+
}
|
| 945 |
+
|
| 946 |
+
/**
|
| 947 |
+
* @brief Print the packed block.
|
| 948 |
+
*/
|
| 949 |
+
void printPackedMatrix(const std::string& name);
|
| 950 |
+
|
| 951 |
+
/**
|
| 952 |
+
* @return Size of row offset buffer in number of elements
|
| 953 |
+
*/
|
| 954 |
+
static int rowOffsetBufferSize(const BlockingFactors* params = nullptr);
|
| 955 |
+
|
| 956 |
+
~PackAWithQuantRowOffset() override {
|
| 957 |
+
if (rowOffsetAllocatedHere) {
|
| 958 |
+
fbgemmAlignedFree(row_offset_);
|
| 959 |
+
}
|
| 960 |
+
}
|
| 961 |
+
|
| 962 |
+
private:
|
| 963 |
+
matrix_op_t trans_;
|
| 964 |
+
const float* smat_;
|
| 965 |
+
std::int32_t ld_;
|
| 966 |
+
float scale_;
|
| 967 |
+
std::int32_t zero_pt_;
|
| 968 |
+
std::int32_t* row_offset_{nullptr};
|
| 969 |
+
bool rowOffsetAllocatedHere{false};
|
| 970 |
+
std::int32_t row_interleave_B_;
|
| 971 |
+
};
|
| 972 |
+
|
| 973 |
+
/*
|
| 974 |
+
*
|
| 975 |
+
* Post Processing of outputs
|
| 976 |
+
*
|
| 977 |
+
*/
|
| 978 |
+
|
| 979 |
+
/**
|
| 980 |
+
* @brief Does nothing. NoOp. Used as the last operation in the output
|
| 981 |
+
* processing pipeline.
|
| 982 |
+
*
|
| 983 |
+
*/
|
| 984 |
+
template <typename outT = std::uint8_t, typename inT = std::uint8_t>
|
| 985 |
+
class FBGEMM_API DoNothing {
|
| 986 |
+
public:
|
| 987 |
+
using outType = outT;
|
| 988 |
+
using inpType = inT;
|
| 989 |
+
DoNothing() = default;
|
| 990 |
+
template <inst_set_t instSet>
|
| 991 |
+
int f(
|
| 992 |
+
outType* /* unused */,
|
| 993 |
+
inpType* /* unused */,
|
| 994 |
+
const block_type_t& /* unused */,
|
| 995 |
+
int /* unused */,
|
| 996 |
+
int /* unused */) const {
|
| 997 |
+
return 0;
|
| 998 |
+
}
|
| 999 |
+
};
|
| 1000 |
+
|
| 1001 |
+
/**
|
| 1002 |
+
* @brief Copy data pointed by inp ptr to out ptr when
|
| 1003 |
+
* inp ptr and out ptr are not the same.
|
| 1004 |
+
* inp buffer: row and column start points: (0, 0)
|
| 1005 |
+
* output buffer: row and column start points:
|
| 1006 |
+
* (block.row_start, block.col_start)
|
| 1007 |
+
*
|
| 1008 |
+
* This is the output processing stage that should passed when there is no
|
| 1009 |
+
* requantization and output is required in the same format as internal buffer
|
| 1010 |
+
* used for accumulation.
|
| 1011 |
+
*/
|
| 1012 |
+
template <
|
| 1013 |
+
typename outT = std::int32_t,
|
| 1014 |
+
typename inT = std::int32_t,
|
| 1015 |
+
typename nextOPType = DoNothing<outT, outT>>
|
| 1016 |
+
class FBGEMM_API memCopy {
|
| 1017 |
+
public:
|
| 1018 |
+
using outType = outT;
|
| 1019 |
+
using inpType = inT;
|
| 1020 |
+
explicit memCopy(nextOPType& nextop) : nextop_(nextop) {}
|
| 1021 |
+
template <inst_set_t instSet>
|
| 1022 |
+
inline int f(
|
| 1023 |
+
outType* out,
|
| 1024 |
+
inpType* inp,
|
| 1025 |
+
const block_type_t& block,
|
| 1026 |
+
int ld_out,
|
| 1027 |
+
int ld_in) const;
|
| 1028 |
+
|
| 1029 |
+
private:
|
| 1030 |
+
nextOPType& nextop_;
|
| 1031 |
+
};
|
| 1032 |
+
|
| 1033 |
+
/**
|
| 1034 |
+
* @brief Perform scaling on accumulated data.
|
| 1035 |
+
*/
|
| 1036 |
+
template <
|
| 1037 |
+
typename outT = std::int32_t,
|
| 1038 |
+
typename inT = std::int32_t,
|
| 1039 |
+
typename nextOPType = DoNothing<outT, outT>>
|
| 1040 |
+
class ScaleOP {
|
| 1041 |
+
public:
|
| 1042 |
+
using outType = outT;
|
| 1043 |
+
using inpType = inT;
|
| 1044 |
+
explicit ScaleOP(inpType scalingFactor) : scalingFactor_(scalingFactor) {}
|
| 1045 |
+
|
| 1046 |
+
template <inst_set_t instSet>
|
| 1047 |
+
inline int f(
|
| 1048 |
+
outType* out,
|
| 1049 |
+
inpType* inp,
|
| 1050 |
+
const block_type_t& block,
|
| 1051 |
+
int ld_out,
|
| 1052 |
+
int ld_in) const;
|
| 1053 |
+
|
| 1054 |
+
private:
|
| 1055 |
+
inpType scalingFactor_;
|
| 1056 |
+
};
|
| 1057 |
+
|
| 1058 |
+
/**
|
| 1059 |
+
* @brief Perform Relu on accumulated data.
|
| 1060 |
+
*/
|
| 1061 |
+
template <
|
| 1062 |
+
typename outT = std::int32_t,
|
| 1063 |
+
typename inT = std::int32_t,
|
| 1064 |
+
typename nextOPType = DoNothing<outT, outT>>
|
| 1065 |
+
class ReluOutput {
|
| 1066 |
+
public:
|
| 1067 |
+
using outType = outT;
|
| 1068 |
+
using inpType = inT;
|
| 1069 |
+
explicit ReluOutput(inpType zero_pt) : zero_pt_(zero_pt) {}
|
| 1070 |
+
|
| 1071 |
+
template <inst_set_t instSet>
|
| 1072 |
+
inline int f(
|
| 1073 |
+
outType* out,
|
| 1074 |
+
inpType* inp,
|
| 1075 |
+
const block_type_t& block,
|
| 1076 |
+
int ld_out,
|
| 1077 |
+
int ld_in) const;
|
| 1078 |
+
|
| 1079 |
+
private:
|
| 1080 |
+
inpType zero_pt_;
|
| 1081 |
+
};
|
| 1082 |
+
|
| 1083 |
+
/**
|
| 1084 |
+
* @brief Perform Dense-Matrix * Sparse-Matrix as a part the of output
|
| 1085 |
+
* processing pipeline.
|
| 1086 |
+
*
|
| 1087 |
+
* SPMDM (SParse Matrix times Dense Matrix) inplace on the 32-bit input buffer
|
| 1088 |
+
* (inp). After modifying the input buffer, pass it to the next op.
|
| 1089 |
+
* When groups > 1, each group is numRows() x (numCols()/groups) matrix.
|
| 1090 |
+
*/
|
| 1091 |
+
template <
|
| 1092 |
+
typename outT = std::int32_t,
|
| 1093 |
+
typename inT = std::int32_t,
|
| 1094 |
+
typename nextOPType = DoNothing<inT, inT>>
|
| 1095 |
+
class FBGEMM_API DoSpmdmOnInpBuffer {
|
| 1096 |
+
public:
|
| 1097 |
+
using outType = outT;
|
| 1098 |
+
using inpType = inT;
|
| 1099 |
+
DoSpmdmOnInpBuffer(
|
| 1100 |
+
nextOPType& nextop,
|
| 1101 |
+
const std::uint8_t* A,
|
| 1102 |
+
int lda,
|
| 1103 |
+
const CompressedSparseColumn& B_csc,
|
| 1104 |
+
int groups = 1)
|
| 1105 |
+
: nextop_(nextop), A_(A), lda_(lda), B_csc_(B_csc), groups_(groups) {}
|
| 1106 |
+
|
| 1107 |
+
template <inst_set_t instSet>
|
| 1108 |
+
inline int f(
|
| 1109 |
+
outT* out,
|
| 1110 |
+
inT* inp,
|
| 1111 |
+
const block_type_t& block,
|
| 1112 |
+
int ld_out,
|
| 1113 |
+
int ld_in) const;
|
| 1114 |
+
|
| 1115 |
+
private:
|
| 1116 |
+
nextOPType& nextop_;
|
| 1117 |
+
const std::uint8_t* A_;
|
| 1118 |
+
const int lda_;
|
| 1119 |
+
const CompressedSparseColumn& B_csc_;
|
| 1120 |
+
const int groups_;
|
| 1121 |
+
};
|
| 1122 |
+
|
| 1123 |
+
/**
|
| 1124 |
+
* @brief Perform Dense-Matrix * Sparse-Matrix as a part the of output
|
| 1125 |
+
* processing pipeline.
|
| 1126 |
+
*
|
| 1127 |
+
* SPMDM (SParse Matrix times Dense Matrix) inplace on the 32-bit input buffer
|
| 1128 |
+
* (inp). After modifying the input buffer, pass it to the next op.
|
| 1129 |
+
* When groups > 1, each group is numRows() x (numCols()/groups) matrix.
|
| 1130 |
+
*/
|
| 1131 |
+
template <
|
| 1132 |
+
typename outT = std::int32_t,
|
| 1133 |
+
typename inT = std::int32_t,
|
| 1134 |
+
typename nextOPType = DoNothing<inT, inT>>
|
| 1135 |
+
class FBGEMM_API DoSConvOnInpBuffer {
|
| 1136 |
+
public:
|
| 1137 |
+
using outType = outT;
|
| 1138 |
+
using inpType = inT;
|
| 1139 |
+
DoSConvOnInpBuffer(
|
| 1140 |
+
nextOPType& nextop,
|
| 1141 |
+
const std::uint8_t* A,
|
| 1142 |
+
const conv_param_t<>& conv_p,
|
| 1143 |
+
std::int32_t A_zero_point,
|
| 1144 |
+
const CompressedSparseColumn& B_csc)
|
| 1145 |
+
: nextop_(nextop),
|
| 1146 |
+
A_(A),
|
| 1147 |
+
conv_p_(conv_p),
|
| 1148 |
+
A_zero_point_(A_zero_point),
|
| 1149 |
+
B_csc_(B_csc) {}
|
| 1150 |
+
|
| 1151 |
+
template <inst_set_t instSet>
|
| 1152 |
+
inline int f(
|
| 1153 |
+
outT* out,
|
| 1154 |
+
inT* inp,
|
| 1155 |
+
const block_type_t& block,
|
| 1156 |
+
int ld_out,
|
| 1157 |
+
int ld_in) const;
|
| 1158 |
+
|
| 1159 |
+
private:
|
| 1160 |
+
nextOPType& nextop_;
|
| 1161 |
+
const std::uint8_t* A_;
|
| 1162 |
+
const conv_param_t<> conv_p_;
|
| 1163 |
+
const std::int32_t A_zero_point_;
|
| 1164 |
+
const CompressedSparseColumn& B_csc_;
|
| 1165 |
+
};
|
| 1166 |
+
|
| 1167 |
+
/**
|
| 1168 |
+
* @brief Requantize values in inp buffer and write to out buffer.
|
| 1169 |
+
* pass the out buffer to next op for further processing.
|
| 1170 |
+
*/
|
| 1171 |
+
template <
|
| 1172 |
+
bool FUSE_RELU,
|
| 1173 |
+
QuantizationGranularity Q_GRAN = QuantizationGranularity::TENSOR,
|
| 1174 |
+
typename BIAS_TYPE = std::int32_t,
|
| 1175 |
+
typename outT = std::uint8_t,
|
| 1176 |
+
typename inT = std::int32_t,
|
| 1177 |
+
typename nextOPType = DoNothing<outT, outT>>
|
| 1178 |
+
class FBGEMM_API ReQuantizeOutput {
|
| 1179 |
+
public:
|
| 1180 |
+
static constexpr int RELU_FUSED = FUSE_RELU;
|
| 1181 |
+
static constexpr QuantizationGranularity QGRANType = Q_GRAN;
|
| 1182 |
+
using BIAS_T = BIAS_TYPE;
|
| 1183 |
+
using outType = outT;
|
| 1184 |
+
using inpType = inT;
|
| 1185 |
+
/**
|
| 1186 |
+
* @param C_multiplier The length of this array is
|
| 1187 |
+
* 1 when Q_GRAN == QuantizationGranularity::TENSOR,
|
| 1188 |
+
* groups when Q_GRAN == QuantizationGranularity::GROUP,
|
| 1189 |
+
* nCol if Q_GRAN == QuantizationGranularity::OUT_CHANNEL
|
| 1190 |
+
* @param Bq_zero_point The length of this array should be the same as
|
| 1191 |
+
* C_multiplier.
|
| 1192 |
+
* @param row_offsets Typically, this should've been computed by a
|
| 1193 |
+
* PackAMatrix and should be obtained by
|
| 1194 |
+
* PackMatrix::getRowOffsetBuffer().
|
| 1195 |
+
* If Bq_zero_point == 0 (symmetric quantization of B
|
| 1196 |
+
* matrix), we can pass nullptr.
|
| 1197 |
+
* @param col_offsets This should be pre-computed for example using
|
| 1198 |
+
* col_offsets_with_zero_pt_s8acc32_ref.
|
| 1199 |
+
* The length should be nCol.
|
| 1200 |
+
* See PackedRequantizeTest.cc for an example.
|
| 1201 |
+
* TODO: if Aq_zero_point == 0, allow passing nullptr.
|
| 1202 |
+
* @param bias can be nullptr otherwise the length should be nCol
|
| 1203 |
+
* @param act_times_w_scale activation_scale * weight_scale. This is only
|
| 1204 |
+
* used if bias is unquantized (i.e., float).
|
| 1205 |
+
*/
|
| 1206 |
+
ReQuantizeOutput(
|
| 1207 |
+
nextOPType& nextop,
|
| 1208 |
+
const float* C_multiplier,
|
| 1209 |
+
std::int32_t C_zero_point,
|
| 1210 |
+
std::int32_t Aq_zero_point,
|
| 1211 |
+
const std::int32_t* Bq_zero_point,
|
| 1212 |
+
const std::int32_t* row_offsets,
|
| 1213 |
+
const std::int32_t* col_offsets,
|
| 1214 |
+
const BIAS_T* bias,
|
| 1215 |
+
std::uint32_t nCol,
|
| 1216 |
+
int groups = 1,
|
| 1217 |
+
const float* act_times_w_scale = nullptr)
|
| 1218 |
+
: nextop_(nextop),
|
| 1219 |
+
C_multiplier_(C_multiplier),
|
| 1220 |
+
C_zero_point_(C_zero_point),
|
| 1221 |
+
Aq_zero_point_(Aq_zero_point),
|
| 1222 |
+
Bq_zero_point_(Bq_zero_point),
|
| 1223 |
+
q_row_offsets_(row_offsets),
|
| 1224 |
+
q_col_offsets_(col_offsets),
|
| 1225 |
+
bias_(bias),
|
| 1226 |
+
ncols_(nCol),
|
| 1227 |
+
groups_(groups),
|
| 1228 |
+
act_times_w_scale_(act_times_w_scale) {}
|
| 1229 |
+
|
| 1230 |
+
template <inst_set_t instSet>
|
| 1231 |
+
inline int f(
|
| 1232 |
+
outT* out,
|
| 1233 |
+
const inT* inp,
|
| 1234 |
+
const block_type_t& block,
|
| 1235 |
+
int ld_out,
|
| 1236 |
+
int ld_in) const;
|
| 1237 |
+
|
| 1238 |
+
const float* getCMultiplier() const {
|
| 1239 |
+
return C_multiplier_;
|
| 1240 |
+
}
|
| 1241 |
+
std::int32_t getAZeroPoint() const {
|
| 1242 |
+
return Aq_zero_point_;
|
| 1243 |
+
}
|
| 1244 |
+
std::int32_t getCZeroPoint() const {
|
| 1245 |
+
return C_zero_point_;
|
| 1246 |
+
}
|
| 1247 |
+
const std::int32_t* getBZeroPoint() const {
|
| 1248 |
+
return Bq_zero_point_;
|
| 1249 |
+
}
|
| 1250 |
+
const std::int32_t* getRowOffsets() const {
|
| 1251 |
+
return q_row_offsets_;
|
| 1252 |
+
}
|
| 1253 |
+
const std::int32_t* getColOffsets() const {
|
| 1254 |
+
return q_col_offsets_;
|
| 1255 |
+
}
|
| 1256 |
+
const BIAS_T* getBias() const {
|
| 1257 |
+
return bias_;
|
| 1258 |
+
}
|
| 1259 |
+
std::uint32_t getNCols() const {
|
| 1260 |
+
return ncols_;
|
| 1261 |
+
}
|
| 1262 |
+
const float* getActWScale() const {
|
| 1263 |
+
return act_times_w_scale_;
|
| 1264 |
+
}
|
| 1265 |
+
|
| 1266 |
+
void setRowOffsets(const std::int32_t* row_offsets) {
|
| 1267 |
+
q_row_offsets_ = row_offsets;
|
| 1268 |
+
}
|
| 1269 |
+
|
| 1270 |
+
private:
|
| 1271 |
+
nextOPType& nextop_;
|
| 1272 |
+
const float* C_multiplier_;
|
| 1273 |
+
std::int32_t C_zero_point_;
|
| 1274 |
+
std::int32_t Aq_zero_point_;
|
| 1275 |
+
const std::int32_t* Bq_zero_point_;
|
| 1276 |
+
const std::int32_t* q_row_offsets_;
|
| 1277 |
+
const std::int32_t* q_col_offsets_;
|
| 1278 |
+
const BIAS_T* bias_;
|
| 1279 |
+
std::uint32_t ncols_;
|
| 1280 |
+
int groups_;
|
| 1281 |
+
const float* act_times_w_scale_;
|
| 1282 |
+
};
|
| 1283 |
+
|
| 1284 |
+
/**
|
| 1285 |
+
* @brief Requantize to convert accumulated data to be used as float, i.e., the
|
| 1286 |
+
* output would be used as float.
|
| 1287 |
+
*/
|
| 1288 |
+
template <
|
| 1289 |
+
bool FUSE_RELU,
|
| 1290 |
+
QuantizationGranularity Q_GRAN = QuantizationGranularity::TENSOR,
|
| 1291 |
+
typename outT = float,
|
| 1292 |
+
typename inT = std::int32_t,
|
| 1293 |
+
typename nextOPType = DoNothing<outT, outT>>
|
| 1294 |
+
class FBGEMM_API ReQuantizeForFloat {
|
| 1295 |
+
public:
|
| 1296 |
+
using outType = outT;
|
| 1297 |
+
using inpType = inT;
|
| 1298 |
+
/**
|
| 1299 |
+
* @param Bq_scale The length of this array is
|
| 1300 |
+
* 1 when Q_GRAN == QuantizationGranularity::TENSOR,
|
| 1301 |
+
* groups when Q_GRAN == QuantizationGranularity::GROUP,
|
| 1302 |
+
* nCol if Q_GRAN == QuantizationGranularity::OUT_CHANNEL
|
| 1303 |
+
* @param Bq_zero_point The length of this array should be the same as
|
| 1304 |
+
* Bq_scale.
|
| 1305 |
+
* @param row_offsets Typically, this should've been computed by a
|
| 1306 |
+
* PackAMatrix and should be obtained by
|
| 1307 |
+
* PackMatrix::getRowOffsetBuffer().
|
| 1308 |
+
* If Bq_zero_point == 0 (symmetric quantization of B
|
| 1309 |
+
* matrix), we can pass nullptr.
|
| 1310 |
+
* @param col_offsets This should be pre-computed for example using
|
| 1311 |
+
* col_offsets_with_zero_pt_s8acc32_ref.
|
| 1312 |
+
* The length should be nCol.
|
| 1313 |
+
* See PackedRequantizeTest.cc for an example.
|
| 1314 |
+
* TODO: if Aq_zero_point == 0, allow passing nullptr.
|
| 1315 |
+
* @param bias can be nullptr otherwise the length should be nCol
|
| 1316 |
+
*/
|
| 1317 |
+
ReQuantizeForFloat(
|
| 1318 |
+
nextOPType& nextop,
|
| 1319 |
+
float Aq_scale,
|
| 1320 |
+
const float* Bq_scale,
|
| 1321 |
+
std::int32_t Aq_zero_point,
|
| 1322 |
+
const std::int32_t* Bq_zero_point,
|
| 1323 |
+
const std::int32_t* row_offsets,
|
| 1324 |
+
const std::int32_t* col_offsets,
|
| 1325 |
+
const float* bias,
|
| 1326 |
+
std::uint32_t nCol,
|
| 1327 |
+
int groups = 1)
|
| 1328 |
+
: nextop_(nextop),
|
| 1329 |
+
Aq_scale_(Aq_scale),
|
| 1330 |
+
Bq_scale_(Bq_scale),
|
| 1331 |
+
Aq_zero_point_(Aq_zero_point),
|
| 1332 |
+
Bq_zero_point_(Bq_zero_point),
|
| 1333 |
+
q_row_offsets_(row_offsets),
|
| 1334 |
+
q_col_offsets_(col_offsets),
|
| 1335 |
+
bias_(bias),
|
| 1336 |
+
ncols_(nCol),
|
| 1337 |
+
groups_(groups) {}
|
| 1338 |
+
|
| 1339 |
+
template <inst_set_t instSet>
|
| 1340 |
+
inline int f(
|
| 1341 |
+
outT* out,
|
| 1342 |
+
inT* inp,
|
| 1343 |
+
const block_type_t& block,
|
| 1344 |
+
int ld_out,
|
| 1345 |
+
int ld_in) const;
|
| 1346 |
+
|
| 1347 |
+
private:
|
| 1348 |
+
nextOPType& nextop_;
|
| 1349 |
+
float Aq_scale_;
|
| 1350 |
+
const float* Bq_scale_;
|
| 1351 |
+
std::int32_t Aq_zero_point_;
|
| 1352 |
+
const std::int32_t* Bq_zero_point_;
|
| 1353 |
+
const std::int32_t* q_row_offsets_;
|
| 1354 |
+
const std::int32_t* q_col_offsets_;
|
| 1355 |
+
const float* bias_;
|
| 1356 |
+
std::uint32_t ncols_;
|
| 1357 |
+
int groups_;
|
| 1358 |
+
};
|
| 1359 |
+
|
| 1360 |
+
// type specialized implementation in an include file
|
| 1361 |
+
#include "./OutputProcessing-inl.h" // @manual
|
| 1362 |
+
|
| 1363 |
+
/*
|
| 1364 |
+
*
|
| 1365 |
+
* ####### GEMM related functions #######
|
| 1366 |
+
*
|
| 1367 |
+
*/
|
| 1368 |
+
|
| 1369 |
+
/**
|
| 1370 |
+
* Matrix B must be prepacked. For matrix A, packA.pack function is called to
|
| 1371 |
+
* pack it.
|
| 1372 |
+
*
|
| 1373 |
+
* @tparam packingAMatrix processing of A matrix while packing,
|
| 1374 |
+
* e.g., PackAWithQuantRowOffset
|
| 1375 |
+
*
|
| 1376 |
+
* @tparam packingBMatrix processing of B matrix while packing,
|
| 1377 |
+
* e.g., pre-multiply by alpha
|
| 1378 |
+
* @tparam cT data type of C matrix
|
| 1379 |
+
* @tparam processOutputType further processing of outputs, e.g., Relu
|
| 1380 |
+
*/
|
| 1381 |
+
template <
|
| 1382 |
+
typename packingAMatrix,
|
| 1383 |
+
typename packingBMatrix,
|
| 1384 |
+
typename cT,
|
| 1385 |
+
typename processOutputType>
|
| 1386 |
+
FBGEMM_API void fbgemmPacked(
|
| 1387 |
+
PackMatrix<
|
| 1388 |
+
packingAMatrix,
|
| 1389 |
+
typename packingAMatrix::inpType,
|
| 1390 |
+
typename packingAMatrix::accType>& packA,
|
| 1391 |
+
PackMatrix<
|
| 1392 |
+
packingBMatrix,
|
| 1393 |
+
typename packingBMatrix::inpType,
|
| 1394 |
+
typename packingBMatrix::accType>& packB,
|
| 1395 |
+
cT* C,
|
| 1396 |
+
std::int32_t* C_buffer,
|
| 1397 |
+
std::uint32_t ldc,
|
| 1398 |
+
const processOutputType& outProcess,
|
| 1399 |
+
int thread_id,
|
| 1400 |
+
int num_threads,
|
| 1401 |
+
const BlockingFactors* blocking_params = nullptr);
|
| 1402 |
+
|
| 1403 |
+
/**
|
| 1404 |
+
* @brief Perform small-channels-per-group groupwise convolution
|
| 1405 |
+
* Note: Currently threading is not supported. This function does
|
| 1406 |
+
* nothing for thread_ids > 0, i.e., returns early.
|
| 1407 |
+
*
|
| 1408 |
+
* @param rowOffsetBuf nullptr if B uses symmetric quantization
|
| 1409 |
+
* Note: Currently threading is not supported. This function does
|
| 1410 |
+
* nothing for thread_ids > 0, i.e., returns early.
|
| 1411 |
+
*/
|
| 1412 |
+
template <
|
| 1413 |
+
typename packed_W,
|
| 1414 |
+
typename outType,
|
| 1415 |
+
bool FUSE_RELU,
|
| 1416 |
+
QuantizationGranularity Q_GRAN,
|
| 1417 |
+
int SPATIAL_DIM = 2,
|
| 1418 |
+
typename BIAS_TYPE = std::int32_t>
|
| 1419 |
+
FBGEMM_API void fbgemmGroupwiseConv(
|
| 1420 |
+
const conv_param_t<SPATIAL_DIM>& conv_param,
|
| 1421 |
+
const std::uint8_t* activations,
|
| 1422 |
+
std::int32_t a_zero_point,
|
| 1423 |
+
std::int32_t* rowOffsetBuf,
|
| 1424 |
+
packed_W& packed_weights,
|
| 1425 |
+
outType* out,
|
| 1426 |
+
std::int32_t* outBuffer,
|
| 1427 |
+
const ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE>& outProcess,
|
| 1428 |
+
int thread_id,
|
| 1429 |
+
int num_threads);
|
| 1430 |
+
|
| 1431 |
+
template <
|
| 1432 |
+
int SPATIAL_DIM,
|
| 1433 |
+
QuantizationGranularity Q_GRAN,
|
| 1434 |
+
bool FUSE_RELU,
|
| 1435 |
+
typename BIAS_TYPE = std::int32_t>
|
| 1436 |
+
FBGEMM_API void fbgemmDirectConv(
|
| 1437 |
+
const conv_param_t<SPATIAL_DIM>& conv_p,
|
| 1438 |
+
const uint8_t* Aint8,
|
| 1439 |
+
PackedDirectConvMatrix& Bint8_tr,
|
| 1440 |
+
uint8_t* C,
|
| 1441 |
+
int32_t* C_buffer,
|
| 1442 |
+
const ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE>& outProcess,
|
| 1443 |
+
const BIAS_TYPE* bias,
|
| 1444 |
+
int thread_id,
|
| 1445 |
+
int num_threads);
|
| 1446 |
+
|
| 1447 |
+
/**
|
| 1448 |
+
* @return Size of row offset buffer in number of elements needed for
|
| 1449 |
+
* fbgemmGroupwiseConv
|
| 1450 |
+
*/
|
| 1451 |
+
template <int SPATIAL_DIM = 2>
|
| 1452 |
+
FBGEMM_API int rowOffsetBufferSizeGConv(
|
| 1453 |
+
const conv_param_t<SPATIAL_DIM>& conv_param);
|
| 1454 |
+
|
| 1455 |
+
/**
|
| 1456 |
+
* @brief Is this depthwise convolution optimized?
|
| 1457 |
+
*/
|
| 1458 |
+
template <int SPATIAL_DIM = 2, typename ACC_T = std::int32_t>
|
| 1459 |
+
bool takeDepthWiseFastPath(const conv_param_t<SPATIAL_DIM>& conv_p);
|
| 1460 |
+
|
| 1461 |
+
/**
|
| 1462 |
+
* @brief Is this groupwise convolution supported?
|
| 1463 |
+
*/
|
| 1464 |
+
template <int SPATIAL_DIM>
|
| 1465 |
+
FBGEMM_API bool fbgemmOptimizedGConv(const conv_param_t<SPATIAL_DIM>& conv_p);
|
| 1466 |
+
|
| 1467 |
+
/**
|
| 1468 |
+
* @brief Is this convolution a direct matrix-matrix multiplication, i.e., 1x1
|
| 1469 |
+
* (aka pointwise) with right paddings etc.?
|
| 1470 |
+
*/
|
| 1471 |
+
template <int SPATIAL_DIM>
|
| 1472 |
+
FBGEMM_API bool takePointWiseFastPath(const conv_param_t<SPATIAL_DIM>& conv_p);
|
| 1473 |
+
|
| 1474 |
+
/**
|
| 1475 |
+
* @brief Are we running on a fbgemm supported cpu?
|
| 1476 |
+
*/
|
| 1477 |
+
FBGEMM_API bool fbgemmSupportedCPU();
|
| 1478 |
+
|
| 1479 |
+
/**
|
| 1480 |
+
* @brief Performs convolution using fastest path available.
|
| 1481 |
+
*
|
| 1482 |
+
* @tparam SPATIAL_DIM It's 2 for 2D convolutions and 3 for 3D convolutions.
|
| 1483 |
+
*/
|
| 1484 |
+
template <
|
| 1485 |
+
typename processOutputType,
|
| 1486 |
+
int SPATIAL_DIM = 2,
|
| 1487 |
+
typename ACC_T = std::int32_t>
|
| 1488 |
+
FBGEMM_API int fbgemmConv(
|
| 1489 |
+
const conv_param_t<SPATIAL_DIM>& conv_p,
|
| 1490 |
+
const std::uint8_t* activations,
|
| 1491 |
+
PackWeightsForConv<SPATIAL_DIM, std::int8_t, ACC_T>& packed_weights,
|
| 1492 |
+
typename processOutputType::outType* out,
|
| 1493 |
+
std::int32_t* outBuffer,
|
| 1494 |
+
processOutputType& outProcess,
|
| 1495 |
+
int thread_id,
|
| 1496 |
+
int num_threads,
|
| 1497 |
+
const BlockingFactors* blocking_params = nullptr);
|
| 1498 |
+
|
| 1499 |
+
/**
|
| 1500 |
+
* @brief Returns which fast path to take
|
| 1501 |
+
*
|
| 1502 |
+
* @tparam SPATIAL_DIM It's 2 for 2D convolutions and 3 for 3D convolutions.
|
| 1503 |
+
*
|
| 1504 |
+
* @return optimized_conv_t::depthwise, optimized_conv_t::groupwise or
|
| 1505 |
+
* optimized_conv_t::im2col
|
| 1506 |
+
*
|
| 1507 |
+
*/
|
| 1508 |
+
template <int SPATIAL_DIM = 2, typename ACC_T = std::int32_t>
|
| 1509 |
+
FBGEMM_API optimized_conv_t
|
| 1510 |
+
ConvFastPath(const conv_param_t<SPATIAL_DIM>& conv_p);
|
| 1511 |
+
} // namespace fbgemm
|
| 1512 |
+
|
| 1513 |
+
#else
|
| 1514 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 1515 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmBuild.h
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
/*
|
| 3 |
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 4 |
+
* All rights reserved.
|
| 5 |
+
*
|
| 6 |
+
* This source code is licensed under the BSD-style license found in the
|
| 7 |
+
* LICENSE file in the root directory of this source tree.
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
// For details about dllexport/dllimport, checkout the following SO question
|
| 13 |
+
// https://stackoverflow.com/questions/57999/what-is-the-difference-between-dllexport-and-dllimport
|
| 14 |
+
#if !defined(FBGEMM_API)
|
| 15 |
+
#if defined(FBGEMM_STATIC)
|
| 16 |
+
#define FBGEMM_API
|
| 17 |
+
#define FBGEMM_ENUM_CLASS_API
|
| 18 |
+
#elif defined _WIN32 || defined __CYGWIN__
|
| 19 |
+
#if (__GNUC__ || __clang__) && !(__MINGW64__ || __MINGW32__)
|
| 20 |
+
#if defined(FBGEMM_EXPORTS)
|
| 21 |
+
#define FBGEMM_API __attribute__((__dllexport__))
|
| 22 |
+
#else
|
| 23 |
+
#define FBGEMM_API __attribute__((__dllimport__))
|
| 24 |
+
#endif
|
| 25 |
+
#else
|
| 26 |
+
#if defined(FBGEMM_EXPORTS)
|
| 27 |
+
#define FBGEMM_API __declspec(dllexport)
|
| 28 |
+
#else
|
| 29 |
+
#define FBGEMM_API __declspec(dllimport)
|
| 30 |
+
#endif
|
| 31 |
+
#endif
|
| 32 |
+
#define FBGEMM_ENUM_CLASS_API
|
| 33 |
+
#else
|
| 34 |
+
#if __clang__ || __GNUC__ || __INTEL_COMPILER
|
| 35 |
+
#define FBGEMM_API __attribute__((__visibility__("default")))
|
| 36 |
+
#else
|
| 37 |
+
#define FBGEMM_API
|
| 38 |
+
#endif
|
| 39 |
+
// Currently, enum classes need to be declaredly explicitly for shared build on
|
| 40 |
+
// macos
|
| 41 |
+
#if __clang__
|
| 42 |
+
#define FBGEMM_ENUM_CLASS_API __attribute__((__visibility__("default")))
|
| 43 |
+
#else
|
| 44 |
+
#define FBGEMM_ENUM_CLASS_API
|
| 45 |
+
#endif
|
| 46 |
+
#endif
|
| 47 |
+
#endif
|
| 48 |
+
|
| 49 |
+
// Use this to indicate to not inline functions
|
| 50 |
+
#if __clang__ || __GNUC__ || __INTEL_COMPILER
|
| 51 |
+
#define NOINLINE __attribute__((noinline))
|
| 52 |
+
#elif _MSC_VER
|
| 53 |
+
#define NOINLINE __declspec(noinline)
|
| 54 |
+
#else
|
| 55 |
+
#define NOINLINE
|
| 56 |
+
#endif
|
| 57 |
+
|
| 58 |
+
// Use this to indicate always inline functions
|
| 59 |
+
#if __clang__ || __GNUC__ || __INTEL_COMPILER
|
| 60 |
+
#define ALWAYS_INLINE inline __attribute__((__always_inline__))
|
| 61 |
+
#elif _MSC_VER
|
| 62 |
+
// commenting out because __forceinline takes too long time in MSVC
|
| 63 |
+
#define ALWAYS_INLINE // __forceinline
|
| 64 |
+
#else
|
| 65 |
+
#define ALWAYS_INLINE inline
|
| 66 |
+
#endif
|
| 67 |
+
|
| 68 |
+
// Use the C++11 keyword "alignas" if you can
|
| 69 |
+
#if _MSC_VER
|
| 70 |
+
#define ALIGNAS(byte_alignment) __declspec(align(byte_alignment))
|
| 71 |
+
#else
|
| 72 |
+
#define ALIGNAS(byte_alignment) __attribute__((aligned(byte_alignment)))
|
| 73 |
+
#endif
|
| 74 |
+
|
| 75 |
+
// Sanitizers annotations
|
| 76 |
+
#if defined(__has_attribute)
|
| 77 |
+
#if __has_attribute(no_sanitize)
|
| 78 |
+
#define NO_SANITIZE(what) __attribute__((no_sanitize(what)))
|
| 79 |
+
#endif
|
| 80 |
+
#endif
|
| 81 |
+
#if !defined(NO_SANITIZE)
|
| 82 |
+
#define NO_SANITIZE(what)
|
| 83 |
+
#endif
|
| 84 |
+
|
| 85 |
+
// Ignore __builtin_assume() when not supported by compiler.
|
| 86 |
+
#ifndef __has_builtin
|
| 87 |
+
#define __has_builtin(x) 0
|
| 88 |
+
#endif
|
| 89 |
+
#if !__has_builtin(__builtin_assume)
|
| 90 |
+
#define __builtin_assume(x) (static_cast<void>(0))
|
| 91 |
+
#endif
|
| 92 |
+
|
| 93 |
+
// Macro for silencing warnings
|
| 94 |
+
#if __clang__ || __GNUC__
|
| 95 |
+
// clang-format off
|
| 96 |
+
#define FBGEMM_PUSH_WARNING _Pragma("GCC diagnostic push")
|
| 97 |
+
#define FBGEMM_DISABLE_WARNING_INTERNAL2(warningName) #warningName
|
| 98 |
+
#define FBGEMM_DISABLE_WARNING(warningName) \
|
| 99 |
+
_Pragma( \
|
| 100 |
+
FBGEMM_DISABLE_WARNING_INTERNAL2(GCC diagnostic ignored warningName))
|
| 101 |
+
#define FBGEMM_PUSH_WARNING_AND_DISABLE(warningName) \
|
| 102 |
+
_Pragma("GCC diagnostic push") \
|
| 103 |
+
_Pragma( \
|
| 104 |
+
FBGEMM_DISABLE_WARNING_INTERNAL2(GCC diagnostic ignored warningName))
|
| 105 |
+
#define FBGEMM_POP_WARNING _Pragma("GCC diagnostic pop")
|
| 106 |
+
// clang-format on
|
| 107 |
+
#else
|
| 108 |
+
#define FBGEMM_PUSH_WARNING
|
| 109 |
+
#define FBGEMM_DISABLE_WARNING(NAME)
|
| 110 |
+
#define FBGEMM_PUSH_WARNING_AND_DISABLE(NAME)
|
| 111 |
+
#define FBGEMM_POP_WARNING
|
| 112 |
+
#endif
|
| 113 |
+
|
| 114 |
+
#else
|
| 115 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 116 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmConvert.h
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
/*
|
| 3 |
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 4 |
+
* All rights reserved.
|
| 5 |
+
*
|
| 6 |
+
* This source code is licensed under the BSD-style license found in the
|
| 7 |
+
* LICENSE file in the root directory of this source tree.
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include <cstddef>
|
| 13 |
+
#include <cstdint>
|
| 14 |
+
#include "fbgemm/FbgemmBuild.h"
|
| 15 |
+
#include "fbgemm/Types.h"
|
| 16 |
+
|
| 17 |
+
namespace fbgemm {
|
| 18 |
+
|
| 19 |
+
/**
|
| 20 |
+
* @ Transform all entries in a matrix from fp32 to bfloat16: reference
|
| 21 |
+
* implementation.
|
| 22 |
+
*
|
| 23 |
+
*/
|
| 24 |
+
FBGEMM_API void
|
| 25 |
+
FloatToBfloat16_ref(const float* src, bfloat16* dst, size_t size);
|
| 26 |
+
|
| 27 |
+
/**
|
| 28 |
+
* @ Transform all entries in a matrix from bfloat16 to fp32: reference
|
| 29 |
+
* implementation.
|
| 30 |
+
*
|
| 31 |
+
*/
|
| 32 |
+
FBGEMM_API void
|
| 33 |
+
Bfloat16ToFloat_ref(const bfloat16* src, float* dst, size_t size);
|
| 34 |
+
|
| 35 |
+
/**
|
| 36 |
+
* @ Transform all entries in a matrix from fp32 to bfloat16: simd
|
| 37 |
+
* implementation.
|
| 38 |
+
*
|
| 39 |
+
*/
|
| 40 |
+
FBGEMM_API void
|
| 41 |
+
FloatToBfloat16_simd(const float* src, bfloat16* dst, size_t size);
|
| 42 |
+
|
| 43 |
+
/**
|
| 44 |
+
* @ Transform all entries in a matrix from bfloat16 to fp32: simd
|
| 45 |
+
* implementation.
|
| 46 |
+
*
|
| 47 |
+
*/
|
| 48 |
+
FBGEMM_API void
|
| 49 |
+
Bfloat16ToFloat_simd(const bfloat16* src, float* dst, size_t size);
|
| 50 |
+
|
| 51 |
+
#if !defined(__aarch64__)
|
| 52 |
+
/**
|
| 53 |
+
* @brief AVX2 implementation to convert fp32 numbers to bf16 numbers.
|
| 54 |
+
*
|
| 55 |
+
*/
|
| 56 |
+
FBGEMM_API void
|
| 57 |
+
FloatToBfloat16_avx2(const float* src, bfloat16* dst, size_t size);
|
| 58 |
+
|
| 59 |
+
/**
|
| 60 |
+
* @brief AVX512 implementation to convert fp32 numbers to bf16 numbers.
|
| 61 |
+
*
|
| 62 |
+
*/
|
| 63 |
+
FBGEMM_API void
|
| 64 |
+
FloatToBfloat16_avx512(const float* src, bfloat16* dst, size_t size);
|
| 65 |
+
|
| 66 |
+
/**
|
| 67 |
+
* @brief AVX2 implementation to convert bf16 numbers to fp32 numbers.
|
| 68 |
+
*
|
| 69 |
+
*/
|
| 70 |
+
FBGEMM_API void
|
| 71 |
+
Bfloat16ToFloat_avx2(const bfloat16* src, float* dst, size_t size);
|
| 72 |
+
|
| 73 |
+
/**
|
| 74 |
+
* @brief AVX512 implementation to convert bf16 numbers to fp32 numbers.
|
| 75 |
+
*
|
| 76 |
+
*/
|
| 77 |
+
FBGEMM_API void
|
| 78 |
+
Bfloat16ToFloat_avx512(const bfloat16* src, float* dst, size_t size);
|
| 79 |
+
#endif
|
| 80 |
+
|
| 81 |
+
/**
|
| 82 |
+
* @ Transform all entries in a matrix from fp32 to float16: reference
|
| 83 |
+
* implementation.
|
| 84 |
+
*
|
| 85 |
+
* @param do_clip if true we saturate to fp16 min and max instead of generating
|
| 86 |
+
* infinities.
|
| 87 |
+
*/
|
| 88 |
+
FBGEMM_API void FloatToFloat16_ref(
|
| 89 |
+
const float* src,
|
| 90 |
+
float16* dst,
|
| 91 |
+
size_t size,
|
| 92 |
+
bool do_clip = false);
|
| 93 |
+
|
| 94 |
+
/**
|
| 95 |
+
* @ Transform all entries in a matrix from float16 to fp32: reference
|
| 96 |
+
* implementation.
|
| 97 |
+
*
|
| 98 |
+
*/
|
| 99 |
+
FBGEMM_API void Float16ToFloat_ref(const float16* src, float* dst, size_t size);
|
| 100 |
+
|
| 101 |
+
/**
|
| 102 |
+
* @ Transform all entries in a matrix from fp32 to float16: simd
|
| 103 |
+
* implementation.
|
| 104 |
+
*
|
| 105 |
+
* @param do_clip if true we saturate to fp16 min and max instead of generating
|
| 106 |
+
* infinities.
|
| 107 |
+
*/
|
| 108 |
+
FBGEMM_API void FloatToFloat16_simd(
|
| 109 |
+
const float* src,
|
| 110 |
+
float16* dst,
|
| 111 |
+
size_t size,
|
| 112 |
+
bool do_clip = false);
|
| 113 |
+
|
| 114 |
+
/**
|
| 115 |
+
* @ Transform all entries in a matrix from float16 to fp32: simd
|
| 116 |
+
* implementation.
|
| 117 |
+
*
|
| 118 |
+
*/
|
| 119 |
+
FBGEMM_API void
|
| 120 |
+
Float16ToFloat_simd(const float16* src, float* dst, size_t size);
|
| 121 |
+
|
| 122 |
+
/**
|
| 123 |
+
* @brief AVX2 implementation to convert fp32 numbers to fp16 numbers.
|
| 124 |
+
*
|
| 125 |
+
*/
|
| 126 |
+
#if !defined(__aarch64__)
|
| 127 |
+
FBGEMM_API void FloatToFloat16_avx2(
|
| 128 |
+
const float* src,
|
| 129 |
+
float16* dst,
|
| 130 |
+
size_t size,
|
| 131 |
+
bool do_clip = false);
|
| 132 |
+
|
| 133 |
+
/**
|
| 134 |
+
* @brief AVX512 implementation to convert fp32 numbers to fp16 numbers.
|
| 135 |
+
*
|
| 136 |
+
*/
|
| 137 |
+
FBGEMM_API void FloatToFloat16_avx512(
|
| 138 |
+
const float* src,
|
| 139 |
+
float16* dst,
|
| 140 |
+
size_t size,
|
| 141 |
+
bool do_clip = false);
|
| 142 |
+
#endif
|
| 143 |
+
|
| 144 |
+
/**
|
| 145 |
+
* @brief SVE2 implementation to convert fp32 numbers to fp16 numbers.
|
| 146 |
+
*
|
| 147 |
+
*/
|
| 148 |
+
FBGEMM_API void FloatToFloat16_sve2(
|
| 149 |
+
const float* src,
|
| 150 |
+
float16* dst,
|
| 151 |
+
size_t size,
|
| 152 |
+
bool do_clip = false);
|
| 153 |
+
|
| 154 |
+
#if !defined(__aarch64__)
|
| 155 |
+
/**
|
| 156 |
+
* @brief AVX2 implementation to convert fp16 numbers to fp32 numbers.
|
| 157 |
+
*
|
| 158 |
+
*/
|
| 159 |
+
FBGEMM_API void
|
| 160 |
+
Float16ToFloat_avx2(const float16* src, float* dst, size_t size);
|
| 161 |
+
|
| 162 |
+
/**
|
| 163 |
+
* @brief AVX512 implementation to convert fp16 numbers to fp32 numbers.
|
| 164 |
+
*
|
| 165 |
+
*/
|
| 166 |
+
FBGEMM_API void
|
| 167 |
+
Float16ToFloat_avx512(const float16* src, float* dst, size_t size);
|
| 168 |
+
#endif
|
| 169 |
+
|
| 170 |
+
/**
|
| 171 |
+
* @brief Transform all entries in a matrix from fp32 to float16 and back to
|
| 172 |
+
* fp32.
|
| 173 |
+
*/
|
| 174 |
+
FBGEMM_API void RoundToFloat16(
|
| 175 |
+
const float* input,
|
| 176 |
+
float* output,
|
| 177 |
+
size_t size,
|
| 178 |
+
bool clamp = false,
|
| 179 |
+
bool clamp_denorms = false);
|
| 180 |
+
|
| 181 |
+
/**
|
| 182 |
+
* @brief Quantize float32 to float8. The code is a copy of float_to_hfp8() in
|
| 183 |
+
* fbgemm_gpu/quantize_ops_utils.h
|
| 184 |
+
*/
|
| 185 |
+
FBGEMM_API void FloatToFloat8_ref(
|
| 186 |
+
float input,
|
| 187 |
+
uint8_t* output,
|
| 188 |
+
int exponent_bits,
|
| 189 |
+
int exponent_bias);
|
| 190 |
+
|
| 191 |
+
/**
|
| 192 |
+
* @brief Dequantize float8 to float32. The code is a copy of hf8_to_float() in
|
| 193 |
+
* fbgemm_gpu/quantize_ops_utils.h
|
| 194 |
+
*/
|
| 195 |
+
FBGEMM_API void Float8ToFloat_ref(
|
| 196 |
+
uint8_t input,
|
| 197 |
+
float* output,
|
| 198 |
+
int exponent_bits,
|
| 199 |
+
int exponent_bias);
|
| 200 |
+
|
| 201 |
+
} // namespace fbgemm
|
| 202 |
+
|
| 203 |
+
#else
|
| 204 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 205 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmEmbedding.h
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
/*
|
| 3 |
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 4 |
+
* All rights reserved.
|
| 5 |
+
*
|
| 6 |
+
* This source code is licensed under the BSD-style license found in the
|
| 7 |
+
* LICENSE file in the root directory of this source tree.
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#pragma once
|
| 11 |
+
#include <cstdint>
|
| 12 |
+
#include <functional>
|
| 13 |
+
|
| 14 |
+
#include "fbgemm/FbgemmBuild.h"
|
| 15 |
+
|
| 16 |
+
namespace fbgemm {
|
| 17 |
+
|
| 18 |
+
template <
|
| 19 |
+
typename InType,
|
| 20 |
+
typename IndexType,
|
| 21 |
+
typename OffsetType = std::int32_t,
|
| 22 |
+
typename OutType = float>
|
| 23 |
+
class EmbeddingSpMDMKernelSignature {
|
| 24 |
+
public:
|
| 25 |
+
/**
|
| 26 |
+
* Behavior is as the follow pseudocode
|
| 27 |
+
* (when use_offsets == true, lengths[i] == offsets[i + 1] - offsets[i])
|
| 28 |
+
* (when is_weight_positional == true, use weights[j - offsets[i]] instead of
|
| 29 |
+
* weights[j])
|
| 30 |
+
*
|
| 31 |
+
* for i in range(output_size):
|
| 32 |
+
* out[i * block_size : (i + 1) * block_size] = 0
|
| 33 |
+
* for j in range(offsets[i], offsets[i + 1]):
|
| 34 |
+
* for k in range(block_size):
|
| 35 |
+
* out[i * block_size + k] += input[indices[j] * block_size + k] *
|
| 36 |
+
* weights ? weights[j] : 1;
|
| 37 |
+
* if normalize_weights and lengths[i] > 0:
|
| 38 |
+
* out[i * block_size : (i + 1) * block_size] /= lengths[i]
|
| 39 |
+
*
|
| 40 |
+
* @param data_size the number of rows in embedding table
|
| 41 |
+
*/
|
| 42 |
+
using Type = std::function<bool(
|
| 43 |
+
std::int64_t output_size,
|
| 44 |
+
std::int64_t index_size,
|
| 45 |
+
std::int64_t data_size,
|
| 46 |
+
const InType* input,
|
| 47 |
+
const IndexType* indices,
|
| 48 |
+
const OffsetType* offsets_or_lengths,
|
| 49 |
+
const float* weights, // optional, can be null for non-weighted sum
|
| 50 |
+
OutType* out)>;
|
| 51 |
+
};
|
| 52 |
+
|
| 53 |
+
/**
|
| 54 |
+
* @tparam InType can be float, float16, or uint8_t
|
| 55 |
+
* @tparam IndexType can be int32_t or int64_t
|
| 56 |
+
* @tparam IndexType can be int32_t or int64_t
|
| 57 |
+
*
|
| 58 |
+
* @param use_offsets If true, the generated code assumes we will pass offsets
|
| 59 |
+
* instead of lengths that confirms PyTorch EmbeddingBag
|
| 60 |
+
* interface. In this case, the length of offsets array
|
| 61 |
+
* should be output_size + 1 and offsets[output_size] should
|
| 62 |
+
* be index_size.
|
| 63 |
+
* If false, the generate code assumes we will pass lengths
|
| 64 |
+
* that confirms Caffe2 SparseLengthsSum interface.
|
| 65 |
+
*/
|
| 66 |
+
template <
|
| 67 |
+
typename InType,
|
| 68 |
+
typename IndexType,
|
| 69 |
+
typename OffsetType = std::int32_t,
|
| 70 |
+
typename OutType = float,
|
| 71 |
+
bool THREAD_LOCAL = false>
|
| 72 |
+
FBGEMM_API typename EmbeddingSpMDMKernelSignature<
|
| 73 |
+
InType,
|
| 74 |
+
IndexType,
|
| 75 |
+
OffsetType,
|
| 76 |
+
OutType>::Type
|
| 77 |
+
GenerateEmbeddingSpMDM(
|
| 78 |
+
const std::int64_t block_size,
|
| 79 |
+
bool has_weight,
|
| 80 |
+
bool normalize_by_lengths,
|
| 81 |
+
int prefetch = 16,
|
| 82 |
+
bool is_weight_positional = false,
|
| 83 |
+
bool use_offsets = true,
|
| 84 |
+
bool is_bf16_out = false,
|
| 85 |
+
bool is_bf16_in = false);
|
| 86 |
+
|
| 87 |
+
/**
|
| 88 |
+
* @param output_stride If -1, output_stride is same as block_size
|
| 89 |
+
* @param input_stride If -1, input_stride is same as block_size
|
| 90 |
+
* @param scale_bias_last if false, scale and bias appear at the beginning
|
| 91 |
+
* of each row and are in fp16 for table batched embedding (TBE)
|
| 92 |
+
* in FBGEMM_GPU. If false, it can also take -1 indices (output from
|
| 93 |
+
* pruned embedding id mapping)
|
| 94 |
+
*/
|
| 95 |
+
template <
|
| 96 |
+
typename InType,
|
| 97 |
+
typename IndexType,
|
| 98 |
+
typename OffsetType = std::int32_t,
|
| 99 |
+
typename OutType = float,
|
| 100 |
+
bool THREAD_LOCAL = false>
|
| 101 |
+
FBGEMM_API typename EmbeddingSpMDMKernelSignature<
|
| 102 |
+
InType,
|
| 103 |
+
IndexType,
|
| 104 |
+
OffsetType,
|
| 105 |
+
OutType>::Type
|
| 106 |
+
GenerateEmbeddingSpMDMWithStrides(
|
| 107 |
+
const std::int64_t block_size,
|
| 108 |
+
bool has_weight,
|
| 109 |
+
bool normalize_by_lengths,
|
| 110 |
+
int prefetch = 16,
|
| 111 |
+
bool is_weight_positional = false,
|
| 112 |
+
bool use_offsets = true,
|
| 113 |
+
std::int64_t output_stride = -1,
|
| 114 |
+
std::int64_t input_stride = -1,
|
| 115 |
+
bool scale_bias_last = true,
|
| 116 |
+
bool no_bag = false,
|
| 117 |
+
bool is_bf16_out = false,
|
| 118 |
+
bool is_bf16_in = false);
|
| 119 |
+
|
| 120 |
+
/**
|
| 121 |
+
* @tparam IndexType can be int32_t or int64_t
|
| 122 |
+
* @tparam OffsetType can be int32_t or int64_t
|
| 123 |
+
* @param bit_rate can be 2 or 4
|
| 124 |
+
*/
|
| 125 |
+
template <
|
| 126 |
+
typename IndexType,
|
| 127 |
+
typename OffsetType = std::int32_t,
|
| 128 |
+
typename OutType = float>
|
| 129 |
+
FBGEMM_API typename EmbeddingSpMDMKernelSignature<
|
| 130 |
+
std::uint8_t,
|
| 131 |
+
IndexType,
|
| 132 |
+
OffsetType,
|
| 133 |
+
OutType>::Type
|
| 134 |
+
GenerateEmbeddingSpMDMNBit(
|
| 135 |
+
int bit_rate,
|
| 136 |
+
const std::int64_t block_size,
|
| 137 |
+
bool has_weight,
|
| 138 |
+
bool normalize_by_lengths,
|
| 139 |
+
int prefetch = 16,
|
| 140 |
+
bool is_weight_positional = false,
|
| 141 |
+
bool use_offsets = true);
|
| 142 |
+
|
| 143 |
+
/**
|
| 144 |
+
* @param output_stride If -1, output_stride is same as block_size
|
| 145 |
+
* @param input_stride in Bytes. If -1, input_stride is same as
|
| 146 |
+
* block_size / num_elem_per_byte + 2 * sizeof(float16)
|
| 147 |
+
* @param scale_bias_last if false, scale and bias appear at the beginning
|
| 148 |
+
* of each row and are in fp16 for table batched embedding (TBE)
|
| 149 |
+
* in FBGEMM_GPU. If false, it can also take -1 indices (output from
|
| 150 |
+
* pruned embedding id mapping)
|
| 151 |
+
*/
|
| 152 |
+
template <
|
| 153 |
+
typename IndexType,
|
| 154 |
+
typename OffsetType = std::int32_t,
|
| 155 |
+
typename OutType = float,
|
| 156 |
+
bool THREAD_LOCAL = false>
|
| 157 |
+
FBGEMM_API typename EmbeddingSpMDMKernelSignature<
|
| 158 |
+
std::uint8_t,
|
| 159 |
+
IndexType,
|
| 160 |
+
OffsetType,
|
| 161 |
+
OutType>::Type
|
| 162 |
+
GenerateEmbeddingSpMDMNBitWithStrides(
|
| 163 |
+
const int input_bit_rate,
|
| 164 |
+
const std::int64_t block_size,
|
| 165 |
+
bool has_weight,
|
| 166 |
+
bool normalize_by_lengths,
|
| 167 |
+
int prefetch = 16,
|
| 168 |
+
bool is_weight_positional = false,
|
| 169 |
+
bool use_offsets = true,
|
| 170 |
+
std::int64_t output_stride = -1,
|
| 171 |
+
std::int64_t input_stride = -1,
|
| 172 |
+
bool scale_bias_last = true,
|
| 173 |
+
const bool is_bf16_out = false,
|
| 174 |
+
const bool no_bag = false,
|
| 175 |
+
int output_bit_rate = -1);
|
| 176 |
+
|
| 177 |
+
/**
|
| 178 |
+
* @param output_stride If -1, output_stride is same as block_size
|
| 179 |
+
* @param input_stride in Bytes. If -1, input_stride is same as
|
| 180 |
+
* block_size / num_elem_per_byte + 2 * sizeof(float16)
|
| 181 |
+
* @param exponent_bits is the number of exponent bits in the FP8 encode
|
| 182 |
+
* (normally 4 or 5)
|
| 183 |
+
* @param exponent_bias is subtracted from the exponent to obtain the actual
|
| 184 |
+
* exponent for the floating-point number
|
| 185 |
+
*/
|
| 186 |
+
template <
|
| 187 |
+
typename IndexType,
|
| 188 |
+
typename OffsetType = std::int32_t,
|
| 189 |
+
typename OutType = float>
|
| 190 |
+
FBGEMM_API typename EmbeddingSpMDMKernelSignature<
|
| 191 |
+
std::uint8_t,
|
| 192 |
+
IndexType,
|
| 193 |
+
OffsetType,
|
| 194 |
+
OutType>::Type
|
| 195 |
+
GenerateEmbeddingSpMDMFP8WithStrides(
|
| 196 |
+
const std::int64_t block_size,
|
| 197 |
+
bool normalize_by_lengths,
|
| 198 |
+
bool is_weight_positional = false,
|
| 199 |
+
bool use_offsets = true,
|
| 200 |
+
std::int64_t output_stride = -1,
|
| 201 |
+
std::int64_t input_stride = -1,
|
| 202 |
+
int exponent_bits = 4,
|
| 203 |
+
int exponent_bias = 7,
|
| 204 |
+
bool is_bf16_out = false);
|
| 205 |
+
|
| 206 |
+
template <
|
| 207 |
+
typename InType,
|
| 208 |
+
typename IndexType,
|
| 209 |
+
typename OffsetType = std::int32_t>
|
| 210 |
+
class EmbeddingSpMDMRowWiseSparseKernelSignature {
|
| 211 |
+
public:
|
| 212 |
+
using Type = std::function<bool(
|
| 213 |
+
std::int64_t output_size,
|
| 214 |
+
std::int64_t index_size,
|
| 215 |
+
std::int64_t uncompressed_data_size,
|
| 216 |
+
// TODO: add compressed_data_size and check array bound
|
| 217 |
+
const InType* input,
|
| 218 |
+
const IndexType* indices,
|
| 219 |
+
const OffsetType* offsets_or_lengths,
|
| 220 |
+
const float* weights, // optional, can be null for non-weighted sum
|
| 221 |
+
float* out,
|
| 222 |
+
const std::int32_t* compressed_indices_table)>;
|
| 223 |
+
};
|
| 224 |
+
|
| 225 |
+
/**
|
| 226 |
+
* @tparam InType can be float, float16, or uint8_t
|
| 227 |
+
* @tparam IndexType can be int32_t or int64_t
|
| 228 |
+
* @tparam OffsetType can be int32_t or int64_t
|
| 229 |
+
*/
|
| 230 |
+
template <
|
| 231 |
+
typename InType,
|
| 232 |
+
typename IndexType,
|
| 233 |
+
typename OffsetType = std::int32_t>
|
| 234 |
+
FBGEMM_API typename EmbeddingSpMDMRowWiseSparseKernelSignature<
|
| 235 |
+
InType,
|
| 236 |
+
IndexType,
|
| 237 |
+
OffsetType>::Type
|
| 238 |
+
GenerateEmbeddingSpMDMRowWiseSparse(
|
| 239 |
+
const std::int64_t block_size,
|
| 240 |
+
bool has_weight,
|
| 241 |
+
bool normalize_by_lengths,
|
| 242 |
+
int prefetch = 16,
|
| 243 |
+
bool is_weight_positional = false,
|
| 244 |
+
bool use_offsets = true);
|
| 245 |
+
|
| 246 |
+
/**
|
| 247 |
+
* @tparam IndexType can be int32_t or int64_t
|
| 248 |
+
* @tparam OffsetType can be int32_t or int64_t
|
| 249 |
+
* @param bit_rate can be 2 or 4
|
| 250 |
+
*/
|
| 251 |
+
template <typename IndexType, typename OffsetType = std::int32_t>
|
| 252 |
+
FBGEMM_API typename EmbeddingSpMDMRowWiseSparseKernelSignature<
|
| 253 |
+
std::uint8_t,
|
| 254 |
+
IndexType,
|
| 255 |
+
OffsetType>::Type
|
| 256 |
+
GenerateEmbeddingSpMDMNBitRowWiseSparse(
|
| 257 |
+
int bit_rate,
|
| 258 |
+
const std::int64_t block_size,
|
| 259 |
+
bool has_weight,
|
| 260 |
+
bool normalize_by_lengths,
|
| 261 |
+
int prefetch = 16,
|
| 262 |
+
bool is_weight_positional = false,
|
| 263 |
+
bool use_offsets = true);
|
| 264 |
+
|
| 265 |
+
/**
|
| 266 |
+
* @return The number of rows processed. If smaller than num_rows, an error
|
| 267 |
+
* must have happened at the last row processed.
|
| 268 |
+
*/
|
| 269 |
+
template <typename IndexType>
|
| 270 |
+
class SparseAdaGradSignature {
|
| 271 |
+
public:
|
| 272 |
+
using Type = std::function<int(
|
| 273 |
+
int num_rows, // number of rows reading
|
| 274 |
+
std::uint64_t param_size, // total number of parameters
|
| 275 |
+
float* w, // input/output parameters
|
| 276 |
+
const float* g, // input gradients
|
| 277 |
+
float* h, // input/output momentums
|
| 278 |
+
const IndexType* indices, // indices of each row
|
| 279 |
+
float epsilon,
|
| 280 |
+
float lr,
|
| 281 |
+
float weight_decay,
|
| 282 |
+
const double* counter, // used for weight_decay adjusted for frequency
|
| 283 |
+
// nullptr when frequency adjustment is not used.
|
| 284 |
+
// ignored when the kernel is generated with
|
| 285 |
+
// use_weight_decay = false.
|
| 286 |
+
std::int64_t counter_halflife)>; // frequency adjust happens only after
|
| 287 |
+
};
|
| 288 |
+
|
| 289 |
+
template <typename IndexType>
|
| 290 |
+
FBGEMM_API typename SparseAdaGradSignature<IndexType>::Type
|
| 291 |
+
GenerateSparseAdaGrad(
|
| 292 |
+
int block_size, // number of parameters per row
|
| 293 |
+
bool rowwise = false,
|
| 294 |
+
int prefetch = 16,
|
| 295 |
+
bool use_weight_decay = false);
|
| 296 |
+
|
| 297 |
+
// RowWiseSparseAdaGrad fused with SLS gradient
|
| 298 |
+
// Weights can be either float or float16
|
| 299 |
+
template <
|
| 300 |
+
typename IndexType,
|
| 301 |
+
typename OffsetType = std::int32_t,
|
| 302 |
+
typename DataType = float>
|
| 303 |
+
class RowWiseSparseAdaGradFusedSignature {
|
| 304 |
+
public:
|
| 305 |
+
using Type = std::function<bool(
|
| 306 |
+
std::int64_t output_size,
|
| 307 |
+
std::int64_t index_size,
|
| 308 |
+
std::int64_t data_size, // number of rows in w
|
| 309 |
+
DataType* w, // input/output parameters
|
| 310 |
+
const float* g, // input gradients
|
| 311 |
+
float* h, // input/output momentums
|
| 312 |
+
const IndexType* indices, // indices of each row
|
| 313 |
+
const OffsetType* offsets_or_lengths,
|
| 314 |
+
float epsilon,
|
| 315 |
+
float lr)>;
|
| 316 |
+
};
|
| 317 |
+
|
| 318 |
+
/**
|
| 319 |
+
* @param grad_stride If -1, grad_stride is same as block size
|
| 320 |
+
*/
|
| 321 |
+
template <
|
| 322 |
+
typename IndexType,
|
| 323 |
+
typename OffsetType = std::int32_t,
|
| 324 |
+
typename DataType = float>
|
| 325 |
+
FBGEMM_API typename RowWiseSparseAdaGradFusedSignature<
|
| 326 |
+
IndexType,
|
| 327 |
+
OffsetType,
|
| 328 |
+
DataType>::Type
|
| 329 |
+
GenerateRowWiseSparseAdaGradFused(
|
| 330 |
+
int block_size, // number of parameters per row
|
| 331 |
+
int prefetch = 16,
|
| 332 |
+
bool use_offsets = true,
|
| 333 |
+
bool use_stochastic_rounding = true,
|
| 334 |
+
int grad_stride = -1);
|
| 335 |
+
|
| 336 |
+
namespace internal {
|
| 337 |
+
// Specialization for block size 1 internally called by GenerateEmbeddingSpMDM
|
| 338 |
+
template <typename InType, typename IndexType, typename OffsetType>
|
| 339 |
+
FBGEMM_API bool EmbeddingSpMDMBlockSize1_(
|
| 340 |
+
const std::int64_t output_size,
|
| 341 |
+
const std::int64_t index_size,
|
| 342 |
+
const std::int64_t data_size, // the number of rows in input
|
| 343 |
+
const InType* input,
|
| 344 |
+
const IndexType* indices,
|
| 345 |
+
const OffsetType* offsets_or_lengths,
|
| 346 |
+
const float* weights, // optional, can be null for non-weighted sum
|
| 347 |
+
bool normalize_by_lengths,
|
| 348 |
+
float* out,
|
| 349 |
+
bool is_weight_positional = false,
|
| 350 |
+
bool use_offsets = true,
|
| 351 |
+
bool is_bf16 = false);
|
| 352 |
+
|
| 353 |
+
#if !defined(__aarch64__)
|
| 354 |
+
template <typename IndexType, bool HAS_WEIGHTS>
|
| 355 |
+
void compressed_indices_remap_avx512(
|
| 356 |
+
std::int32_t offsets_numel,
|
| 357 |
+
const IndexType* indices,
|
| 358 |
+
const int32_t* compressed_indices_mapping,
|
| 359 |
+
const IndexType* offsets,
|
| 360 |
+
const float* weights, // optional, can be null,
|
| 361 |
+
IndexType* out_indices,
|
| 362 |
+
IndexType* out_offsets,
|
| 363 |
+
float* out_weights);
|
| 364 |
+
#endif
|
| 365 |
+
|
| 366 |
+
} // namespace internal
|
| 367 |
+
|
| 368 |
+
template <typename IndexType>
|
| 369 |
+
FBGEMM_API void compressed_indices_remap(
|
| 370 |
+
std::int32_t offsets_numel,
|
| 371 |
+
const IndexType* indices,
|
| 372 |
+
const int32_t* compressed_indices_mapping,
|
| 373 |
+
const IndexType* offsets,
|
| 374 |
+
const float* weights, // optional, can be null,
|
| 375 |
+
IndexType* out_indices,
|
| 376 |
+
IndexType* out_offsets,
|
| 377 |
+
float* out_weights);
|
| 378 |
+
|
| 379 |
+
} // namespace fbgemm
|
| 380 |
+
|
| 381 |
+
#else
|
| 382 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 383 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmFP16.h
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
/*
|
| 3 |
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 4 |
+
* All rights reserved.
|
| 5 |
+
*
|
| 6 |
+
* This source code is licensed under the BSD-style license found in the
|
| 7 |
+
* LICENSE file in the root directory of this source tree.
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
// WARNING: this is a legacy fp16 fbgemm implementation and will soon be
|
| 13 |
+
// upgraded to match with new fbgemm interface.
|
| 14 |
+
|
| 15 |
+
#include <cpuinfo.h>
|
| 16 |
+
|
| 17 |
+
#include "./FbgemmPackMatrixB.h" // @manual
|
| 18 |
+
#include "./FloatConversion.h" // @manual
|
| 19 |
+
#include "./Types.h" // @manual
|
| 20 |
+
#include "./Utils.h" // @manual
|
| 21 |
+
|
| 22 |
+
namespace fbgemm {
|
| 23 |
+
|
| 24 |
+
template <>
|
| 25 |
+
struct TypeConverter<float16> {
|
| 26 |
+
float16 operator()(float src) const {
|
| 27 |
+
constexpr float FP16_MAX = 65504.f;
|
| 28 |
+
const float fp16 = std::max(-FP16_MAX, std::min(src, FP16_MAX));
|
| 29 |
+
return cpu_float2half(fp16);
|
| 30 |
+
}
|
| 31 |
+
};
|
| 32 |
+
|
| 33 |
+
using PackedGemmMatrixFP16 = PackedGemmMatrixB<float16>;
|
| 34 |
+
|
| 35 |
+
template <typename T>
|
| 36 |
+
FBGEMM_API void cblas_gemm_compute(
|
| 37 |
+
const matrix_op_t transa,
|
| 38 |
+
const int m,
|
| 39 |
+
const float* A,
|
| 40 |
+
const PackedGemmMatrixB<T>& Bp,
|
| 41 |
+
const float beta,
|
| 42 |
+
float* C,
|
| 43 |
+
int thread_id = 0,
|
| 44 |
+
int num_threads = 1);
|
| 45 |
+
|
| 46 |
+
extern template void cblas_gemm_compute<float16>(
|
| 47 |
+
const matrix_op_t transa,
|
| 48 |
+
const int m,
|
| 49 |
+
const float* A,
|
| 50 |
+
const PackedGemmMatrixFP16& Bp,
|
| 51 |
+
const float beta,
|
| 52 |
+
float* C,
|
| 53 |
+
int thread_id,
|
| 54 |
+
int num_threads);
|
| 55 |
+
|
| 56 |
+
}; // namespace fbgemm
|
| 57 |
+
|
| 58 |
+
#else
|
| 59 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 60 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmFP32.h
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
|
| 3 |
+
|
| 4 |
+
#pragma once
|
| 5 |
+
|
| 6 |
+
// WARNING: this is a legacy fp16 fbgemm implementation and will soon be
|
| 7 |
+
// upgraded to match with new fbgemm interface.
|
| 8 |
+
|
| 9 |
+
#include <cpuinfo.h>
|
| 10 |
+
|
| 11 |
+
#include "fbgemm/FbgemmFPCommon.h"
|
| 12 |
+
#include "fbgemm/FbgemmPackMatrixB.h"
|
| 13 |
+
#include "fbgemm/Utils.h"
|
| 14 |
+
|
| 15 |
+
namespace fbgemm {
|
| 16 |
+
template <>
|
| 17 |
+
struct TypeConverter<float> {
|
| 18 |
+
float operator()(float src) const {
|
| 19 |
+
return src;
|
| 20 |
+
}
|
| 21 |
+
};
|
| 22 |
+
|
| 23 |
+
using GemmParamsFP32 = GemmParams<float>;
|
| 24 |
+
using PackedGemmMatrixFP32 = PackedGemmMatrixB<float>;
|
| 25 |
+
|
| 26 |
+
template <typename T, int _kernel_ncol_blocks, int _brow>
|
| 27 |
+
void cblas_gemm_compute(
|
| 28 |
+
const matrix_op_t transa,
|
| 29 |
+
const int m,
|
| 30 |
+
const float* A,
|
| 31 |
+
const PackedGemmMatrixB<T>& Bp,
|
| 32 |
+
const float beta,
|
| 33 |
+
float* C,
|
| 34 |
+
int thread_id = 0,
|
| 35 |
+
int num_threads = 1);
|
| 36 |
+
|
| 37 |
+
extern template void cblas_gemm_compute(
|
| 38 |
+
const matrix_op_t transa,
|
| 39 |
+
const int m,
|
| 40 |
+
const float* A,
|
| 41 |
+
const PackedGemmMatrixFP32& Bp,
|
| 42 |
+
const float beta,
|
| 43 |
+
float* C,
|
| 44 |
+
int thread_id,
|
| 45 |
+
int num_threads);
|
| 46 |
+
|
| 47 |
+
template <>
|
| 48 |
+
const isa_descriptor<float>& getIsaHandlers(inst_set_t isa);
|
| 49 |
+
|
| 50 |
+
} // namespace fbgemm
|
| 51 |
+
|
| 52 |
+
#else
|
| 53 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 54 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmFPCommon.h
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
/*
|
| 3 |
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 4 |
+
* Copyright 2024-2025 Arm Limited and/or its affiliates
|
| 5 |
+
* <open-source-office@arm.com> All rights reserved.
|
| 6 |
+
*
|
| 7 |
+
* This source code is licensed under the BSD-style license found in the
|
| 8 |
+
* LICENSE file in the root directory of this source tree.
|
| 9 |
+
*/
|
| 10 |
+
|
| 11 |
+
#pragma once
|
| 12 |
+
|
| 13 |
+
#include <fbgemm/FbgemmPackMatrixB.h>
|
| 14 |
+
#include <fbgemm/Types.h>
|
| 15 |
+
#include <fbgemm/Utils.h>
|
| 16 |
+
#include <array>
|
| 17 |
+
#include <memory>
|
| 18 |
+
|
| 19 |
+
#if defined(FBGEMM_FP16_FALLBACK_TO_REF_KERNEL) || \
|
| 20 |
+
defined(FBGEMM_FP32_FALLBACK_TO_REF_KERNEL)
|
| 21 |
+
#if defined(__APPLE__) && defined(__aarch64__)
|
| 22 |
+
#define FBGEMM_USE_REF_KERNEL
|
| 23 |
+
#endif
|
| 24 |
+
#endif
|
| 25 |
+
|
| 26 |
+
namespace fbgemm {
|
| 27 |
+
|
| 28 |
+
using partition_array_t = std::array<std::array<std::array<int, 2>, 2>, 121>;
|
| 29 |
+
extern partition_array_t partition_avx2;
|
| 30 |
+
extern partition_array_t partition_avx512;
|
| 31 |
+
extern partition_array_t partition_sve128;
|
| 32 |
+
#ifdef FBGEMM_ENABLE_KLEIDIAI
|
| 33 |
+
extern partition_array_t partition_neon;
|
| 34 |
+
#endif
|
| 35 |
+
|
| 36 |
+
template <typename T>
|
| 37 |
+
struct GemmParams {
|
| 38 |
+
uint64_t k;
|
| 39 |
+
float* A;
|
| 40 |
+
const T* B;
|
| 41 |
+
float beta;
|
| 42 |
+
float* C;
|
| 43 |
+
uint64_t ldc;
|
| 44 |
+
uint64_t b_block_cols;
|
| 45 |
+
uint64_t b_block_size;
|
| 46 |
+
};
|
| 47 |
+
|
| 48 |
+
template <>
|
| 49 |
+
struct GemmParams<float16> {
|
| 50 |
+
uint64_t k;
|
| 51 |
+
float* A;
|
| 52 |
+
const float16* B;
|
| 53 |
+
float beta;
|
| 54 |
+
float* C;
|
| 55 |
+
uint64_t ldc;
|
| 56 |
+
uint64_t b_block_cols;
|
| 57 |
+
#ifdef FBGEMM_ENABLE_KLEIDIAI
|
| 58 |
+
uint64_t lda;
|
| 59 |
+
#else
|
| 60 |
+
uint64_t b_block_size;
|
| 61 |
+
#endif
|
| 62 |
+
};
|
| 63 |
+
|
| 64 |
+
template <>
|
| 65 |
+
struct GemmParams<float> {
|
| 66 |
+
uint64_t k;
|
| 67 |
+
float* A;
|
| 68 |
+
const float* B;
|
| 69 |
+
float beta;
|
| 70 |
+
float* C;
|
| 71 |
+
uint64_t ldc;
|
| 72 |
+
uint64_t b_block_cols;
|
| 73 |
+
#ifdef FBGEMM_ENABLE_KLEIDIAI
|
| 74 |
+
uint64_t lda;
|
| 75 |
+
#else
|
| 76 |
+
uint64_t b_block_size;
|
| 77 |
+
#endif
|
| 78 |
+
};
|
| 79 |
+
|
| 80 |
+
template <typename T>
|
| 81 |
+
using funcptr_t = void (*)(GemmParams<T>*);
|
| 82 |
+
template <typename T>
|
| 83 |
+
using kernel_array_t = std::array<funcptr_t<T>, 15>;
|
| 84 |
+
template <typename T>
|
| 85 |
+
using isa_descriptor = std::tuple<kernel_array_t<T>, partition_array_t>;
|
| 86 |
+
|
| 87 |
+
template <typename T>
|
| 88 |
+
extern const isa_descriptor<T>& getIsaHandlers(inst_set_t isa);
|
| 89 |
+
|
| 90 |
+
void PackA(int nrow, int ncol, const float* from, int ldim, float* to);
|
| 91 |
+
|
| 92 |
+
// define fp16/fp32 kernels using a reference C implementation
|
| 93 |
+
#if defined(FBGEMM_FP16_FALLBACK_TO_REF_KERNEL) || \
|
| 94 |
+
defined(FBGEMM_FP32_FALLBACK_TO_REF_KERNEL)
|
| 95 |
+
template <typename T>
|
| 96 |
+
FBGEMM_API void ref_kernel(
|
| 97 |
+
int kernel_nrows,
|
| 98 |
+
GemmParams<T>* gp,
|
| 99 |
+
const float* C_base,
|
| 100 |
+
int m_total,
|
| 101 |
+
int n_total,
|
| 102 |
+
int vlen);
|
| 103 |
+
#endif
|
| 104 |
+
|
| 105 |
+
template <typename T>
|
| 106 |
+
FBGEMM_API void cblas_gemm_compute(
|
| 107 |
+
const matrix_op_t transa,
|
| 108 |
+
const int m,
|
| 109 |
+
const float* A,
|
| 110 |
+
const PackedGemmMatrixB<T>& Bp,
|
| 111 |
+
const float beta,
|
| 112 |
+
float* C,
|
| 113 |
+
int thread_id = 0,
|
| 114 |
+
int num_threads = 1);
|
| 115 |
+
|
| 116 |
+
#if defined(FBGEMM_EXPORTS)
|
| 117 |
+
// autotuned kernel splits for various cases m = 1:mb_max
|
| 118 |
+
template <typename T>
|
| 119 |
+
void cblas_gemm_compute(
|
| 120 |
+
const matrix_op_t transa [[maybe_unused]],
|
| 121 |
+
const int m,
|
| 122 |
+
const float* A,
|
| 123 |
+
const PackedGemmMatrixB<T>& Bp,
|
| 124 |
+
const float beta,
|
| 125 |
+
float* C,
|
| 126 |
+
int thread_id,
|
| 127 |
+
int num_threads) {
|
| 128 |
+
// ground truth
|
| 129 |
+
assert(cpuinfo_initialize());
|
| 130 |
+
#ifndef __aarch64__
|
| 131 |
+
assert(cpuinfo_has_x86_fma3());
|
| 132 |
+
assert(cpuinfo_has_x86_f16c());
|
| 133 |
+
#endif
|
| 134 |
+
assert(transa == matrix_op_t::NoTranspose);
|
| 135 |
+
|
| 136 |
+
// private scratchpad storage
|
| 137 |
+
static thread_local std::unique_ptr<std::array<float, 256 * 1024>> scratchpad(
|
| 138 |
+
new std::array<float, 256 * 1024>());
|
| 139 |
+
|
| 140 |
+
// constants
|
| 141 |
+
const int n = Bp.numCols(), k = Bp.numRows(), ldc = n;
|
| 142 |
+
const int mb_max = 120;
|
| 143 |
+
|
| 144 |
+
#if defined(FBGEMM_USE_REF_KERNEL) && defined(__APPLE__)
|
| 145 |
+
const auto& [_, partition] = getIsaHandlers<float16>(inst_set_t::sve);
|
| 146 |
+
#else
|
| 147 |
+
const auto iset = fbgemmInstructionSet();
|
| 148 |
+
const auto& [kernels, partition] = getIsaHandlers<T>(iset);
|
| 149 |
+
#endif
|
| 150 |
+
|
| 151 |
+
#ifdef FBGEMM_USE_REF_KERNEL
|
| 152 |
+
// By some reason, if packed B is using packing layout for avx2, we just use
|
| 153 |
+
// avx2 even if avx512 is available.
|
| 154 |
+
const int simd_width =
|
| 155 |
+
#ifndef __aarch64__
|
| 156 |
+
(iset == inst_set_t::avx512 || iset == inst_set_t::avx512_vnni) &&
|
| 157 |
+
(Bp.blockColSize() == 16 * Bp.kernelNumColBlocks())
|
| 158 |
+
? simd_info<inst_set_t::avx512>::WIDTH_32BIT_ELEMS
|
| 159 |
+
: simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS;
|
| 160 |
+
#else
|
| 161 |
+
simd_info<inst_set_t::sve>::WIDTH_32BIT_ELEMS;
|
| 162 |
+
#endif
|
| 163 |
+
#endif
|
| 164 |
+
|
| 165 |
+
GemmParams<T> gp;
|
| 166 |
+
int i_begin = 0, i_end = 0;
|
| 167 |
+
i_begin = 0;
|
| 168 |
+
i_end = m;
|
| 169 |
+
for (auto m0 = i_begin; m0 < i_end; m0 += mb_max) {
|
| 170 |
+
int mb = std::min(mb_max, i_end - m0);
|
| 171 |
+
assert(mb < static_cast<int64_t>(partition.size()));
|
| 172 |
+
for (auto k_ind = 0; k_ind < k; k_ind += Bp.blockRowSize()) {
|
| 173 |
+
// set up proper accumulation to avoid "Nan" problem
|
| 174 |
+
// accumulate of beta != 0.0
|
| 175 |
+
// do not!!! accumulate otherwise
|
| 176 |
+
float beta_ = beta;
|
| 177 |
+
if (k_ind != 0) {
|
| 178 |
+
// always accumulate with beta_ = 1.0f
|
| 179 |
+
beta_ = 1.0f;
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
const int kb = std::min(Bp.blockRowSize(), Bp.numRows() - k_ind);
|
| 183 |
+
|
| 184 |
+
auto m1 = m0;
|
| 185 |
+
auto const num_cycles = partition[mb].size();
|
| 186 |
+
for (size_t c = 0; c < num_cycles; ++c) {
|
| 187 |
+
auto kernel_nrows = partition[mb][c][0];
|
| 188 |
+
auto nkernel_nrows = partition[mb][c][1];
|
| 189 |
+
auto m_start = m1;
|
| 190 |
+
auto m_end = m1 + kernel_nrows * nkernel_nrows;
|
| 191 |
+
for (auto m2 = m_start; m2 < m_end; m2 += kernel_nrows) {
|
| 192 |
+
assert(kernel_nrows * kb < static_cast<int64_t>(scratchpad->size()));
|
| 193 |
+
if (m != 1) {
|
| 194 |
+
#ifdef FBGEMM_ENABLE_KLEIDIAI
|
| 195 |
+
if constexpr (
|
| 196 |
+
std::is_same<T, float16>::value ||
|
| 197 |
+
std::is_same<T, float>::value) {
|
| 198 |
+
gp.A = const_cast<float*>(&A[m2 * k + k_ind]);
|
| 199 |
+
} else {
|
| 200 |
+
#endif
|
| 201 |
+
PackA(
|
| 202 |
+
kernel_nrows, kb, &A[m2 * k + k_ind], k, scratchpad->data());
|
| 203 |
+
gp.A = scratchpad->data();
|
| 204 |
+
#ifdef FBGEMM_ENABLE_KLEIDIAI
|
| 205 |
+
}
|
| 206 |
+
#endif
|
| 207 |
+
} else {
|
| 208 |
+
// When m == 1, it is actually vector matrix multiplication. We
|
| 209 |
+
// don't need to do the transposition for packA here. Instead, we
|
| 210 |
+
// can just pass the pointer of the original A matrix buffer to the
|
| 211 |
+
// packed A buffer.
|
| 212 |
+
gp.A = const_cast<float*>(&A[k_ind]);
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
int nbcol = n / Bp.blockColSize();
|
| 216 |
+
gp.k = kb;
|
| 217 |
+
gp.B = &(Bp(k_ind, 0));
|
| 218 |
+
gp.beta = beta_;
|
| 219 |
+
gp.C = &C[m2 * ldc];
|
| 220 |
+
gp.ldc = ldc * sizeof(C[0]);
|
| 221 |
+
gp.b_block_cols = nbcol;
|
| 222 |
+
#ifdef FBGEMM_ENABLE_KLEIDIAI
|
| 223 |
+
if constexpr (
|
| 224 |
+
std::is_same<T, float16>::value ||
|
| 225 |
+
std::is_same<T, float>::value) {
|
| 226 |
+
gp.lda = k * sizeof(A[0]);
|
| 227 |
+
} else {
|
| 228 |
+
#endif
|
| 229 |
+
gp.b_block_size = gp.k * Bp.blockColSize() * sizeof(gp.B[0]);
|
| 230 |
+
#ifdef FBGEMM_ENABLE_KLEIDIAI
|
| 231 |
+
}
|
| 232 |
+
#endif
|
| 233 |
+
if ((n % Bp.blockColSize()) == 0) {
|
| 234 |
+
int64_t jb_begin = 0, jb_end = 0;
|
| 235 |
+
fbgemmPartition1D(
|
| 236 |
+
thread_id, num_threads, gp.b_block_cols, jb_begin, jb_end);
|
| 237 |
+
gp.B += gp.k * Bp.blockColSize() * jb_begin;
|
| 238 |
+
gp.C += Bp.blockColSize() * jb_begin;
|
| 239 |
+
gp.b_block_cols = jb_end - jb_begin;
|
| 240 |
+
if (gp.b_block_cols) {
|
| 241 |
+
#ifdef FBGEMM_USE_REF_KERNEL
|
| 242 |
+
ref_kernel<T>(kernel_nrows, &gp, C, m, n, simd_width);
|
| 243 |
+
#else
|
| 244 |
+
kernels[kernel_nrows](&gp);
|
| 245 |
+
#endif
|
| 246 |
+
}
|
| 247 |
+
} else {
|
| 248 |
+
int last_blk_col = nbcol * Bp.blockColSize();
|
| 249 |
+
if (nbcol) {
|
| 250 |
+
int64_t jb_begin = 0, jb_end = 0;
|
| 251 |
+
fbgemmPartition1D(
|
| 252 |
+
thread_id, num_threads, gp.b_block_cols, jb_begin, jb_end);
|
| 253 |
+
gp.B += gp.k * Bp.blockColSize() * jb_begin;
|
| 254 |
+
gp.C += Bp.blockColSize() * jb_begin;
|
| 255 |
+
gp.b_block_cols = jb_end - jb_begin;
|
| 256 |
+
if (gp.b_block_cols) {
|
| 257 |
+
#ifdef FBGEMM_USE_REF_KERNEL
|
| 258 |
+
ref_kernel(kernel_nrows, &gp, C, m, n, simd_width);
|
| 259 |
+
#else
|
| 260 |
+
kernels[kernel_nrows](&gp);
|
| 261 |
+
#endif
|
| 262 |
+
}
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
// use one thread to handle the fringe cases
|
| 266 |
+
if (thread_id == num_threads - 1) {
|
| 267 |
+
// leftover
|
| 268 |
+
const int rem [[maybe_unused]] = n - last_blk_col;
|
| 269 |
+
assert(rem < Bp.blockColSize());
|
| 270 |
+
|
| 271 |
+
// small temporary buffer: the size should be larger than the
|
| 272 |
+
// required kernel_nrow x kernel_ncols elements computed in the
|
| 273 |
+
// registers.
|
| 274 |
+
std::array<float, 14 * 32> c_tmp{0.f};
|
| 275 |
+
assert(
|
| 276 |
+
static_cast<int64_t>(c_tmp.size()) >=
|
| 277 |
+
kernel_nrows * Bp.blockColSize());
|
| 278 |
+
|
| 279 |
+
gp.B = &(Bp(k_ind, last_blk_col));
|
| 280 |
+
gp.C = c_tmp.data();
|
| 281 |
+
gp.ldc = Bp.blockColSize() * sizeof(C[0]);
|
| 282 |
+
gp.b_block_cols = 1;
|
| 283 |
+
#ifdef FBGEMM_USE_REF_KERNEL
|
| 284 |
+
ref_kernel<T>(
|
| 285 |
+
kernel_nrows, &gp, c_tmp.data(), 14, 32, simd_width);
|
| 286 |
+
#else
|
| 287 |
+
kernels[kernel_nrows](&gp);
|
| 288 |
+
#endif
|
| 289 |
+
for (int i = 0; i < kernel_nrows; i++) {
|
| 290 |
+
// Todo: use assembly
|
| 291 |
+
for (int j = last_blk_col; j < n; j++) {
|
| 292 |
+
assert(
|
| 293 |
+
i * Bp.blockColSize() + (j - last_blk_col) <
|
| 294 |
+
static_cast<int64_t>(sizeof(c_tmp) / sizeof(c_tmp[0])));
|
| 295 |
+
if (beta_ == 0.f) {
|
| 296 |
+
C[(m2 + i) * ldc + j] =
|
| 297 |
+
c_tmp[i * Bp.blockColSize() + (j - last_blk_col)];
|
| 298 |
+
} else {
|
| 299 |
+
C[(m2 + i) * ldc + j] = beta_ * C[(m2 + i) * ldc + j] +
|
| 300 |
+
c_tmp[i * Bp.blockColSize() + (j - last_blk_col)];
|
| 301 |
+
}
|
| 302 |
+
}
|
| 303 |
+
}
|
| 304 |
+
}
|
| 305 |
+
}
|
| 306 |
+
}
|
| 307 |
+
m1 += kernel_nrows * nkernel_nrows;
|
| 308 |
+
}
|
| 309 |
+
}
|
| 310 |
+
}
|
| 311 |
+
}
|
| 312 |
+
#endif
|
| 313 |
+
|
| 314 |
+
#undef FBGEMM_USE_REF_KERNEL
|
| 315 |
+
} // namespace fbgemm
|
| 316 |
+
|
| 317 |
+
#else
|
| 318 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 319 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmI64.h
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
/*
|
| 3 |
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 4 |
+
* All rights reserved.
|
| 5 |
+
*
|
| 6 |
+
* This source code is licensed under the BSD-style license found in the
|
| 7 |
+
* LICENSE file in the root directory of this source tree.
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include <cstdint>
|
| 13 |
+
|
| 14 |
+
#include "fbgemm/Utils.h"
|
| 15 |
+
|
| 16 |
+
namespace fbgemm {
|
| 17 |
+
|
| 18 |
+
FBGEMM_API void cblas_gemm_i64_i64acc(
|
| 19 |
+
matrix_op_t transa,
|
| 20 |
+
matrix_op_t transb,
|
| 21 |
+
int M,
|
| 22 |
+
int N,
|
| 23 |
+
int K,
|
| 24 |
+
const std::int64_t* A,
|
| 25 |
+
int lda,
|
| 26 |
+
const std::int64_t* B,
|
| 27 |
+
int ldb,
|
| 28 |
+
bool accumulate,
|
| 29 |
+
std::int64_t* C,
|
| 30 |
+
int ldc);
|
| 31 |
+
|
| 32 |
+
} // namespace fbgemm
|
| 33 |
+
|
| 34 |
+
#else
|
| 35 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 36 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
URSA/.venv_ursa/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmI8DepthwiseAvx2.h
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
/*
|
| 3 |
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 4 |
+
* All rights reserved.
|
| 5 |
+
*
|
| 6 |
+
* This source code is licensed under the BSD-style license found in the
|
| 7 |
+
* LICENSE file in the root directory of this source tree.
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include <cstdint>
|
| 13 |
+
#include "fbgemm/ConvUtils.h"
|
| 14 |
+
#include "fbgemm/FbgemmBuild.h"
|
| 15 |
+
#include "fbgemm/UtilsAvx2.h"
|
| 16 |
+
|
| 17 |
+
namespace fbgemm {
|
| 18 |
+
|
| 19 |
+
class FBGEMM_API PackedDepthWiseConvMatrix {
|
| 20 |
+
public:
|
| 21 |
+
/**
|
| 22 |
+
* @param IC the number of input channels (same as the number of groups
|
| 23 |
+
* because depth-wise convolution has one input channel per group)
|
| 24 |
+
* @param OC the number of output channels
|
| 25 |
+
* @param kernel_prod the product of all kernels. For example, kernel_prod =
|
| 26 |
+
* 9 for 3x3 conv, and 27 for 3x3x3 conv.
|
| 27 |
+
* @param smat the source unpacked weight in GRS layout
|
| 28 |
+
*/
|
| 29 |
+
PackedDepthWiseConvMatrix(int OC, int kernel_prod, const std::int8_t* smat);
|
| 30 |
+
PackedDepthWiseConvMatrix(const PackedDepthWiseConvMatrix&) = delete;
|
| 31 |
+
PackedDepthWiseConvMatrix(PackedDepthWiseConvMatrix&&) = delete;
|
| 32 |
+
PackedDepthWiseConvMatrix& operator=(const PackedDepthWiseConvMatrix&) =
|
| 33 |
+
delete;
|
| 34 |
+
PackedDepthWiseConvMatrix& operator=(PackedDepthWiseConvMatrix&&) = delete;
|
| 35 |
+
virtual ~PackedDepthWiseConvMatrix();
|
| 36 |
+
|
| 37 |
+
const std::int8_t* PackedMat() const {
|
| 38 |
+
return pmat_;
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
int GetKernelProduct() const {
|
| 42 |
+
return kernel_prod_;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
/**
|
| 46 |
+
* @brief Unpacks pmat_ into unpack_data.
|
| 47 |
+
* Used for recovering the weight matrix into the original format
|
| 48 |
+
*/
|
| 49 |
+
void unpack(std::int8_t* unpacked_data);
|
| 50 |
+
|
| 51 |
+
/**
|
| 52 |
+
* @brief returns the index into pmat_ given the row and column for smat
|
| 53 |
+
*/
|
| 54 |
+
int addr(int r, int c);
|
| 55 |
+
|
| 56 |
+
private:
|
| 57 |
+
const int OC_; /**< the number of output channels */
|
| 58 |
+
const int kernel_prod_; /** the product of all kernel dims */
|
| 59 |
+
std::int8_t* pmat_; /** packed weight */
|
| 60 |
+
}; // PackedDepthWiseConvMatrix
|
| 61 |
+
|
| 62 |
+
/**
|
| 63 |
+
* Depth-wise convolution that results in the same output feature size as the
|
| 64 |
+
* input feature. That is PAD_T = PAD_B = (R - 1) / 2 and PAD_L = PAD_R =
|
| 65 |
+
* (S - 1) / 2. This function also does requantization.
|
| 66 |
+
* @param col_offsets nullptr if col_offsets are folded into bias
|
| 67 |
+
* @param act_times_w_scale Only used if BIAS_TYPE is float, i.e., bias is
|
| 68 |
+
* unquantized.
|
| 69 |
+
*/
|
| 70 |
+
template <QuantizationGranularity Q_GRAN, typename BIAS_TYPE = std::int32_t>
|
| 71 |
+
FBGEMM_API void depthwise_2d_same_pad(
|
| 72 |
+
int N,
|
| 73 |
+
int H,
|
| 74 |
+
int W,
|
| 75 |
+
int IC,
|
| 76 |
+
int OC,
|
| 77 |
+
int stride_h,
|
| 78 |
+
int stride_w,
|
| 79 |
+
std::int32_t A_zero_point,
|
| 80 |
+
const std::uint8_t* A,
|
| 81 |
+
const std::int32_t* B_zero_point,
|
| 82 |
+
const PackedDepthWiseConvMatrix& Bp,
|
| 83 |
+
const float* C_multiplier,
|
| 84 |
+
std::int32_t C_zero_point,
|
| 85 |
+
std::uint8_t* C,
|
| 86 |
+
const std::int32_t* col_offsets,
|
| 87 |
+
const BIAS_TYPE* bias,
|
| 88 |
+
bool fuse_relu = false,
|
| 89 |
+
const float* act_times_w_scale = nullptr,
|
| 90 |
+
int thread_id = 0,
|
| 91 |
+
int num_threads = 1);
|
| 92 |
+
|
| 93 |
+
/**
|
| 94 |
+
* @param col_offsets nullptr if col_offsets are folded into bias
|
| 95 |
+
*/
|
| 96 |
+
template <QuantizationGranularity Q_GRAN, typename BIAS_TYPE = std::int32_t>
|
| 97 |
+
FBGEMM_API void depthwise_3d_same_pad(
|
| 98 |
+
const conv_param_t<3>& conv_p,
|
| 99 |
+
std::int32_t A_zero_point,
|
| 100 |
+
const std::uint8_t* A,
|
| 101 |
+
const std::int32_t* B_zero_point,
|
| 102 |
+
const PackedDepthWiseConvMatrix& Bp,
|
| 103 |
+
const float* C_multiplier,
|
| 104 |
+
std::int32_t C_zero_point,
|
| 105 |
+
std::uint8_t* C,
|
| 106 |
+
const std::int32_t* col_offsets,
|
| 107 |
+
const BIAS_TYPE* bias,
|
| 108 |
+
bool fuse_relu = false,
|
| 109 |
+
const float* act_times_w_scale = nullptr,
|
| 110 |
+
int thread_id = 0,
|
| 111 |
+
int num_threads = 1);
|
| 112 |
+
|
| 113 |
+
} // namespace fbgemm
|
| 114 |
+
|
| 115 |
+
#else
|
| 116 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 117 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|