Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/split_cat.cpython-311.pyc +3 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/AccumulateType.h +173 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/Backend.h +2 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/CPUApplyUtils.h +343 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/CPUFixedAllocator.h +33 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/CPUFunctions.h +29 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/CPUGeneratorImpl.h +49 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/CompositeExplicitAutogradFunctions_inl.h +553 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions.h +29 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions.h +29 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h +25 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/Config.h +21 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/Context.h +610 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/Device.h +2 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/DeviceAccelerator.h +27 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h +808 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/EmptyTensor.h +166 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/ExpandBase.h +30 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/FuncTorchTLS.h +46 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/FunctionalStorageImpl.h +208 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/FunctionalTensorWrapper.h +454 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/InferSize.h +88 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/InitialTensorOptions.h +15 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/Layout.h +2 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/LegacyBatchedFallback.h +25 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/LegacyBatchedTensorImpl.h +160 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/LegacyVmapMode.h +26 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/LegacyVmapTransforms.h +183 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/MapAllocator.h +143 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/MatrixRef.h +107 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/MetaFunctions.h +29 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/MethodOperators.h +443 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/NamedTensor.h +1 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/NamedTensorUtils.h +214 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/NativeFunctions.h +1344 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/NestedTensorImpl.h +286 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/OpMathType.h +69 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/PadNd.h +28 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/Parallel.h +158 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/ParallelFuture.h +13 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/RegistrationDeclarations.h +0 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/SavedTensorHooks.h +66 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/Scalar.h +3 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/ScalarOps.h +53 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/ScalarType.h +4 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/SparseCsrTensorImpl.h +206 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/SparseCsrTensorUtils.h +441 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/Storage.h +2 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/Tensor.h +3 -0
.gitattributes
CHANGED
|
@@ -145,3 +145,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
|
|
| 145 |
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cpu.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 146 |
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 147 |
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/common.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 145 |
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cpu.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 146 |
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/cpp.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 147 |
.venv/lib/python3.11/site-packages/torch/_inductor/codegen/__pycache__/common.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 148 |
+
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/split_cat.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
.venv/lib/python3.11/site-packages/torch/_inductor/fx_passes/__pycache__/split_cat.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5b6d007391b31b1b010874c0dfd08680792c26a7542040197359230b54bc6612
|
| 3 |
+
size 114085
|
.venv/lib/python3.11/site-packages/torch/include/ATen/AccumulateType.h
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/Config.h>
|
| 3 |
+
#include <c10/core/DeviceType.h>
|
| 4 |
+
#include <c10/core/ScalarType.h>
|
| 5 |
+
#include <c10/util/BFloat16.h>
|
| 6 |
+
#include <c10/util/Float8_e4m3fn.h>
|
| 7 |
+
#include <c10/util/Float8_e4m3fnuz.h>
|
| 8 |
+
#include <c10/util/Float8_e5m2.h>
|
| 9 |
+
#include <c10/util/Float8_e5m2fnuz.h>
|
| 10 |
+
#include <c10/util/Half.h>
|
| 11 |
+
|
| 12 |
+
// Defines the accumulation type for a scalar type.
|
| 13 |
+
// Example:
|
| 14 |
+
// using accscalar_t = acc_type<scalar_t, /*is_cuda*/true>;
|
| 15 |
+
//
|
| 16 |
+
// Accumulation types are an important concept in numeric computing
|
| 17 |
+
// because you frequently want to perform intermediate computations
|
| 18 |
+
// at a higher precision than the input and output precision, to avoid
|
| 19 |
+
// compounding internal rounding errors. Accumulation is the most
|
| 20 |
+
// well-known intermediate computation (it is of great importance for
|
| 21 |
+
// sum reduction and matrix multiply, for example), but in PyTorch
|
| 22 |
+
// acc_type ends up getting used for all sorts of other intermediate
|
| 23 |
+
// computations, so it perhaps would be more accurately (ahem) called an
|
| 24 |
+
// "accurate" type. acc_type is especially important for reduced
|
| 25 |
+
// precision operations like float16 and bfloat16, where relatively
|
| 26 |
+
// benign looking inputs can easily end up overflowing/underflowing.
|
| 27 |
+
//
|
| 28 |
+
// acc_type is parametrized by whether or not you are running on CUDA
|
| 29 |
+
// or not, because on CUDA double precision operations are expensive
|
| 30 |
+
// and so by default, we don't actually want to use double as an
|
| 31 |
+
// acc_type on CUDA. A lot of things are typed out below, but
|
| 32 |
+
// basically, the table is generated by a few rules:
|
| 33 |
+
//
|
| 34 |
+
// If bool:
|
| 35 |
+
// Use 'bool' as acc_type.
|
| 36 |
+
// If floating point:
|
| 37 |
+
// If CUDA, use 'float' as acc_type (unless scalar_t is double),
|
| 38 |
+
// otherwise (CPU) use 'double'
|
| 39 |
+
// If integral:
|
| 40 |
+
// Use 'int64_t' as acc_type
|
| 41 |
+
//
|
| 42 |
+
// You're not forced to use this template; if you happen to know
|
| 43 |
+
// something specific about your use case, you can specify your own
|
| 44 |
+
// desired behavior. This template, however, will give you a reasonable
|
| 45 |
+
// default that will work for all dtypes supported in PyTorch.
|
| 46 |
+
|
| 47 |
+
#if defined(__CUDACC__)
|
| 48 |
+
#include <cuda.h>
|
| 49 |
+
#include <cuda_fp16.h>
|
| 50 |
+
#elif defined(__HIPCC__)
|
| 51 |
+
#include <hip/hip_fp16.h>
|
| 52 |
+
#include <hip/hip_runtime.h>
|
| 53 |
+
#endif
|
| 54 |
+
|
| 55 |
+
namespace at {
|
| 56 |
+
|
| 57 |
+
template <typename T, c10::DeviceType D>
|
| 58 |
+
struct AccumulateTypeDevice {};
|
| 59 |
+
|
| 60 |
+
template <typename T, bool>
|
| 61 |
+
struct AccumulateType {};
|
| 62 |
+
|
| 63 |
+
template <typename T>
|
| 64 |
+
struct AccumulateType<T, false> {
|
| 65 |
+
using type = typename AccumulateTypeDevice<T, c10::DeviceType::CPU>::type;
|
| 66 |
+
};
|
| 67 |
+
|
| 68 |
+
template <typename T>
|
| 69 |
+
struct AccumulateType<T, true> {
|
| 70 |
+
using type = typename AccumulateTypeDevice<T, c10::DeviceType::CUDA>::type;
|
| 71 |
+
};
|
| 72 |
+
|
| 73 |
+
template <typename T, c10::DeviceType device>
|
| 74 |
+
using acc_type_device = typename AccumulateTypeDevice<T, device>::type;
|
| 75 |
+
|
| 76 |
+
template <typename T, bool is_cuda>
|
| 77 |
+
using acc_type = typename AccumulateType<T, is_cuda>::type;
|
| 78 |
+
|
| 79 |
+
#define ACC_TYPE(t, acc_t, device_type) \
|
| 80 |
+
template <> \
|
| 81 |
+
struct AccumulateTypeDevice<t, device_type> { \
|
| 82 |
+
using type = acc_t; \
|
| 83 |
+
};
|
| 84 |
+
#define MPS_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::MPS)
|
| 85 |
+
#define XPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::XPU)
|
| 86 |
+
#define CUDA_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CUDA)
|
| 87 |
+
#define CPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CPU)
|
| 88 |
+
|
| 89 |
+
MPS_ACC_TYPE(BFloat16, float);
|
| 90 |
+
MPS_ACC_TYPE(Half, float);
|
| 91 |
+
MPS_ACC_TYPE(Float8_e5m2, float);
|
| 92 |
+
MPS_ACC_TYPE(Float8_e4m3fn, float);
|
| 93 |
+
MPS_ACC_TYPE(Float8_e5m2fnuz, float);
|
| 94 |
+
MPS_ACC_TYPE(Float8_e4m3fnuz, float);
|
| 95 |
+
MPS_ACC_TYPE(float, float);
|
| 96 |
+
MPS_ACC_TYPE(double, float);
|
| 97 |
+
MPS_ACC_TYPE(int8_t, int64_t);
|
| 98 |
+
MPS_ACC_TYPE(uint8_t, int64_t);
|
| 99 |
+
MPS_ACC_TYPE(char, int64_t);
|
| 100 |
+
MPS_ACC_TYPE(int16_t, int64_t);
|
| 101 |
+
MPS_ACC_TYPE(int32_t, int64_t);
|
| 102 |
+
MPS_ACC_TYPE(int64_t, int64_t);
|
| 103 |
+
MPS_ACC_TYPE(bool, bool);
|
| 104 |
+
MPS_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
|
| 105 |
+
MPS_ACC_TYPE(c10::complex<float>, c10::complex<float>);
|
| 106 |
+
MPS_ACC_TYPE(c10::complex<double>, c10::complex<float>);
|
| 107 |
+
|
| 108 |
+
XPU_ACC_TYPE(BFloat16, float);
|
| 109 |
+
XPU_ACC_TYPE(Half, float);
|
| 110 |
+
XPU_ACC_TYPE(Float8_e5m2, float);
|
| 111 |
+
XPU_ACC_TYPE(Float8_e4m3fn, float);
|
| 112 |
+
XPU_ACC_TYPE(Float8_e5m2fnuz, float);
|
| 113 |
+
XPU_ACC_TYPE(Float8_e4m3fnuz, float);
|
| 114 |
+
XPU_ACC_TYPE(float, float);
|
| 115 |
+
XPU_ACC_TYPE(double, double);
|
| 116 |
+
XPU_ACC_TYPE(int8_t, int64_t);
|
| 117 |
+
XPU_ACC_TYPE(uint8_t, int64_t);
|
| 118 |
+
XPU_ACC_TYPE(char, int64_t);
|
| 119 |
+
XPU_ACC_TYPE(int16_t, int64_t);
|
| 120 |
+
XPU_ACC_TYPE(int32_t, int64_t);
|
| 121 |
+
XPU_ACC_TYPE(int64_t, int64_t);
|
| 122 |
+
XPU_ACC_TYPE(bool, bool);
|
| 123 |
+
XPU_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
|
| 124 |
+
XPU_ACC_TYPE(c10::complex<float>, c10::complex<float>);
|
| 125 |
+
XPU_ACC_TYPE(c10::complex<double>, c10::complex<double>);
|
| 126 |
+
|
| 127 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 128 |
+
CUDA_ACC_TYPE(half, float);
|
| 129 |
+
#endif
|
| 130 |
+
CUDA_ACC_TYPE(BFloat16, float);
|
| 131 |
+
CUDA_ACC_TYPE(Half, float);
|
| 132 |
+
CUDA_ACC_TYPE(Float8_e5m2, float);
|
| 133 |
+
CUDA_ACC_TYPE(Float8_e4m3fn, float);
|
| 134 |
+
CUDA_ACC_TYPE(Float8_e5m2fnuz, float);
|
| 135 |
+
CUDA_ACC_TYPE(Float8_e4m3fnuz, float);
|
| 136 |
+
CUDA_ACC_TYPE(float, float);
|
| 137 |
+
CUDA_ACC_TYPE(double, double);
|
| 138 |
+
CUDA_ACC_TYPE(int8_t, int64_t);
|
| 139 |
+
CUDA_ACC_TYPE(uint8_t, int64_t);
|
| 140 |
+
CUDA_ACC_TYPE(char, int64_t);
|
| 141 |
+
CUDA_ACC_TYPE(int16_t, int64_t);
|
| 142 |
+
CUDA_ACC_TYPE(int32_t, int64_t);
|
| 143 |
+
CUDA_ACC_TYPE(int64_t, int64_t);
|
| 144 |
+
CUDA_ACC_TYPE(bool, bool);
|
| 145 |
+
CUDA_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
|
| 146 |
+
CUDA_ACC_TYPE(c10::complex<float>, c10::complex<float>);
|
| 147 |
+
CUDA_ACC_TYPE(c10::complex<double>, c10::complex<double>);
|
| 148 |
+
|
| 149 |
+
CPU_ACC_TYPE(BFloat16, float);
|
| 150 |
+
CPU_ACC_TYPE(Half, float);
|
| 151 |
+
CPU_ACC_TYPE(Float8_e5m2, float);
|
| 152 |
+
CPU_ACC_TYPE(Float8_e4m3fn, float);
|
| 153 |
+
CPU_ACC_TYPE(Float8_e5m2fnuz, float);
|
| 154 |
+
CPU_ACC_TYPE(Float8_e4m3fnuz, float);
|
| 155 |
+
CPU_ACC_TYPE(float, double);
|
| 156 |
+
CPU_ACC_TYPE(double, double);
|
| 157 |
+
CPU_ACC_TYPE(int8_t, int64_t);
|
| 158 |
+
CPU_ACC_TYPE(uint8_t, int64_t);
|
| 159 |
+
CPU_ACC_TYPE(char, int64_t);
|
| 160 |
+
CPU_ACC_TYPE(int16_t, int64_t);
|
| 161 |
+
CPU_ACC_TYPE(int32_t, int64_t);
|
| 162 |
+
CPU_ACC_TYPE(int64_t, int64_t);
|
| 163 |
+
CPU_ACC_TYPE(bool, bool);
|
| 164 |
+
CPU_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
|
| 165 |
+
CPU_ACC_TYPE(c10::complex<float>, c10::complex<double>);
|
| 166 |
+
CPU_ACC_TYPE(c10::complex<double>, c10::complex<double>);
|
| 167 |
+
|
| 168 |
+
TORCH_API c10::ScalarType toAccumulateType(
|
| 169 |
+
c10::ScalarType type,
|
| 170 |
+
c10::DeviceType device);
|
| 171 |
+
TORCH_API c10::ScalarType toAccumulateType(c10::ScalarType type, bool is_cuda);
|
| 172 |
+
|
| 173 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/Backend.h
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <c10/core/Backend.h>
|
.venv/lib/python3.11/site-packages/torch/include/ATen/CPUApplyUtils.h
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/CollapseDims.h>
|
| 4 |
+
#include <ATen/Parallel.h>
|
| 5 |
+
#include <ATen/TensorUtils.h>
|
| 6 |
+
#include <c10/util/irange.h>
|
| 7 |
+
#include <cstring>
|
| 8 |
+
#include <limits>
|
| 9 |
+
|
| 10 |
+
namespace at {
|
| 11 |
+
|
| 12 |
+
/*
|
| 13 |
+
* The basic strategy for apply is as follows:
|
| 14 |
+
*
|
| 15 |
+
* 1. Starting with the outermost index, loop until we reach a dimension where
|
| 16 |
+
* the data is no longer contiguous, i.e. the stride at that dimension is not
|
| 17 |
+
* equal to the size of the tensor defined by the outer dimensions. Let's call
|
| 18 |
+
* this outer (contiguous) tensor A. Note that if the Tensor is contiguous, then
|
| 19 |
+
* A is equal to the entire Tensor. Let's call the inner tensor B.
|
| 20 |
+
*
|
| 21 |
+
* 2. We loop through the indices in B, starting at its outermost dimension. For
|
| 22 |
+
* example, if B is a 2x2 matrix, then we do:
|
| 23 |
+
*
|
| 24 |
+
* B[0][0]
|
| 25 |
+
* B[0][1]
|
| 26 |
+
* B[1][0]
|
| 27 |
+
* B[1][1]
|
| 28 |
+
*
|
| 29 |
+
* We set the offset into the underlying storage as (storageOffset + stride_B *
|
| 30 |
+
* index_B), i.e. basically we compute the offset into the storage as we would
|
| 31 |
+
* normally for a Tensor. But because we are guaranteed the subsequent data is
|
| 32 |
+
* contiguous in memory, we can simply loop for sizeof(A) iterations and perform
|
| 33 |
+
* the operation, without having to follow the order described by the strides of
|
| 34 |
+
* A.
|
| 35 |
+
*
|
| 36 |
+
* 3. As an optimization, we merge dimensions of A that are contiguous in
|
| 37 |
+
* memory. For example, if A is a 3x3x3x3 tensor narrowed from a 3x3x4x3 tensor,
|
| 38 |
+
* then the first two dimensions can be merged for the purposes of APPLY,
|
| 39 |
+
* reducing the number of nested loops.
|
| 40 |
+
*/
|
| 41 |
+
|
| 42 |
+
inline Tensor sort_strides(Tensor& tensor_) {
|
| 43 |
+
IntArrayRef strides = tensor_.strides();
|
| 44 |
+
std::vector<int64_t> indices;
|
| 45 |
+
indices.reserve(tensor_.ndimension());
|
| 46 |
+
for (const auto i : c10::irange(tensor_.ndimension())) {
|
| 47 |
+
indices.push_back(i);
|
| 48 |
+
}
|
| 49 |
+
std::sort(indices.begin(), indices.end(), [&strides](int64_t i1, int64_t i2) {
|
| 50 |
+
return strides[i1] > strides[i2];
|
| 51 |
+
});
|
| 52 |
+
Tensor tensor = tensor_.permute(indices);
|
| 53 |
+
return tensor;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
template <typename T, int N>
|
| 57 |
+
struct strided_tensor_iter_fixed {
|
| 58 |
+
public:
|
| 59 |
+
T* data_ = NULL;
|
| 60 |
+
int64_t dim_ = 0;
|
| 61 |
+
|
| 62 |
+
int64_t counter_[N] = {0};
|
| 63 |
+
int64_t sizes_[N] = {0};
|
| 64 |
+
int64_t strides_[N] = {0};
|
| 65 |
+
|
| 66 |
+
strided_tensor_iter_fixed(strided_tensor_iter_fixed const&) = delete;
|
| 67 |
+
void operator=(strided_tensor_iter_fixed const& x) = delete;
|
| 68 |
+
strided_tensor_iter_fixed(strided_tensor_iter_fixed&&) = default;
|
| 69 |
+
strided_tensor_iter_fixed(
|
| 70 |
+
Tensor& tensor,
|
| 71 |
+
C10_UNUSED bool sort_strides = false)
|
| 72 |
+
: data_(tensor.data_ptr<T>()) {
|
| 73 |
+
std::memset(counter_, 0, sizeof(int64_t) * N);
|
| 74 |
+
if (tensor.dim() > 0) {
|
| 75 |
+
std::memcpy(
|
| 76 |
+
sizes_, tensor.sizes().data(), tensor.dim() * sizeof(int64_t));
|
| 77 |
+
std::memcpy(
|
| 78 |
+
strides_, tensor.strides().data(), tensor.dim() * sizeof(int64_t));
|
| 79 |
+
}
|
| 80 |
+
dim_ = std::get<1>(collapse_dims(sizes_, strides_, tensor.ndimension()));
|
| 81 |
+
}
|
| 82 |
+
};
|
| 83 |
+
|
| 84 |
+
template <typename T>
|
| 85 |
+
struct strided_tensor_iter {
|
| 86 |
+
private:
|
| 87 |
+
public:
|
| 88 |
+
T* data_ = NULL;
|
| 89 |
+
int64_t dim_;
|
| 90 |
+
|
| 91 |
+
std::vector<int64_t> counter_;
|
| 92 |
+
std::vector<int64_t> sizes_;
|
| 93 |
+
std::vector<int64_t> strides_;
|
| 94 |
+
|
| 95 |
+
strided_tensor_iter(strided_tensor_iter const&) = delete;
|
| 96 |
+
void operator=(strided_tensor_iter const& x) = delete;
|
| 97 |
+
strided_tensor_iter(strided_tensor_iter&&) = default;
|
| 98 |
+
strided_tensor_iter(Tensor& tensor)
|
| 99 |
+
: data_(tensor.data_ptr<T>()),
|
| 100 |
+
dim_(tensor.ndimension()),
|
| 101 |
+
counter_(dim_, 0),
|
| 102 |
+
sizes_(tensor.sizes().vec()),
|
| 103 |
+
strides_(tensor.strides().vec()) {
|
| 104 |
+
dim_ = std::get<1>(collapse_dims(sizes_.data(), strides_.data(), dim_));
|
| 105 |
+
}
|
| 106 |
+
};
|
| 107 |
+
|
| 108 |
+
inline bool _all_equal_numel(at::ArrayRef<Tensor> tensors) {
|
| 109 |
+
if (tensors.empty())
|
| 110 |
+
return true;
|
| 111 |
+
int64_t all_numel = tensors[0].numel();
|
| 112 |
+
for (const auto i : c10::irange(1, tensors.size())) {
|
| 113 |
+
if (tensors[i].numel() != all_numel)
|
| 114 |
+
return false;
|
| 115 |
+
}
|
| 116 |
+
return true;
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
inline std::string _all_equal_numel_error(at::ArrayRef<Tensor> tensors) {
|
| 120 |
+
std::ostringstream oss;
|
| 121 |
+
oss << "inconsistent tensor size, expected ";
|
| 122 |
+
for (size_t i = 0; i < tensors.size() - 1; i++) {
|
| 123 |
+
oss << tensors[i].sizes() << ", ";
|
| 124 |
+
}
|
| 125 |
+
oss << "and " << tensors[tensors.size() - 1].sizes()
|
| 126 |
+
<< " to have the same number of elements, but got ";
|
| 127 |
+
for (size_t i = 0; i < tensors.size() - 1; i++) {
|
| 128 |
+
oss << tensors[i].numel() << ", ";
|
| 129 |
+
}
|
| 130 |
+
oss << "and " << tensors[tensors.size() - 1].numel()
|
| 131 |
+
<< " elements respectively";
|
| 132 |
+
return oss.str();
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
inline bool _apply_preamble(ArrayRef<Tensor> tensors) {
|
| 136 |
+
checkDeviceType("CPU_tensor_apply", tensors, kCPU);
|
| 137 |
+
checkLayout("CPU_tensor_apply", tensors, kStrided);
|
| 138 |
+
if (!_all_equal_numel(tensors))
|
| 139 |
+
AT_ERROR(_all_equal_numel_error(tensors));
|
| 140 |
+
// An empty tensor has no elements
|
| 141 |
+
for (auto& t : tensors)
|
| 142 |
+
if (t.numel() == 0)
|
| 143 |
+
return false;
|
| 144 |
+
return true;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
inline int64_t _max_dim_tensors(ArrayRef<Tensor> tensors) {
|
| 148 |
+
int64_t dim = 0;
|
| 149 |
+
for (auto& t : tensors)
|
| 150 |
+
dim = std::max(dim, t.ndimension());
|
| 151 |
+
return dim;
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
inline void iterate(int64_t /*size*/){};
|
| 155 |
+
|
| 156 |
+
template <typename Arg, typename... Args>
|
| 157 |
+
inline void iterate(int64_t size, Arg& iter, Args&... iter_tail) {
|
| 158 |
+
iter.counter_[iter.dim_ - 1] += size;
|
| 159 |
+
iter.data_ = iter.data_ + size * iter.strides_[iter.dim_ - 1];
|
| 160 |
+
iterate(size, iter_tail...);
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
inline bool iterate_continue() {
|
| 164 |
+
return true;
|
| 165 |
+
};
|
| 166 |
+
|
| 167 |
+
template <typename Arg, typename... Args>
|
| 168 |
+
inline bool iterate_continue(Arg& iter, Args&... iter_tail) {
|
| 169 |
+
return iter.counter_[iter.dim_ - 1] < iter.sizes_[iter.dim_ - 1] &&
|
| 170 |
+
iterate_continue(iter_tail...);
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
inline int64_t max_iterate_size() {
|
| 174 |
+
return std::numeric_limits<int64_t>::max();
|
| 175 |
+
};
|
| 176 |
+
|
| 177 |
+
template <typename Arg, typename... Args>
|
| 178 |
+
inline int64_t max_iterate_size(Arg& iter, Args&... iter_tail) {
|
| 179 |
+
return std::min(
|
| 180 |
+
(iter.sizes_[iter.dim_ - 1] - iter.counter_[iter.dim_ - 1]),
|
| 181 |
+
max_iterate_size(iter_tail...));
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
inline void iterate_overflow(){};
|
| 185 |
+
|
| 186 |
+
template <typename Arg, typename... Args>
|
| 187 |
+
inline void iterate_overflow(Arg& iter, Args&... iter_tail) {
|
| 188 |
+
if (iter.counter_[iter.dim_ - 1] == iter.sizes_[iter.dim_ - 1]) {
|
| 189 |
+
for (int64_t i = iter.dim_ - 1; i > 0; i--) {
|
| 190 |
+
if (iter.counter_[i] == iter.sizes_[i]) {
|
| 191 |
+
iter.counter_[i] = 0;
|
| 192 |
+
iter.counter_[i - 1]++;
|
| 193 |
+
iter.data_ = iter.data_ - (iter.sizes_[i] * iter.strides_[i]) +
|
| 194 |
+
iter.strides_[i - 1];
|
| 195 |
+
}
|
| 196 |
+
}
|
| 197 |
+
}
|
| 198 |
+
iterate_overflow(iter_tail...);
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
inline void forward(int64_t /*offset*/){};
|
| 202 |
+
|
| 203 |
+
template <typename Arg, typename... Args>
|
| 204 |
+
inline void forward(int64_t offset, Arg& iter, Args&... iter_tail) {
|
| 205 |
+
int64_t multi = offset;
|
| 206 |
+
for (int64_t i = iter.dim_ - 1; i >= 0; i--) {
|
| 207 |
+
int64_t inc = multi % iter.sizes_[i];
|
| 208 |
+
multi = multi / iter.sizes_[i];
|
| 209 |
+
iter.data_ = iter.data_ + inc * iter.strides_[i];
|
| 210 |
+
iter.counter_[i] += inc;
|
| 211 |
+
}
|
| 212 |
+
forward(offset, iter_tail...);
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
inline int64_t max_dim() {
|
| 216 |
+
return 0;
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
template <typename Arg, typename... Args>
|
| 220 |
+
inline int64_t max_dim(Arg& iter, Args&... iter_tail) {
|
| 221 |
+
return std::max(iter.dim_, max_dim(iter_tail...));
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
inline void apply_op(){};
|
| 225 |
+
|
| 226 |
+
template <typename Op, typename... Args>
|
| 227 |
+
inline void apply_op(
|
| 228 |
+
int64_t numel,
|
| 229 |
+
int64_t offset,
|
| 230 |
+
const Op& op,
|
| 231 |
+
Args... iters) {
|
| 232 |
+
// For 0-dim tensors
|
| 233 |
+
if (numel == 1 && max_dim(iters...) == 0) {
|
| 234 |
+
op(*iters.data_...);
|
| 235 |
+
return;
|
| 236 |
+
}
|
| 237 |
+
if (offset > 0)
|
| 238 |
+
forward(offset, iters...);
|
| 239 |
+
// Splitting this into chunks helps the compiler create faster assembly
|
| 240 |
+
for (int64_t i = 0; i < numel;) {
|
| 241 |
+
for (; iterate_continue(iters...) && i < numel;) {
|
| 242 |
+
op(*iters.data_...);
|
| 243 |
+
iterate(1, iters...);
|
| 244 |
+
i++;
|
| 245 |
+
}
|
| 246 |
+
iterate_overflow(iters...);
|
| 247 |
+
}
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
/*
|
| 251 |
+
Apply a pointwise operator to sequence of tensors
|
| 252 |
+
|
| 253 |
+
The calling convention for op is a function/functor that takes the same
|
| 254 |
+
number of pointers of type scalar as the number of given tensors. For example,
|
| 255 |
+
to compute a = b * c, op would be of the form:
|
| 256 |
+
[](scalar* a_val, const scalar* b_val, const scalar* c_val) { a_val[0] =
|
| 257 |
+
b_val[0] * c_val[0]; };
|
| 258 |
+
*/
|
| 259 |
+
|
| 260 |
+
template <typename scalar1, typename scalar2, typename Op>
|
| 261 |
+
inline void CPU_tensor_apply2(Tensor tensor1, Tensor tensor2, const Op op) {
|
| 262 |
+
if (!_apply_preamble({tensor1, tensor2}))
|
| 263 |
+
return;
|
| 264 |
+
if (_max_dim_tensors({tensor1, tensor2}) <= 8) {
|
| 265 |
+
apply_op(
|
| 266 |
+
tensor1.numel(),
|
| 267 |
+
0,
|
| 268 |
+
op,
|
| 269 |
+
strided_tensor_iter_fixed<scalar1, 8>(tensor1),
|
| 270 |
+
strided_tensor_iter_fixed<scalar2, 8>(tensor2));
|
| 271 |
+
} else {
|
| 272 |
+
apply_op(
|
| 273 |
+
tensor1.numel(),
|
| 274 |
+
0,
|
| 275 |
+
op,
|
| 276 |
+
strided_tensor_iter<scalar1>(tensor1),
|
| 277 |
+
strided_tensor_iter<scalar2>(tensor2));
|
| 278 |
+
}
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
template <typename scalar1, typename scalar2, typename scalar3, typename Op>
|
| 282 |
+
inline void CPU_tensor_apply3(
|
| 283 |
+
Tensor tensor1,
|
| 284 |
+
Tensor tensor2,
|
| 285 |
+
Tensor tensor3,
|
| 286 |
+
const Op op) {
|
| 287 |
+
if (!_apply_preamble({tensor1, tensor2, tensor3}))
|
| 288 |
+
return;
|
| 289 |
+
if (_max_dim_tensors({tensor1, tensor2, tensor3}) <= 8) {
|
| 290 |
+
apply_op(
|
| 291 |
+
tensor1.numel(),
|
| 292 |
+
0,
|
| 293 |
+
op,
|
| 294 |
+
strided_tensor_iter_fixed<scalar1, 8>(tensor1),
|
| 295 |
+
strided_tensor_iter_fixed<scalar2, 8>(tensor2),
|
| 296 |
+
strided_tensor_iter_fixed<scalar3, 8>(tensor3));
|
| 297 |
+
} else {
|
| 298 |
+
apply_op(
|
| 299 |
+
tensor1.numel(),
|
| 300 |
+
0,
|
| 301 |
+
op,
|
| 302 |
+
strided_tensor_iter<scalar1>(tensor1),
|
| 303 |
+
strided_tensor_iter<scalar2>(tensor2),
|
| 304 |
+
strided_tensor_iter<scalar3>(tensor3));
|
| 305 |
+
}
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
template <
|
| 309 |
+
typename scalar1,
|
| 310 |
+
typename scalar2,
|
| 311 |
+
typename scalar3,
|
| 312 |
+
typename scalar4,
|
| 313 |
+
typename Op>
|
| 314 |
+
inline void CPU_tensor_apply4(
|
| 315 |
+
Tensor tensor1,
|
| 316 |
+
Tensor tensor2,
|
| 317 |
+
Tensor tensor3,
|
| 318 |
+
Tensor tensor4,
|
| 319 |
+
const Op op) {
|
| 320 |
+
if (!_apply_preamble({tensor1, tensor2, tensor3, tensor4}))
|
| 321 |
+
return;
|
| 322 |
+
if (_max_dim_tensors({tensor1, tensor2, tensor3, tensor4}) <= 8) {
|
| 323 |
+
apply_op(
|
| 324 |
+
tensor1.numel(),
|
| 325 |
+
0,
|
| 326 |
+
op,
|
| 327 |
+
strided_tensor_iter_fixed<scalar1, 8>(tensor1),
|
| 328 |
+
strided_tensor_iter_fixed<scalar2, 8>(tensor2),
|
| 329 |
+
strided_tensor_iter_fixed<scalar3, 8>(tensor3),
|
| 330 |
+
strided_tensor_iter_fixed<scalar4, 8>(tensor4));
|
| 331 |
+
} else {
|
| 332 |
+
apply_op(
|
| 333 |
+
tensor1.numel(),
|
| 334 |
+
0,
|
| 335 |
+
op,
|
| 336 |
+
strided_tensor_iter<scalar1>(tensor1),
|
| 337 |
+
strided_tensor_iter<scalar2>(tensor2),
|
| 338 |
+
strided_tensor_iter<scalar3>(tensor3),
|
| 339 |
+
strided_tensor_iter<scalar4>(tensor4));
|
| 340 |
+
}
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/CPUFixedAllocator.h
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/Allocator.h>
|
| 4 |
+
#include <c10/util/Exception.h>
|
| 5 |
+
|
| 6 |
+
// This file creates a fake allocator that just throws exceptions if
|
| 7 |
+
// it is actually used.
|
| 8 |
+
|
| 9 |
+
// state passed to the allocator is the std::function<void(void*)> called
|
| 10 |
+
// when the blob is release by ATen
|
| 11 |
+
|
| 12 |
+
namespace at {
|
| 13 |
+
|
| 14 |
+
static cpu_fixed_malloc(void*, ptrdiff_t) {
|
| 15 |
+
AT_ERROR("attempting to resize a tensor view of an external blob");
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
static cpu_fixed_realloc(void*, void*, ptrdiff_t) {
|
| 19 |
+
AT_ERROR("attempting to resize a tensor view of an external blob");
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
static cpu_fixed_free(void* state, void* allocation) {
|
| 23 |
+
auto on_release = static_cast<std::function<void(void*)>*>(state);
|
| 24 |
+
(*on_release)(allocation);
|
| 25 |
+
delete on_release;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
static Allocator CPU_fixed_allocator = {
|
| 29 |
+
cpu_fixed_malloc,
|
| 30 |
+
cpu_fixed_realloc,
|
| 31 |
+
cpu_fixed_free};
|
| 32 |
+
|
| 33 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/CPUFunctions.h
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/TensorBody.h>
|
| 2 |
+
|
| 3 |
+
// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
|
| 4 |
+
// Code introduced to avoid cyclic dependency in static dispatch is no longer
|
| 5 |
+
// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
|
| 6 |
+
// to Operators.cpp for supporting multiple backends with multiple kernels.
|
| 7 |
+
//
|
| 8 |
+
// Note [Avoiding Include Cycles In Static Dispatch]
|
| 9 |
+
// In order to avoid #include cycles in the static dispatch build, we've carefully split out
|
| 10 |
+
// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
|
| 11 |
+
//
|
| 12 |
+
// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
|
| 13 |
+
// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
|
| 14 |
+
// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
|
| 15 |
+
// directly inlined into TensorBody.h.
|
| 16 |
+
// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
|
| 17 |
+
// which include functions that have defaultable std::optional<Tensor> arguments.
|
| 18 |
+
// That requires knowing the full Tensor class definition.
|
| 19 |
+
//
|
| 20 |
+
// We break the cycle by doing the following:
|
| 21 |
+
// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
|
| 22 |
+
// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
|
| 23 |
+
// - CPUFunctions_inl.h includes everything else
|
| 24 |
+
// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
|
| 25 |
+
// and then it includes CPUFunctions_inl.h.
|
| 26 |
+
// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
|
| 27 |
+
// - This also means that static dispatch build, CPUFunctions.h only needs to
|
| 28 |
+
// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
|
| 29 |
+
#include <ATen/CPUFunctions_inl.h>
|
.venv/lib/python3.11/site-packages/torch/include/ATen/CPUGeneratorImpl.h
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Generator.h>
|
| 4 |
+
#include <ATen/core/MT19937RNGEngine.h>
|
| 5 |
+
#include <c10/core/GeneratorImpl.h>
|
| 6 |
+
#include <optional>
|
| 7 |
+
|
| 8 |
+
namespace at {
|
| 9 |
+
|
| 10 |
+
struct TORCH_API CPUGeneratorImpl : public c10::GeneratorImpl {
|
| 11 |
+
// Constructors
|
| 12 |
+
CPUGeneratorImpl(uint64_t seed_in = default_rng_seed_val);
|
| 13 |
+
~CPUGeneratorImpl() override = default;
|
| 14 |
+
|
| 15 |
+
// CPUGeneratorImpl methods
|
| 16 |
+
std::shared_ptr<CPUGeneratorImpl> clone() const;
|
| 17 |
+
void set_current_seed(uint64_t seed) override;
|
| 18 |
+
void set_offset(uint64_t offset) override;
|
| 19 |
+
uint64_t get_offset() const override;
|
| 20 |
+
uint64_t current_seed() const override;
|
| 21 |
+
uint64_t seed() override;
|
| 22 |
+
void set_state(const c10::TensorImpl& new_state) override;
|
| 23 |
+
c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
|
| 24 |
+
static c10::DeviceType device_type();
|
| 25 |
+
uint32_t random();
|
| 26 |
+
uint64_t random64();
|
| 27 |
+
std::optional<float> next_float_normal_sample();
|
| 28 |
+
std::optional<double> next_double_normal_sample();
|
| 29 |
+
void set_next_float_normal_sample(std::optional<float> randn);
|
| 30 |
+
void set_next_double_normal_sample(std::optional<double> randn);
|
| 31 |
+
at::mt19937 engine();
|
| 32 |
+
void set_engine(at::mt19937 engine);
|
| 33 |
+
|
| 34 |
+
private:
|
| 35 |
+
CPUGeneratorImpl* clone_impl() const override;
|
| 36 |
+
at::mt19937 engine_;
|
| 37 |
+
std::optional<float> next_float_normal_sample_;
|
| 38 |
+
std::optional<double> next_double_normal_sample_;
|
| 39 |
+
};
|
| 40 |
+
|
| 41 |
+
namespace detail {
|
| 42 |
+
|
| 43 |
+
TORCH_API const Generator& getDefaultCPUGenerator();
|
| 44 |
+
TORCH_API Generator
|
| 45 |
+
createCPUGenerator(uint64_t seed_val = default_rng_seed_val);
|
| 46 |
+
|
| 47 |
+
} // namespace detail
|
| 48 |
+
|
| 49 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/CompositeExplicitAutogradFunctions_inl.h
ADDED
|
@@ -0,0 +1,553 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 12 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 13 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 14 |
+
Consider including a specific operator from \
|
| 15 |
+
<ATen/ops/{my_operator}_compositeexplicitautograd_dispatch.h>. \
|
| 16 |
+
See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
#include <ATen/ops/_adaptive_avg_pool2d_compositeexplicitautograd_dispatch.h>
|
| 20 |
+
#include <ATen/ops/_adaptive_avg_pool2d_backward_compositeexplicitautograd_dispatch.h>
|
| 21 |
+
#include <ATen/ops/_adaptive_avg_pool3d_compositeexplicitautograd_dispatch.h>
|
| 22 |
+
#include <ATen/ops/_adaptive_avg_pool3d_backward_compositeexplicitautograd_dispatch.h>
|
| 23 |
+
#include <ATen/ops/_add_relu_compositeexplicitautograd_dispatch.h>
|
| 24 |
+
#include <ATen/ops/_aminmax_compositeexplicitautograd_dispatch.h>
|
| 25 |
+
#include <ATen/ops/_amp_foreach_non_finite_check_and_unscale_compositeexplicitautograd_dispatch.h>
|
| 26 |
+
#include <ATen/ops/_amp_update_scale_compositeexplicitautograd_dispatch.h>
|
| 27 |
+
#include <ATen/ops/_assert_scalar_compositeexplicitautograd_dispatch.h>
|
| 28 |
+
#include <ATen/ops/_batch_norm_no_update_compositeexplicitautograd_dispatch.h>
|
| 29 |
+
#include <ATen/ops/_batch_norm_with_update_compositeexplicitautograd_dispatch.h>
|
| 30 |
+
#include <ATen/ops/_cdist_backward_compositeexplicitautograd_dispatch.h>
|
| 31 |
+
#include <ATen/ops/_cdist_forward_compositeexplicitautograd_dispatch.h>
|
| 32 |
+
#include <ATen/ops/_cholesky_solve_helper_compositeexplicitautograd_dispatch.h>
|
| 33 |
+
#include <ATen/ops/_chunk_cat_compositeexplicitautograd_dispatch.h>
|
| 34 |
+
#include <ATen/ops/_coalesce_compositeexplicitautograd_dispatch.h>
|
| 35 |
+
#include <ATen/ops/_coalesced_compositeexplicitautograd_dispatch.h>
|
| 36 |
+
#include <ATen/ops/_conj_compositeexplicitautograd_dispatch.h>
|
| 37 |
+
#include <ATen/ops/_conj_copy_compositeexplicitautograd_dispatch.h>
|
| 38 |
+
#include <ATen/ops/_conj_physical_compositeexplicitautograd_dispatch.h>
|
| 39 |
+
#include <ATen/ops/_convolution_compositeexplicitautograd_dispatch.h>
|
| 40 |
+
#include <ATen/ops/_copy_from_compositeexplicitautograd_dispatch.h>
|
| 41 |
+
#include <ATen/ops/_copy_from_and_resize_compositeexplicitautograd_dispatch.h>
|
| 42 |
+
#include <ATen/ops/_ctc_loss_compositeexplicitautograd_dispatch.h>
|
| 43 |
+
#include <ATen/ops/_ctc_loss_backward_compositeexplicitautograd_dispatch.h>
|
| 44 |
+
#include <ATen/ops/_cudnn_ctc_loss_compositeexplicitautograd_dispatch.h>
|
| 45 |
+
#include <ATen/ops/_cudnn_init_dropout_state_compositeexplicitautograd_dispatch.h>
|
| 46 |
+
#include <ATen/ops/_cudnn_rnn_compositeexplicitautograd_dispatch.h>
|
| 47 |
+
#include <ATen/ops/_cudnn_rnn_backward_compositeexplicitautograd_dispatch.h>
|
| 48 |
+
#include <ATen/ops/_cudnn_rnn_flatten_weight_compositeexplicitautograd_dispatch.h>
|
| 49 |
+
#include <ATen/ops/_dirichlet_grad_compositeexplicitautograd_dispatch.h>
|
| 50 |
+
#include <ATen/ops/_efficientzerotensor_compositeexplicitautograd_dispatch.h>
|
| 51 |
+
#include <ATen/ops/_embedding_bag_compositeexplicitautograd_dispatch.h>
|
| 52 |
+
#include <ATen/ops/_embedding_bag_dense_backward_compositeexplicitautograd_dispatch.h>
|
| 53 |
+
#include <ATen/ops/_embedding_bag_forward_only_compositeexplicitautograd_dispatch.h>
|
| 54 |
+
#include <ATen/ops/_embedding_bag_per_sample_weights_backward_compositeexplicitautograd_dispatch.h>
|
| 55 |
+
#include <ATen/ops/_empty_affine_quantized_compositeexplicitautograd_dispatch.h>
|
| 56 |
+
#include <ATen/ops/_empty_per_channel_affine_quantized_compositeexplicitautograd_dispatch.h>
|
| 57 |
+
#include <ATen/ops/_euclidean_dist_compositeexplicitautograd_dispatch.h>
|
| 58 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_compositeexplicitautograd_dispatch.h>
|
| 59 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_compositeexplicitautograd_dispatch.h>
|
| 60 |
+
#include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_compositeexplicitautograd_dispatch.h>
|
| 61 |
+
#include <ATen/ops/_foobar_compositeexplicitautograd_dispatch.h>
|
| 62 |
+
#include <ATen/ops/_foreach_abs_compositeexplicitautograd_dispatch.h>
|
| 63 |
+
#include <ATen/ops/_foreach_acos_compositeexplicitautograd_dispatch.h>
|
| 64 |
+
#include <ATen/ops/_foreach_add_compositeexplicitautograd_dispatch.h>
|
| 65 |
+
#include <ATen/ops/_foreach_addcdiv_compositeexplicitautograd_dispatch.h>
|
| 66 |
+
#include <ATen/ops/_foreach_addcmul_compositeexplicitautograd_dispatch.h>
|
| 67 |
+
#include <ATen/ops/_foreach_asin_compositeexplicitautograd_dispatch.h>
|
| 68 |
+
#include <ATen/ops/_foreach_atan_compositeexplicitautograd_dispatch.h>
|
| 69 |
+
#include <ATen/ops/_foreach_ceil_compositeexplicitautograd_dispatch.h>
|
| 70 |
+
#include <ATen/ops/_foreach_clamp_max_compositeexplicitautograd_dispatch.h>
|
| 71 |
+
#include <ATen/ops/_foreach_clamp_min_compositeexplicitautograd_dispatch.h>
|
| 72 |
+
#include <ATen/ops/_foreach_copy_compositeexplicitautograd_dispatch.h>
|
| 73 |
+
#include <ATen/ops/_foreach_cos_compositeexplicitautograd_dispatch.h>
|
| 74 |
+
#include <ATen/ops/_foreach_cosh_compositeexplicitautograd_dispatch.h>
|
| 75 |
+
#include <ATen/ops/_foreach_div_compositeexplicitautograd_dispatch.h>
|
| 76 |
+
#include <ATen/ops/_foreach_erf_compositeexplicitautograd_dispatch.h>
|
| 77 |
+
#include <ATen/ops/_foreach_erfc_compositeexplicitautograd_dispatch.h>
|
| 78 |
+
#include <ATen/ops/_foreach_exp_compositeexplicitautograd_dispatch.h>
|
| 79 |
+
#include <ATen/ops/_foreach_expm1_compositeexplicitautograd_dispatch.h>
|
| 80 |
+
#include <ATen/ops/_foreach_floor_compositeexplicitautograd_dispatch.h>
|
| 81 |
+
#include <ATen/ops/_foreach_frac_compositeexplicitautograd_dispatch.h>
|
| 82 |
+
#include <ATen/ops/_foreach_lerp_compositeexplicitautograd_dispatch.h>
|
| 83 |
+
#include <ATen/ops/_foreach_lgamma_compositeexplicitautograd_dispatch.h>
|
| 84 |
+
#include <ATen/ops/_foreach_log_compositeexplicitautograd_dispatch.h>
|
| 85 |
+
#include <ATen/ops/_foreach_log10_compositeexplicitautograd_dispatch.h>
|
| 86 |
+
#include <ATen/ops/_foreach_log1p_compositeexplicitautograd_dispatch.h>
|
| 87 |
+
#include <ATen/ops/_foreach_log2_compositeexplicitautograd_dispatch.h>
|
| 88 |
+
#include <ATen/ops/_foreach_max_compositeexplicitautograd_dispatch.h>
|
| 89 |
+
#include <ATen/ops/_foreach_maximum_compositeexplicitautograd_dispatch.h>
|
| 90 |
+
#include <ATen/ops/_foreach_minimum_compositeexplicitautograd_dispatch.h>
|
| 91 |
+
#include <ATen/ops/_foreach_mul_compositeexplicitautograd_dispatch.h>
|
| 92 |
+
#include <ATen/ops/_foreach_neg_compositeexplicitautograd_dispatch.h>
|
| 93 |
+
#include <ATen/ops/_foreach_norm_compositeexplicitautograd_dispatch.h>
|
| 94 |
+
#include <ATen/ops/_foreach_pow_compositeexplicitautograd_dispatch.h>
|
| 95 |
+
#include <ATen/ops/_foreach_reciprocal_compositeexplicitautograd_dispatch.h>
|
| 96 |
+
#include <ATen/ops/_foreach_round_compositeexplicitautograd_dispatch.h>
|
| 97 |
+
#include <ATen/ops/_foreach_sigmoid_compositeexplicitautograd_dispatch.h>
|
| 98 |
+
#include <ATen/ops/_foreach_sign_compositeexplicitautograd_dispatch.h>
|
| 99 |
+
#include <ATen/ops/_foreach_sin_compositeexplicitautograd_dispatch.h>
|
| 100 |
+
#include <ATen/ops/_foreach_sinh_compositeexplicitautograd_dispatch.h>
|
| 101 |
+
#include <ATen/ops/_foreach_sqrt_compositeexplicitautograd_dispatch.h>
|
| 102 |
+
#include <ATen/ops/_foreach_sub_compositeexplicitautograd_dispatch.h>
|
| 103 |
+
#include <ATen/ops/_foreach_tan_compositeexplicitautograd_dispatch.h>
|
| 104 |
+
#include <ATen/ops/_foreach_tanh_compositeexplicitautograd_dispatch.h>
|
| 105 |
+
#include <ATen/ops/_foreach_trunc_compositeexplicitautograd_dispatch.h>
|
| 106 |
+
#include <ATen/ops/_foreach_zero_compositeexplicitautograd_dispatch.h>
|
| 107 |
+
#include <ATen/ops/_functional_assert_scalar_compositeexplicitautograd_dispatch.h>
|
| 108 |
+
#include <ATen/ops/_functional_sym_constrain_range_compositeexplicitautograd_dispatch.h>
|
| 109 |
+
#include <ATen/ops/_functional_sym_constrain_range_for_size_compositeexplicitautograd_dispatch.h>
|
| 110 |
+
#include <ATen/ops/_fused_adagrad_compositeexplicitautograd_dispatch.h>
|
| 111 |
+
#include <ATen/ops/_fused_adam_compositeexplicitautograd_dispatch.h>
|
| 112 |
+
#include <ATen/ops/_fused_adamw_compositeexplicitautograd_dispatch.h>
|
| 113 |
+
#include <ATen/ops/_fused_dropout_compositeexplicitautograd_dispatch.h>
|
| 114 |
+
#include <ATen/ops/_fused_moving_avg_obs_fq_helper_compositeexplicitautograd_dispatch.h>
|
| 115 |
+
#include <ATen/ops/_fused_sgd_compositeexplicitautograd_dispatch.h>
|
| 116 |
+
#include <ATen/ops/_fw_primal_compositeexplicitautograd_dispatch.h>
|
| 117 |
+
#include <ATen/ops/_fw_primal_copy_compositeexplicitautograd_dispatch.h>
|
| 118 |
+
#include <ATen/ops/_grid_sampler_2d_cpu_fallback_compositeexplicitautograd_dispatch.h>
|
| 119 |
+
#include <ATen/ops/_has_same_storage_numel_compositeexplicitautograd_dispatch.h>
|
| 120 |
+
#include <ATen/ops/_histogramdd_bin_edges_compositeexplicitautograd_dispatch.h>
|
| 121 |
+
#include <ATen/ops/_histogramdd_from_bin_cts_compositeexplicitautograd_dispatch.h>
|
| 122 |
+
#include <ATen/ops/_histogramdd_from_bin_tensors_compositeexplicitautograd_dispatch.h>
|
| 123 |
+
#include <ATen/ops/_index_put_impl_compositeexplicitautograd_dispatch.h>
|
| 124 |
+
#include <ATen/ops/_indices_copy_compositeexplicitautograd_dispatch.h>
|
| 125 |
+
#include <ATen/ops/_is_all_true_compositeexplicitautograd_dispatch.h>
|
| 126 |
+
#include <ATen/ops/_is_any_true_compositeexplicitautograd_dispatch.h>
|
| 127 |
+
#include <ATen/ops/_lazy_clone_compositeexplicitautograd_dispatch.h>
|
| 128 |
+
#include <ATen/ops/_linalg_check_errors_compositeexplicitautograd_dispatch.h>
|
| 129 |
+
#include <ATen/ops/_lstm_mps_compositeexplicitautograd_dispatch.h>
|
| 130 |
+
#include <ATen/ops/_make_dual_compositeexplicitautograd_dispatch.h>
|
| 131 |
+
#include <ATen/ops/_make_dual_copy_compositeexplicitautograd_dispatch.h>
|
| 132 |
+
#include <ATen/ops/_make_per_channel_quantized_tensor_compositeexplicitautograd_dispatch.h>
|
| 133 |
+
#include <ATen/ops/_make_per_tensor_quantized_tensor_compositeexplicitautograd_dispatch.h>
|
| 134 |
+
#include <ATen/ops/_masked_scale_compositeexplicitautograd_dispatch.h>
|
| 135 |
+
#include <ATen/ops/_masked_softmax_compositeexplicitautograd_dispatch.h>
|
| 136 |
+
#include <ATen/ops/_masked_softmax_backward_compositeexplicitautograd_dispatch.h>
|
| 137 |
+
#include <ATen/ops/_mkldnn_reshape_compositeexplicitautograd_dispatch.h>
|
| 138 |
+
#include <ATen/ops/_mkldnn_transpose_compositeexplicitautograd_dispatch.h>
|
| 139 |
+
#include <ATen/ops/_mps_convolution_compositeexplicitautograd_dispatch.h>
|
| 140 |
+
#include <ATen/ops/_mps_convolution_transpose_compositeexplicitautograd_dispatch.h>
|
| 141 |
+
#include <ATen/ops/_native_batch_norm_legit_compositeexplicitautograd_dispatch.h>
|
| 142 |
+
#include <ATen/ops/_native_batch_norm_legit_no_training_compositeexplicitautograd_dispatch.h>
|
| 143 |
+
#include <ATen/ops/_native_multi_head_attention_compositeexplicitautograd_dispatch.h>
|
| 144 |
+
#include <ATen/ops/_neg_view_compositeexplicitautograd_dispatch.h>
|
| 145 |
+
#include <ATen/ops/_neg_view_copy_compositeexplicitautograd_dispatch.h>
|
| 146 |
+
#include <ATen/ops/_nested_from_padded_compositeexplicitautograd_dispatch.h>
|
| 147 |
+
#include <ATen/ops/_nested_from_padded_and_nested_example_compositeexplicitautograd_dispatch.h>
|
| 148 |
+
#include <ATen/ops/_nested_get_values_copy_compositeexplicitautograd_dispatch.h>
|
| 149 |
+
#include <ATen/ops/_nested_tensor_from_mask_compositeexplicitautograd_dispatch.h>
|
| 150 |
+
#include <ATen/ops/_nested_tensor_from_tensor_list_compositeexplicitautograd_dispatch.h>
|
| 151 |
+
#include <ATen/ops/_nested_tensor_size_compositeexplicitautograd_dispatch.h>
|
| 152 |
+
#include <ATen/ops/_nested_tensor_storage_offsets_compositeexplicitautograd_dispatch.h>
|
| 153 |
+
#include <ATen/ops/_nested_tensor_strides_compositeexplicitautograd_dispatch.h>
|
| 154 |
+
#include <ATen/ops/_nested_view_from_buffer_copy_compositeexplicitautograd_dispatch.h>
|
| 155 |
+
#include <ATen/ops/_nested_view_from_jagged_copy_compositeexplicitautograd_dispatch.h>
|
| 156 |
+
#include <ATen/ops/_new_zeros_with_same_feature_meta_compositeexplicitautograd_dispatch.h>
|
| 157 |
+
#include <ATen/ops/_nnpack_spatial_convolution_compositeexplicitautograd_dispatch.h>
|
| 158 |
+
#include <ATen/ops/_pack_padded_sequence_compositeexplicitautograd_dispatch.h>
|
| 159 |
+
#include <ATen/ops/_pdist_backward_compositeexplicitautograd_dispatch.h>
|
| 160 |
+
#include <ATen/ops/_pdist_forward_compositeexplicitautograd_dispatch.h>
|
| 161 |
+
#include <ATen/ops/_pin_memory_compositeexplicitautograd_dispatch.h>
|
| 162 |
+
#include <ATen/ops/_print_compositeexplicitautograd_dispatch.h>
|
| 163 |
+
#include <ATen/ops/_reshape_alias_copy_compositeexplicitautograd_dispatch.h>
|
| 164 |
+
#include <ATen/ops/_reshape_copy_compositeexplicitautograd_dispatch.h>
|
| 165 |
+
#include <ATen/ops/_resize_output_compositeexplicitautograd_dispatch.h>
|
| 166 |
+
#include <ATen/ops/_safe_softmax_compositeexplicitautograd_dispatch.h>
|
| 167 |
+
#include <ATen/ops/_sample_dirichlet_compositeexplicitautograd_dispatch.h>
|
| 168 |
+
#include <ATen/ops/_scaled_dot_product_fused_attention_overrideable_compositeexplicitautograd_dispatch.h>
|
| 169 |
+
#include <ATen/ops/_scaled_dot_product_fused_attention_overrideable_backward_compositeexplicitautograd_dispatch.h>
|
| 170 |
+
#include <ATen/ops/_segment_reduce_backward_compositeexplicitautograd_dispatch.h>
|
| 171 |
+
#include <ATen/ops/_slow_conv2d_backward_compositeexplicitautograd_dispatch.h>
|
| 172 |
+
#include <ATen/ops/_sparse_addmm_compositeexplicitautograd_dispatch.h>
|
| 173 |
+
#include <ATen/ops/_sparse_broadcast_to_copy_compositeexplicitautograd_dispatch.h>
|
| 174 |
+
#include <ATen/ops/_sparse_compressed_tensor_with_dims_compositeexplicitautograd_dispatch.h>
|
| 175 |
+
#include <ATen/ops/_sparse_coo_tensor_with_dims_compositeexplicitautograd_dispatch.h>
|
| 176 |
+
#include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors_compositeexplicitautograd_dispatch.h>
|
| 177 |
+
#include <ATen/ops/_sparse_csr_prod_compositeexplicitautograd_dispatch.h>
|
| 178 |
+
#include <ATen/ops/_sparse_csr_sum_compositeexplicitautograd_dispatch.h>
|
| 179 |
+
#include <ATen/ops/_sparse_log_softmax_compositeexplicitautograd_dispatch.h>
|
| 180 |
+
#include <ATen/ops/_sparse_log_softmax_backward_data_compositeexplicitautograd_dispatch.h>
|
| 181 |
+
#include <ATen/ops/_sparse_mask_projection_compositeexplicitautograd_dispatch.h>
|
| 182 |
+
#include <ATen/ops/_sparse_softmax_compositeexplicitautograd_dispatch.h>
|
| 183 |
+
#include <ATen/ops/_sparse_softmax_backward_data_compositeexplicitautograd_dispatch.h>
|
| 184 |
+
#include <ATen/ops/_sparse_sparse_matmul_compositeexplicitautograd_dispatch.h>
|
| 185 |
+
#include <ATen/ops/_sparse_sum_compositeexplicitautograd_dispatch.h>
|
| 186 |
+
#include <ATen/ops/_sparse_sum_backward_compositeexplicitautograd_dispatch.h>
|
| 187 |
+
#include <ATen/ops/_spdiags_compositeexplicitautograd_dispatch.h>
|
| 188 |
+
#include <ATen/ops/_stack_compositeexplicitautograd_dispatch.h>
|
| 189 |
+
#include <ATen/ops/_standard_gamma_compositeexplicitautograd_dispatch.h>
|
| 190 |
+
#include <ATen/ops/_standard_gamma_grad_compositeexplicitautograd_dispatch.h>
|
| 191 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_compositeexplicitautograd_dispatch.h>
|
| 192 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_view_compositeexplicitautograd_dispatch.h>
|
| 193 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_view_copy_compositeexplicitautograd_dispatch.h>
|
| 194 |
+
#include <ATen/ops/_test_functorch_fallback_compositeexplicitautograd_dispatch.h>
|
| 195 |
+
#include <ATen/ops/_test_optional_filled_intlist_compositeexplicitautograd_dispatch.h>
|
| 196 |
+
#include <ATen/ops/_test_optional_floatlist_compositeexplicitautograd_dispatch.h>
|
| 197 |
+
#include <ATen/ops/_test_optional_intlist_compositeexplicitautograd_dispatch.h>
|
| 198 |
+
#include <ATen/ops/_test_parallel_materialize_compositeexplicitautograd_dispatch.h>
|
| 199 |
+
#include <ATen/ops/_test_warn_in_autograd_compositeexplicitautograd_dispatch.h>
|
| 200 |
+
#include <ATen/ops/_thnn_fused_gru_cell_compositeexplicitautograd_dispatch.h>
|
| 201 |
+
#include <ATen/ops/_thnn_fused_gru_cell_backward_compositeexplicitautograd_dispatch.h>
|
| 202 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_compositeexplicitautograd_dispatch.h>
|
| 203 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_backward_impl_compositeexplicitautograd_dispatch.h>
|
| 204 |
+
#include <ATen/ops/_to_copy_compositeexplicitautograd_dispatch.h>
|
| 205 |
+
#include <ATen/ops/_to_dense_compositeexplicitautograd_dispatch.h>
|
| 206 |
+
#include <ATen/ops/_to_sparse_compositeexplicitautograd_dispatch.h>
|
| 207 |
+
#include <ATen/ops/_to_sparse_bsc_compositeexplicitautograd_dispatch.h>
|
| 208 |
+
#include <ATen/ops/_to_sparse_bsr_compositeexplicitautograd_dispatch.h>
|
| 209 |
+
#include <ATen/ops/_to_sparse_csc_compositeexplicitautograd_dispatch.h>
|
| 210 |
+
#include <ATen/ops/_to_sparse_csr_compositeexplicitautograd_dispatch.h>
|
| 211 |
+
#include <ATen/ops/_transform_bias_rescale_qkv_compositeexplicitautograd_dispatch.h>
|
| 212 |
+
#include <ATen/ops/_transformer_encoder_layer_fwd_compositeexplicitautograd_dispatch.h>
|
| 213 |
+
#include <ATen/ops/_trilinear_compositeexplicitautograd_dispatch.h>
|
| 214 |
+
#include <ATen/ops/_triton_multi_head_attention_compositeexplicitautograd_dispatch.h>
|
| 215 |
+
#include <ATen/ops/_triton_scaled_dot_attention_compositeexplicitautograd_dispatch.h>
|
| 216 |
+
#include <ATen/ops/_unique_compositeexplicitautograd_dispatch.h>
|
| 217 |
+
#include <ATen/ops/_unique2_compositeexplicitautograd_dispatch.h>
|
| 218 |
+
#include <ATen/ops/_unsafe_index_compositeexplicitautograd_dispatch.h>
|
| 219 |
+
#include <ATen/ops/_unsafe_index_put_compositeexplicitautograd_dispatch.h>
|
| 220 |
+
#include <ATen/ops/_unsafe_masked_index_compositeexplicitautograd_dispatch.h>
|
| 221 |
+
#include <ATen/ops/_unsafe_masked_index_put_accumulate_compositeexplicitautograd_dispatch.h>
|
| 222 |
+
#include <ATen/ops/_unsafe_view_compositeexplicitautograd_dispatch.h>
|
| 223 |
+
#include <ATen/ops/_values_copy_compositeexplicitautograd_dispatch.h>
|
| 224 |
+
#include <ATen/ops/_weight_norm_interface_compositeexplicitautograd_dispatch.h>
|
| 225 |
+
#include <ATen/ops/_weight_norm_interface_backward_compositeexplicitautograd_dispatch.h>
|
| 226 |
+
#include <ATen/ops/abs_compositeexplicitautograd_dispatch.h>
|
| 227 |
+
#include <ATen/ops/add_compositeexplicitautograd_dispatch.h>
|
| 228 |
+
#include <ATen/ops/addr_compositeexplicitautograd_dispatch.h>
|
| 229 |
+
#include <ATen/ops/affine_grid_generator_compositeexplicitautograd_dispatch.h>
|
| 230 |
+
#include <ATen/ops/alias_compositeexplicitautograd_dispatch.h>
|
| 231 |
+
#include <ATen/ops/alias_copy_compositeexplicitautograd_dispatch.h>
|
| 232 |
+
#include <ATen/ops/all_compositeexplicitautograd_dispatch.h>
|
| 233 |
+
#include <ATen/ops/allclose_compositeexplicitautograd_dispatch.h>
|
| 234 |
+
#include <ATen/ops/any_compositeexplicitautograd_dispatch.h>
|
| 235 |
+
#include <ATen/ops/arange_compositeexplicitautograd_dispatch.h>
|
| 236 |
+
#include <ATen/ops/as_strided_copy_compositeexplicitautograd_dispatch.h>
|
| 237 |
+
#include <ATen/ops/as_strided_scatter_compositeexplicitautograd_dispatch.h>
|
| 238 |
+
#include <ATen/ops/bartlett_window_compositeexplicitautograd_dispatch.h>
|
| 239 |
+
#include <ATen/ops/batch_norm_backward_elemt_compositeexplicitautograd_dispatch.h>
|
| 240 |
+
#include <ATen/ops/batch_norm_backward_reduce_compositeexplicitautograd_dispatch.h>
|
| 241 |
+
#include <ATen/ops/batch_norm_gather_stats_compositeexplicitautograd_dispatch.h>
|
| 242 |
+
#include <ATen/ops/batch_norm_gather_stats_with_counts_compositeexplicitautograd_dispatch.h>
|
| 243 |
+
#include <ATen/ops/batch_norm_stats_compositeexplicitautograd_dispatch.h>
|
| 244 |
+
#include <ATen/ops/batch_norm_update_stats_compositeexplicitautograd_dispatch.h>
|
| 245 |
+
#include <ATen/ops/bernoulli_compositeexplicitautograd_dispatch.h>
|
| 246 |
+
#include <ATen/ops/binary_cross_entropy_with_logits_compositeexplicitautograd_dispatch.h>
|
| 247 |
+
#include <ATen/ops/bincount_compositeexplicitautograd_dispatch.h>
|
| 248 |
+
#include <ATen/ops/binomial_compositeexplicitautograd_dispatch.h>
|
| 249 |
+
#include <ATen/ops/bitwise_and_compositeexplicitautograd_dispatch.h>
|
| 250 |
+
#include <ATen/ops/bitwise_left_shift_compositeexplicitautograd_dispatch.h>
|
| 251 |
+
#include <ATen/ops/bitwise_or_compositeexplicitautograd_dispatch.h>
|
| 252 |
+
#include <ATen/ops/bitwise_right_shift_compositeexplicitautograd_dispatch.h>
|
| 253 |
+
#include <ATen/ops/bitwise_xor_compositeexplicitautograd_dispatch.h>
|
| 254 |
+
#include <ATen/ops/blackman_window_compositeexplicitautograd_dispatch.h>
|
| 255 |
+
#include <ATen/ops/block_diag_compositeexplicitautograd_dispatch.h>
|
| 256 |
+
#include <ATen/ops/bucketize_compositeexplicitautograd_dispatch.h>
|
| 257 |
+
#include <ATen/ops/cauchy_compositeexplicitautograd_dispatch.h>
|
| 258 |
+
#include <ATen/ops/ccol_indices_compositeexplicitautograd_dispatch.h>
|
| 259 |
+
#include <ATen/ops/ccol_indices_copy_compositeexplicitautograd_dispatch.h>
|
| 260 |
+
#include <ATen/ops/celu_compositeexplicitautograd_dispatch.h>
|
| 261 |
+
#include <ATen/ops/channel_shuffle_compositeexplicitautograd_dispatch.h>
|
| 262 |
+
#include <ATen/ops/cholesky_solve_compositeexplicitautograd_dispatch.h>
|
| 263 |
+
#include <ATen/ops/clone_compositeexplicitautograd_dispatch.h>
|
| 264 |
+
#include <ATen/ops/col_indices_compositeexplicitautograd_dispatch.h>
|
| 265 |
+
#include <ATen/ops/col_indices_copy_compositeexplicitautograd_dispatch.h>
|
| 266 |
+
#include <ATen/ops/complex_compositeexplicitautograd_dispatch.h>
|
| 267 |
+
#include <ATen/ops/conj_physical_compositeexplicitautograd_dispatch.h>
|
| 268 |
+
#include <ATen/ops/constant_pad_nd_compositeexplicitautograd_dispatch.h>
|
| 269 |
+
#include <ATen/ops/conv_depthwise3d_compositeexplicitautograd_dispatch.h>
|
| 270 |
+
#include <ATen/ops/conv_tbc_compositeexplicitautograd_dispatch.h>
|
| 271 |
+
#include <ATen/ops/convolution_compositeexplicitautograd_dispatch.h>
|
| 272 |
+
#include <ATen/ops/convolution_backward_compositeexplicitautograd_dispatch.h>
|
| 273 |
+
#include <ATen/ops/convolution_backward_overrideable_compositeexplicitautograd_dispatch.h>
|
| 274 |
+
#include <ATen/ops/convolution_overrideable_compositeexplicitautograd_dispatch.h>
|
| 275 |
+
#include <ATen/ops/copy_compositeexplicitautograd_dispatch.h>
|
| 276 |
+
#include <ATen/ops/copy_sparse_to_sparse_compositeexplicitautograd_dispatch.h>
|
| 277 |
+
#include <ATen/ops/copysign_compositeexplicitautograd_dispatch.h>
|
| 278 |
+
#include <ATen/ops/count_nonzero_compositeexplicitautograd_dispatch.h>
|
| 279 |
+
#include <ATen/ops/crow_indices_compositeexplicitautograd_dispatch.h>
|
| 280 |
+
#include <ATen/ops/crow_indices_copy_compositeexplicitautograd_dispatch.h>
|
| 281 |
+
#include <ATen/ops/cudnn_affine_grid_generator_compositeexplicitautograd_dispatch.h>
|
| 282 |
+
#include <ATen/ops/cudnn_affine_grid_generator_backward_compositeexplicitautograd_dispatch.h>
|
| 283 |
+
#include <ATen/ops/cudnn_batch_norm_compositeexplicitautograd_dispatch.h>
|
| 284 |
+
#include <ATen/ops/cudnn_batch_norm_backward_compositeexplicitautograd_dispatch.h>
|
| 285 |
+
#include <ATen/ops/cudnn_convolution_add_relu_compositeexplicitautograd_dispatch.h>
|
| 286 |
+
#include <ATen/ops/cudnn_convolution_relu_compositeexplicitautograd_dispatch.h>
|
| 287 |
+
#include <ATen/ops/cudnn_convolution_transpose_compositeexplicitautograd_dispatch.h>
|
| 288 |
+
#include <ATen/ops/cudnn_grid_sampler_compositeexplicitautograd_dispatch.h>
|
| 289 |
+
#include <ATen/ops/cudnn_grid_sampler_backward_compositeexplicitautograd_dispatch.h>
|
| 290 |
+
#include <ATen/ops/cummax_compositeexplicitautograd_dispatch.h>
|
| 291 |
+
#include <ATen/ops/cummin_compositeexplicitautograd_dispatch.h>
|
| 292 |
+
#include <ATen/ops/deg2rad_compositeexplicitautograd_dispatch.h>
|
| 293 |
+
#include <ATen/ops/dense_dim_compositeexplicitautograd_dispatch.h>
|
| 294 |
+
#include <ATen/ops/dequantize_compositeexplicitautograd_dispatch.h>
|
| 295 |
+
#include <ATen/ops/detach_compositeexplicitautograd_dispatch.h>
|
| 296 |
+
#include <ATen/ops/detach_copy_compositeexplicitautograd_dispatch.h>
|
| 297 |
+
#include <ATen/ops/diag_embed_compositeexplicitautograd_dispatch.h>
|
| 298 |
+
#include <ATen/ops/diagonal_compositeexplicitautograd_dispatch.h>
|
| 299 |
+
#include <ATen/ops/diagonal_backward_compositeexplicitautograd_dispatch.h>
|
| 300 |
+
#include <ATen/ops/diagonal_copy_compositeexplicitautograd_dispatch.h>
|
| 301 |
+
#include <ATen/ops/diagonal_scatter_compositeexplicitautograd_dispatch.h>
|
| 302 |
+
#include <ATen/ops/dist_compositeexplicitautograd_dispatch.h>
|
| 303 |
+
#include <ATen/ops/div_compositeexplicitautograd_dispatch.h>
|
| 304 |
+
#include <ATen/ops/dot_compositeexplicitautograd_dispatch.h>
|
| 305 |
+
#include <ATen/ops/embedding_compositeexplicitautograd_dispatch.h>
|
| 306 |
+
#include <ATen/ops/embedding_dense_backward_compositeexplicitautograd_dispatch.h>
|
| 307 |
+
#include <ATen/ops/embedding_renorm_compositeexplicitautograd_dispatch.h>
|
| 308 |
+
#include <ATen/ops/empty_compositeexplicitautograd_dispatch.h>
|
| 309 |
+
#include <ATen/ops/empty_like_compositeexplicitautograd_dispatch.h>
|
| 310 |
+
#include <ATen/ops/empty_permuted_compositeexplicitautograd_dispatch.h>
|
| 311 |
+
#include <ATen/ops/empty_quantized_compositeexplicitautograd_dispatch.h>
|
| 312 |
+
#include <ATen/ops/empty_strided_compositeexplicitautograd_dispatch.h>
|
| 313 |
+
#include <ATen/ops/expand_compositeexplicitautograd_dispatch.h>
|
| 314 |
+
#include <ATen/ops/expand_copy_compositeexplicitautograd_dispatch.h>
|
| 315 |
+
#include <ATen/ops/exponential_compositeexplicitautograd_dispatch.h>
|
| 316 |
+
#include <ATen/ops/eye_compositeexplicitautograd_dispatch.h>
|
| 317 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_compositeexplicitautograd_dispatch.h>
|
| 318 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_compositeexplicitautograd_dispatch.h>
|
| 319 |
+
#include <ATen/ops/fft_fftfreq_compositeexplicitautograd_dispatch.h>
|
| 320 |
+
#include <ATen/ops/fft_rfftfreq_compositeexplicitautograd_dispatch.h>
|
| 321 |
+
#include <ATen/ops/fill_compositeexplicitautograd_dispatch.h>
|
| 322 |
+
#include <ATen/ops/flip_compositeexplicitautograd_dispatch.h>
|
| 323 |
+
#include <ATen/ops/floor_divide_compositeexplicitautograd_dispatch.h>
|
| 324 |
+
#include <ATen/ops/fmod_compositeexplicitautograd_dispatch.h>
|
| 325 |
+
#include <ATen/ops/frexp_compositeexplicitautograd_dispatch.h>
|
| 326 |
+
#include <ATen/ops/from_file_compositeexplicitautograd_dispatch.h>
|
| 327 |
+
#include <ATen/ops/full_compositeexplicitautograd_dispatch.h>
|
| 328 |
+
#include <ATen/ops/full_like_compositeexplicitautograd_dispatch.h>
|
| 329 |
+
#include <ATen/ops/geometric_compositeexplicitautograd_dispatch.h>
|
| 330 |
+
#include <ATen/ops/glu_backward_jvp_compositeexplicitautograd_dispatch.h>
|
| 331 |
+
#include <ATen/ops/glu_jvp_compositeexplicitautograd_dispatch.h>
|
| 332 |
+
#include <ATen/ops/grid_sampler_2d_compositeexplicitautograd_dispatch.h>
|
| 333 |
+
#include <ATen/ops/grid_sampler_2d_backward_compositeexplicitautograd_dispatch.h>
|
| 334 |
+
#include <ATen/ops/grid_sampler_3d_compositeexplicitautograd_dispatch.h>
|
| 335 |
+
#include <ATen/ops/grid_sampler_3d_backward_compositeexplicitautograd_dispatch.h>
|
| 336 |
+
#include <ATen/ops/hamming_window_compositeexplicitautograd_dispatch.h>
|
| 337 |
+
#include <ATen/ops/hann_window_compositeexplicitautograd_dispatch.h>
|
| 338 |
+
#include <ATen/ops/hardswish_backward_compositeexplicitautograd_dispatch.h>
|
| 339 |
+
#include <ATen/ops/huber_loss_backward_compositeexplicitautograd_dispatch.h>
|
| 340 |
+
#include <ATen/ops/index_fill_compositeexplicitautograd_dispatch.h>
|
| 341 |
+
#include <ATen/ops/index_put_compositeexplicitautograd_dispatch.h>
|
| 342 |
+
#include <ATen/ops/indices_compositeexplicitautograd_dispatch.h>
|
| 343 |
+
#include <ATen/ops/indices_copy_compositeexplicitautograd_dispatch.h>
|
| 344 |
+
#include <ATen/ops/int_repr_compositeexplicitautograd_dispatch.h>
|
| 345 |
+
#include <ATen/ops/is_coalesced_compositeexplicitautograd_dispatch.h>
|
| 346 |
+
#include <ATen/ops/is_pinned_compositeexplicitautograd_dispatch.h>
|
| 347 |
+
#include <ATen/ops/is_same_size_compositeexplicitautograd_dispatch.h>
|
| 348 |
+
#include <ATen/ops/isinf_compositeexplicitautograd_dispatch.h>
|
| 349 |
+
#include <ATen/ops/isnan_compositeexplicitautograd_dispatch.h>
|
| 350 |
+
#include <ATen/ops/kaiser_window_compositeexplicitautograd_dispatch.h>
|
| 351 |
+
#include <ATen/ops/kthvalue_compositeexplicitautograd_dispatch.h>
|
| 352 |
+
#include <ATen/ops/lift_compositeexplicitautograd_dispatch.h>
|
| 353 |
+
#include <ATen/ops/lift_fresh_compositeexplicitautograd_dispatch.h>
|
| 354 |
+
#include <ATen/ops/lift_fresh_copy_compositeexplicitautograd_dispatch.h>
|
| 355 |
+
#include <ATen/ops/linalg_lstsq_compositeexplicitautograd_dispatch.h>
|
| 356 |
+
#include <ATen/ops/linalg_matrix_exp_compositeexplicitautograd_dispatch.h>
|
| 357 |
+
#include <ATen/ops/linalg_pinv_compositeexplicitautograd_dispatch.h>
|
| 358 |
+
#include <ATen/ops/linear_compositeexplicitautograd_dispatch.h>
|
| 359 |
+
#include <ATen/ops/linear_backward_compositeexplicitautograd_dispatch.h>
|
| 360 |
+
#include <ATen/ops/linspace_compositeexplicitautograd_dispatch.h>
|
| 361 |
+
#include <ATen/ops/log_normal_compositeexplicitautograd_dispatch.h>
|
| 362 |
+
#include <ATen/ops/log_softmax_compositeexplicitautograd_dispatch.h>
|
| 363 |
+
#include <ATen/ops/logcumsumexp_compositeexplicitautograd_dispatch.h>
|
| 364 |
+
#include <ATen/ops/logical_and_compositeexplicitautograd_dispatch.h>
|
| 365 |
+
#include <ATen/ops/logical_not_compositeexplicitautograd_dispatch.h>
|
| 366 |
+
#include <ATen/ops/logical_or_compositeexplicitautograd_dispatch.h>
|
| 367 |
+
#include <ATen/ops/logical_xor_compositeexplicitautograd_dispatch.h>
|
| 368 |
+
#include <ATen/ops/logspace_compositeexplicitautograd_dispatch.h>
|
| 369 |
+
#include <ATen/ops/logsumexp_compositeexplicitautograd_dispatch.h>
|
| 370 |
+
#include <ATen/ops/lshift_compositeexplicitautograd_dispatch.h>
|
| 371 |
+
#include <ATen/ops/lstm_mps_backward_compositeexplicitautograd_dispatch.h>
|
| 372 |
+
#include <ATen/ops/masked_fill_compositeexplicitautograd_dispatch.h>
|
| 373 |
+
#include <ATen/ops/masked_scatter_compositeexplicitautograd_dispatch.h>
|
| 374 |
+
#include <ATen/ops/masked_scatter_backward_compositeexplicitautograd_dispatch.h>
|
| 375 |
+
#include <ATen/ops/matmul_backward_compositeexplicitautograd_dispatch.h>
|
| 376 |
+
#include <ATen/ops/max_pool2d_backward_compositeexplicitautograd_dispatch.h>
|
| 377 |
+
#include <ATen/ops/mean_compositeexplicitautograd_dispatch.h>
|
| 378 |
+
#include <ATen/ops/median_compositeexplicitautograd_dispatch.h>
|
| 379 |
+
#include <ATen/ops/miopen_batch_norm_compositeexplicitautograd_dispatch.h>
|
| 380 |
+
#include <ATen/ops/miopen_batch_norm_backward_compositeexplicitautograd_dispatch.h>
|
| 381 |
+
#include <ATen/ops/miopen_convolution_compositeexplicitautograd_dispatch.h>
|
| 382 |
+
#include <ATen/ops/miopen_convolution_transpose_compositeexplicitautograd_dispatch.h>
|
| 383 |
+
#include <ATen/ops/miopen_depthwise_convolution_compositeexplicitautograd_dispatch.h>
|
| 384 |
+
#include <ATen/ops/miopen_rnn_compositeexplicitautograd_dispatch.h>
|
| 385 |
+
#include <ATen/ops/miopen_rnn_backward_compositeexplicitautograd_dispatch.h>
|
| 386 |
+
#include <ATen/ops/mkldnn_adaptive_avg_pool2d_backward_compositeexplicitautograd_dispatch.h>
|
| 387 |
+
#include <ATen/ops/mkldnn_convolution_compositeexplicitautograd_dispatch.h>
|
| 388 |
+
#include <ATen/ops/mkldnn_linear_compositeexplicitautograd_dispatch.h>
|
| 389 |
+
#include <ATen/ops/mkldnn_linear_backward_compositeexplicitautograd_dispatch.h>
|
| 390 |
+
#include <ATen/ops/mkldnn_linear_backward_input_compositeexplicitautograd_dispatch.h>
|
| 391 |
+
#include <ATen/ops/mkldnn_linear_backward_weights_compositeexplicitautograd_dispatch.h>
|
| 392 |
+
#include <ATen/ops/mkldnn_max_pool2d_compositeexplicitautograd_dispatch.h>
|
| 393 |
+
#include <ATen/ops/mkldnn_max_pool2d_backward_compositeexplicitautograd_dispatch.h>
|
| 394 |
+
#include <ATen/ops/mkldnn_max_pool3d_compositeexplicitautograd_dispatch.h>
|
| 395 |
+
#include <ATen/ops/mkldnn_max_pool3d_backward_compositeexplicitautograd_dispatch.h>
|
| 396 |
+
#include <ATen/ops/mkldnn_reorder_conv2d_weight_compositeexplicitautograd_dispatch.h>
|
| 397 |
+
#include <ATen/ops/mkldnn_reorder_conv3d_weight_compositeexplicitautograd_dispatch.h>
|
| 398 |
+
#include <ATen/ops/mkldnn_rnn_layer_compositeexplicitautograd_dispatch.h>
|
| 399 |
+
#include <ATen/ops/mkldnn_rnn_layer_backward_compositeexplicitautograd_dispatch.h>
|
| 400 |
+
#include <ATen/ops/mode_compositeexplicitautograd_dispatch.h>
|
| 401 |
+
#include <ATen/ops/mps_convolution_backward_compositeexplicitautograd_dispatch.h>
|
| 402 |
+
#include <ATen/ops/mps_convolution_transpose_backward_compositeexplicitautograd_dispatch.h>
|
| 403 |
+
#include <ATen/ops/mul_compositeexplicitautograd_dispatch.h>
|
| 404 |
+
#include <ATen/ops/mv_compositeexplicitautograd_dispatch.h>
|
| 405 |
+
#include <ATen/ops/mvlgamma_compositeexplicitautograd_dispatch.h>
|
| 406 |
+
#include <ATen/ops/nan_to_num_compositeexplicitautograd_dispatch.h>
|
| 407 |
+
#include <ATen/ops/nanmedian_compositeexplicitautograd_dispatch.h>
|
| 408 |
+
#include <ATen/ops/native_batch_norm_backward_compositeexplicitautograd_dispatch.h>
|
| 409 |
+
#include <ATen/ops/native_dropout_compositeexplicitautograd_dispatch.h>
|
| 410 |
+
#include <ATen/ops/native_dropout_backward_compositeexplicitautograd_dispatch.h>
|
| 411 |
+
#include <ATen/ops/native_group_norm_compositeexplicitautograd_dispatch.h>
|
| 412 |
+
#include <ATen/ops/native_group_norm_backward_compositeexplicitautograd_dispatch.h>
|
| 413 |
+
#include <ATen/ops/native_layer_norm_compositeexplicitautograd_dispatch.h>
|
| 414 |
+
#include <ATen/ops/native_layer_norm_backward_compositeexplicitautograd_dispatch.h>
|
| 415 |
+
#include <ATen/ops/native_norm_compositeexplicitautograd_dispatch.h>
|
| 416 |
+
#include <ATen/ops/new_empty_compositeexplicitautograd_dispatch.h>
|
| 417 |
+
#include <ATen/ops/new_empty_strided_compositeexplicitautograd_dispatch.h>
|
| 418 |
+
#include <ATen/ops/new_full_compositeexplicitautograd_dispatch.h>
|
| 419 |
+
#include <ATen/ops/new_ones_compositeexplicitautograd_dispatch.h>
|
| 420 |
+
#include <ATen/ops/new_zeros_compositeexplicitautograd_dispatch.h>
|
| 421 |
+
#include <ATen/ops/norm_compositeexplicitautograd_dispatch.h>
|
| 422 |
+
#include <ATen/ops/normal_compositeexplicitautograd_dispatch.h>
|
| 423 |
+
#include <ATen/ops/ones_compositeexplicitautograd_dispatch.h>
|
| 424 |
+
#include <ATen/ops/ones_like_compositeexplicitautograd_dispatch.h>
|
| 425 |
+
#include <ATen/ops/permute_compositeexplicitautograd_dispatch.h>
|
| 426 |
+
#include <ATen/ops/permute_copy_compositeexplicitautograd_dispatch.h>
|
| 427 |
+
#include <ATen/ops/pixel_shuffle_compositeexplicitautograd_dispatch.h>
|
| 428 |
+
#include <ATen/ops/pixel_unshuffle_compositeexplicitautograd_dispatch.h>
|
| 429 |
+
#include <ATen/ops/poisson_compositeexplicitautograd_dispatch.h>
|
| 430 |
+
#include <ATen/ops/polar_compositeexplicitautograd_dispatch.h>
|
| 431 |
+
#include <ATen/ops/polygamma_compositeexplicitautograd_dispatch.h>
|
| 432 |
+
#include <ATen/ops/prod_compositeexplicitautograd_dispatch.h>
|
| 433 |
+
#include <ATen/ops/put_compositeexplicitautograd_dispatch.h>
|
| 434 |
+
#include <ATen/ops/q_per_channel_scales_compositeexplicitautograd_dispatch.h>
|
| 435 |
+
#include <ATen/ops/q_per_channel_zero_points_compositeexplicitautograd_dispatch.h>
|
| 436 |
+
#include <ATen/ops/quantize_per_channel_compositeexplicitautograd_dispatch.h>
|
| 437 |
+
#include <ATen/ops/quantize_per_tensor_compositeexplicitautograd_dispatch.h>
|
| 438 |
+
#include <ATen/ops/quantize_per_tensor_dynamic_compositeexplicitautograd_dispatch.h>
|
| 439 |
+
#include <ATen/ops/quantized_batch_norm_compositeexplicitautograd_dispatch.h>
|
| 440 |
+
#include <ATen/ops/quantized_max_pool1d_compositeexplicitautograd_dispatch.h>
|
| 441 |
+
#include <ATen/ops/quantized_max_pool2d_compositeexplicitautograd_dispatch.h>
|
| 442 |
+
#include <ATen/ops/quantized_max_pool3d_compositeexplicitautograd_dispatch.h>
|
| 443 |
+
#include <ATen/ops/rad2deg_compositeexplicitautograd_dispatch.h>
|
| 444 |
+
#include <ATen/ops/rand_compositeexplicitautograd_dispatch.h>
|
| 445 |
+
#include <ATen/ops/rand_like_compositeexplicitautograd_dispatch.h>
|
| 446 |
+
#include <ATen/ops/randint_compositeexplicitautograd_dispatch.h>
|
| 447 |
+
#include <ATen/ops/randint_like_compositeexplicitautograd_dispatch.h>
|
| 448 |
+
#include <ATen/ops/randn_compositeexplicitautograd_dispatch.h>
|
| 449 |
+
#include <ATen/ops/randn_like_compositeexplicitautograd_dispatch.h>
|
| 450 |
+
#include <ATen/ops/random_compositeexplicitautograd_dispatch.h>
|
| 451 |
+
#include <ATen/ops/randperm_compositeexplicitautograd_dispatch.h>
|
| 452 |
+
#include <ATen/ops/range_compositeexplicitautograd_dispatch.h>
|
| 453 |
+
#include <ATen/ops/relu_compositeexplicitautograd_dispatch.h>
|
| 454 |
+
#include <ATen/ops/remainder_compositeexplicitautograd_dispatch.h>
|
| 455 |
+
#include <ATen/ops/repeat_compositeexplicitautograd_dispatch.h>
|
| 456 |
+
#include <ATen/ops/repeat_interleave_compositeexplicitautograd_dispatch.h>
|
| 457 |
+
#include <ATen/ops/resize_compositeexplicitautograd_dispatch.h>
|
| 458 |
+
#include <ATen/ops/resize_as_compositeexplicitautograd_dispatch.h>
|
| 459 |
+
#include <ATen/ops/resize_as_sparse_compositeexplicitautograd_dispatch.h>
|
| 460 |
+
#include <ATen/ops/roll_compositeexplicitautograd_dispatch.h>
|
| 461 |
+
#include <ATen/ops/rot90_compositeexplicitautograd_dispatch.h>
|
| 462 |
+
#include <ATen/ops/row_indices_compositeexplicitautograd_dispatch.h>
|
| 463 |
+
#include <ATen/ops/row_indices_copy_compositeexplicitautograd_dispatch.h>
|
| 464 |
+
#include <ATen/ops/rrelu_with_noise_backward_compositeexplicitautograd_dispatch.h>
|
| 465 |
+
#include <ATen/ops/rshift_compositeexplicitautograd_dispatch.h>
|
| 466 |
+
#include <ATen/ops/rsub_compositeexplicitautograd_dispatch.h>
|
| 467 |
+
#include <ATen/ops/scalar_tensor_compositeexplicitautograd_dispatch.h>
|
| 468 |
+
#include <ATen/ops/segment_reduce_compositeexplicitautograd_dispatch.h>
|
| 469 |
+
#include <ATen/ops/select_compositeexplicitautograd_dispatch.h>
|
| 470 |
+
#include <ATen/ops/select_backward_compositeexplicitautograd_dispatch.h>
|
| 471 |
+
#include <ATen/ops/select_copy_compositeexplicitautograd_dispatch.h>
|
| 472 |
+
#include <ATen/ops/select_scatter_compositeexplicitautograd_dispatch.h>
|
| 473 |
+
#include <ATen/ops/set_compositeexplicitautograd_dispatch.h>
|
| 474 |
+
#include <ATen/ops/slice_compositeexplicitautograd_dispatch.h>
|
| 475 |
+
#include <ATen/ops/slice_backward_compositeexplicitautograd_dispatch.h>
|
| 476 |
+
#include <ATen/ops/slice_copy_compositeexplicitautograd_dispatch.h>
|
| 477 |
+
#include <ATen/ops/slice_inverse_compositeexplicitautograd_dispatch.h>
|
| 478 |
+
#include <ATen/ops/slice_scatter_compositeexplicitautograd_dispatch.h>
|
| 479 |
+
#include <ATen/ops/slow_conv_dilated2d_compositeexplicitautograd_dispatch.h>
|
| 480 |
+
#include <ATen/ops/slow_conv_dilated3d_compositeexplicitautograd_dispatch.h>
|
| 481 |
+
#include <ATen/ops/smooth_l1_loss_backward_compositeexplicitautograd_dispatch.h>
|
| 482 |
+
#include <ATen/ops/soft_margin_loss_compositeexplicitautograd_dispatch.h>
|
| 483 |
+
#include <ATen/ops/soft_margin_loss_backward_compositeexplicitautograd_dispatch.h>
|
| 484 |
+
#include <ATen/ops/softmax_compositeexplicitautograd_dispatch.h>
|
| 485 |
+
#include <ATen/ops/sort_compositeexplicitautograd_dispatch.h>
|
| 486 |
+
#include <ATen/ops/sparse_compressed_tensor_compositeexplicitautograd_dispatch.h>
|
| 487 |
+
#include <ATen/ops/sparse_coo_tensor_compositeexplicitautograd_dispatch.h>
|
| 488 |
+
#include <ATen/ops/sparse_dim_compositeexplicitautograd_dispatch.h>
|
| 489 |
+
#include <ATen/ops/sparse_mask_compositeexplicitautograd_dispatch.h>
|
| 490 |
+
#include <ATen/ops/sparse_resize_compositeexplicitautograd_dispatch.h>
|
| 491 |
+
#include <ATen/ops/sparse_resize_and_clear_compositeexplicitautograd_dispatch.h>
|
| 492 |
+
#include <ATen/ops/special_chebyshev_polynomial_t_compositeexplicitautograd_dispatch.h>
|
| 493 |
+
#include <ATen/ops/special_chebyshev_polynomial_u_compositeexplicitautograd_dispatch.h>
|
| 494 |
+
#include <ATen/ops/special_chebyshev_polynomial_v_compositeexplicitautograd_dispatch.h>
|
| 495 |
+
#include <ATen/ops/special_chebyshev_polynomial_w_compositeexplicitautograd_dispatch.h>
|
| 496 |
+
#include <ATen/ops/special_hermite_polynomial_h_compositeexplicitautograd_dispatch.h>
|
| 497 |
+
#include <ATen/ops/special_hermite_polynomial_he_compositeexplicitautograd_dispatch.h>
|
| 498 |
+
#include <ATen/ops/special_laguerre_polynomial_l_compositeexplicitautograd_dispatch.h>
|
| 499 |
+
#include <ATen/ops/special_legendre_polynomial_p_compositeexplicitautograd_dispatch.h>
|
| 500 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_t_compositeexplicitautograd_dispatch.h>
|
| 501 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_u_compositeexplicitautograd_dispatch.h>
|
| 502 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_v_compositeexplicitautograd_dispatch.h>
|
| 503 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_w_compositeexplicitautograd_dispatch.h>
|
| 504 |
+
#include <ATen/ops/special_xlog1py_compositeexplicitautograd_dispatch.h>
|
| 505 |
+
#include <ATen/ops/special_zeta_compositeexplicitautograd_dispatch.h>
|
| 506 |
+
#include <ATen/ops/split_compositeexplicitautograd_dispatch.h>
|
| 507 |
+
#include <ATen/ops/split_copy_compositeexplicitautograd_dispatch.h>
|
| 508 |
+
#include <ATen/ops/split_with_sizes_compositeexplicitautograd_dispatch.h>
|
| 509 |
+
#include <ATen/ops/split_with_sizes_copy_compositeexplicitautograd_dispatch.h>
|
| 510 |
+
#include <ATen/ops/squeeze_compositeexplicitautograd_dispatch.h>
|
| 511 |
+
#include <ATen/ops/squeeze_copy_compositeexplicitautograd_dispatch.h>
|
| 512 |
+
#include <ATen/ops/stack_compositeexplicitautograd_dispatch.h>
|
| 513 |
+
#include <ATen/ops/std_mean_compositeexplicitautograd_dispatch.h>
|
| 514 |
+
#include <ATen/ops/sub_compositeexplicitautograd_dispatch.h>
|
| 515 |
+
#include <ATen/ops/sum_compositeexplicitautograd_dispatch.h>
|
| 516 |
+
#include <ATen/ops/sym_constrain_range_compositeexplicitautograd_dispatch.h>
|
| 517 |
+
#include <ATen/ops/sym_constrain_range_for_size_compositeexplicitautograd_dispatch.h>
|
| 518 |
+
#include <ATen/ops/t_compositeexplicitautograd_dispatch.h>
|
| 519 |
+
#include <ATen/ops/t_copy_compositeexplicitautograd_dispatch.h>
|
| 520 |
+
#include <ATen/ops/to_mkldnn_compositeexplicitautograd_dispatch.h>
|
| 521 |
+
#include <ATen/ops/to_padded_tensor_compositeexplicitautograd_dispatch.h>
|
| 522 |
+
#include <ATen/ops/trace_compositeexplicitautograd_dispatch.h>
|
| 523 |
+
#include <ATen/ops/transpose_compositeexplicitautograd_dispatch.h>
|
| 524 |
+
#include <ATen/ops/transpose_copy_compositeexplicitautograd_dispatch.h>
|
| 525 |
+
#include <ATen/ops/tril_indices_compositeexplicitautograd_dispatch.h>
|
| 526 |
+
#include <ATen/ops/triu_indices_compositeexplicitautograd_dispatch.h>
|
| 527 |
+
#include <ATen/ops/unbind_compositeexplicitautograd_dispatch.h>
|
| 528 |
+
#include <ATen/ops/unbind_copy_compositeexplicitautograd_dispatch.h>
|
| 529 |
+
#include <ATen/ops/unfold_backward_compositeexplicitautograd_dispatch.h>
|
| 530 |
+
#include <ATen/ops/unfold_copy_compositeexplicitautograd_dispatch.h>
|
| 531 |
+
#include <ATen/ops/uniform_compositeexplicitautograd_dispatch.h>
|
| 532 |
+
#include <ATen/ops/unique_consecutive_compositeexplicitautograd_dispatch.h>
|
| 533 |
+
#include <ATen/ops/unique_dim_compositeexplicitautograd_dispatch.h>
|
| 534 |
+
#include <ATen/ops/unique_dim_consecutive_compositeexplicitautograd_dispatch.h>
|
| 535 |
+
#include <ATen/ops/unsafe_split_compositeexplicitautograd_dispatch.h>
|
| 536 |
+
#include <ATen/ops/unsafe_split_with_sizes_compositeexplicitautograd_dispatch.h>
|
| 537 |
+
#include <ATen/ops/unsqueeze_compositeexplicitautograd_dispatch.h>
|
| 538 |
+
#include <ATen/ops/unsqueeze_copy_compositeexplicitautograd_dispatch.h>
|
| 539 |
+
#include <ATen/ops/values_compositeexplicitautograd_dispatch.h>
|
| 540 |
+
#include <ATen/ops/values_copy_compositeexplicitautograd_dispatch.h>
|
| 541 |
+
#include <ATen/ops/var_mean_compositeexplicitautograd_dispatch.h>
|
| 542 |
+
#include <ATen/ops/vdot_compositeexplicitautograd_dispatch.h>
|
| 543 |
+
#include <ATen/ops/view_compositeexplicitautograd_dispatch.h>
|
| 544 |
+
#include <ATen/ops/view_as_complex_copy_compositeexplicitautograd_dispatch.h>
|
| 545 |
+
#include <ATen/ops/view_as_real_copy_compositeexplicitautograd_dispatch.h>
|
| 546 |
+
#include <ATen/ops/view_copy_compositeexplicitautograd_dispatch.h>
|
| 547 |
+
#include <ATen/ops/xlogy_compositeexplicitautograd_dispatch.h>
|
| 548 |
+
#include <ATen/ops/zero_compositeexplicitautograd_dispatch.h>
|
| 549 |
+
#include <ATen/ops/zeros_compositeexplicitautograd_dispatch.h>
|
| 550 |
+
#include <ATen/ops/zeros_like_compositeexplicitautograd_dispatch.h>
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
|
.venv/lib/python3.11/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions.h
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/TensorBody.h>
|
| 2 |
+
|
| 3 |
+
// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
|
| 4 |
+
// Code introduced to avoid cyclic dependency in static dispatch is no longer
|
| 5 |
+
// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
|
| 6 |
+
// to Operators.cpp for supporting multiple backends with multiple kernels.
|
| 7 |
+
//
|
| 8 |
+
// Note [Avoiding Include Cycles In Static Dispatch]
|
| 9 |
+
// In order to avoid #include cycles in the static dispatch build, we've carefully split out
|
| 10 |
+
// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
|
| 11 |
+
//
|
| 12 |
+
// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
|
| 13 |
+
// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
|
| 14 |
+
// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
|
| 15 |
+
// directly inlined into TensorBody.h.
|
| 16 |
+
// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
|
| 17 |
+
// which include functions that have defaultable std::optional<Tensor> arguments.
|
| 18 |
+
// That requires knowing the full Tensor class definition.
|
| 19 |
+
//
|
| 20 |
+
// We break the cycle by doing the following:
|
| 21 |
+
// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
|
| 22 |
+
// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
|
| 23 |
+
// - CPUFunctions_inl.h includes everything else
|
| 24 |
+
// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
|
| 25 |
+
// and then it includes CPUFunctions_inl.h.
|
| 26 |
+
// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
|
| 27 |
+
// - This also means that static dispatch build, CPUFunctions.h only needs to
|
| 28 |
+
// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
|
| 29 |
+
#include <ATen/CompositeImplicitAutogradFunctions_inl.h>
|
.venv/lib/python3.11/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions.h
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/TensorBody.h>
|
| 2 |
+
|
| 3 |
+
// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
|
| 4 |
+
// Code introduced to avoid cyclic dependency in static dispatch is no longer
|
| 5 |
+
// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
|
| 6 |
+
// to Operators.cpp for supporting multiple backends with multiple kernels.
|
| 7 |
+
//
|
| 8 |
+
// Note [Avoiding Include Cycles In Static Dispatch]
|
| 9 |
+
// In order to avoid #include cycles in the static dispatch build, we've carefully split out
|
| 10 |
+
// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
|
| 11 |
+
//
|
| 12 |
+
// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
|
| 13 |
+
// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
|
| 14 |
+
// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
|
| 15 |
+
// directly inlined into TensorBody.h.
|
| 16 |
+
// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
|
| 17 |
+
// which include functions that have defaultable std::optional<Tensor> arguments.
|
| 18 |
+
// That requires knowing the full Tensor class definition.
|
| 19 |
+
//
|
| 20 |
+
// We break the cycle by doing the following:
|
| 21 |
+
// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
|
| 22 |
+
// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
|
| 23 |
+
// - CPUFunctions_inl.h includes everything else
|
| 24 |
+
// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
|
| 25 |
+
// and then it includes CPUFunctions_inl.h.
|
| 26 |
+
// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
|
| 27 |
+
// - This also means that static dispatch build, CPUFunctions.h only needs to
|
| 28 |
+
// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
|
| 29 |
+
#include <ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h>
|
.venv/lib/python3.11/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 12 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 13 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 14 |
+
Consider including a specific operator from \
|
| 15 |
+
<ATen/ops/{my_operator}_compositeimplicitautogradnestedtensor_dispatch.h>. \
|
| 16 |
+
See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
#include <ATen/ops/randn_like_compositeimplicitautogradnestedtensor_dispatch.h>
|
| 20 |
+
#include <ATen/ops/reshape_compositeimplicitautogradnestedtensor_dispatch.h>
|
| 21 |
+
#include <ATen/ops/reshape_as_compositeimplicitautogradnestedtensor_dispatch.h>
|
| 22 |
+
#include <ATen/ops/zeros_like_compositeimplicitautogradnestedtensor_dispatch.h>
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
.venv/lib/python3.11/site-packages/torch/include/ATen/Config.h
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// Test these using #if AT_MKL_ENABLED(), not #ifdef, so that it's
|
| 4 |
+
// obvious if you forgot to include Config.h
|
| 5 |
+
// c.f. https://stackoverflow.com/questions/33759787/generating-an-error-if-checked-boolean-macro-is-not-defined
|
| 6 |
+
//
|
| 7 |
+
// DO NOT put the macros for CUDA libraries in this file; they belong in cuda/CUDAConfig.h
|
| 8 |
+
|
| 9 |
+
#define AT_MKLDNN_ENABLED() 1
|
| 10 |
+
#define AT_MKLDNN_ACL_ENABLED() 0
|
| 11 |
+
#define AT_MKL_ENABLED() 1
|
| 12 |
+
#define AT_MKL_SEQUENTIAL() 0
|
| 13 |
+
#define AT_POCKETFFT_ENABLED() 0
|
| 14 |
+
#define AT_NNPACK_ENABLED() 1
|
| 15 |
+
#define CAFFE2_STATIC_LINK_CUDA() 0
|
| 16 |
+
#define AT_BUILD_WITH_BLAS() 1
|
| 17 |
+
#define AT_BUILD_WITH_LAPACK() 1
|
| 18 |
+
#define AT_PARALLEL_OPENMP 1
|
| 19 |
+
#define AT_PARALLEL_NATIVE 0
|
| 20 |
+
#define AT_BLAS_F2C() 0
|
| 21 |
+
#define AT_BLAS_USE_CBLAS_DOT() 0
|
.venv/lib/python3.11/site-packages/torch/include/ATen/Context.h
ADDED
|
@@ -0,0 +1,610 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/BlasBackend.h>
|
| 4 |
+
#include <ATen/CPUGeneratorImpl.h>
|
| 5 |
+
#include <ATen/DeviceAccelerator.h>
|
| 6 |
+
#include <ATen/LinalgBackend.h>
|
| 7 |
+
#include <ATen/core/ATenGeneral.h>
|
| 8 |
+
#include <ATen/core/DeprecatedTypeProperties.h>
|
| 9 |
+
#include <ATen/core/Generator.h>
|
| 10 |
+
#include <ATen/core/LegacyTypeDispatch.h>
|
| 11 |
+
#include <ATen/detail/AcceleratorHooksInterface.h>
|
| 12 |
+
#include <ATen/detail/CUDAHooksInterface.h>
|
| 13 |
+
#include <ATen/detail/HIPHooksInterface.h>
|
| 14 |
+
#include <ATen/detail/IPUHooksInterface.h>
|
| 15 |
+
#include <ATen/detail/MAIAHooksInterface.h>
|
| 16 |
+
#include <ATen/detail/MPSHooksInterface.h>
|
| 17 |
+
#include <ATen/detail/MTIAHooksInterface.h>
|
| 18 |
+
#include <ATen/detail/PrivateUse1HooksInterface.h>
|
| 19 |
+
#include <ATen/detail/XPUHooksInterface.h>
|
| 20 |
+
#include <c10/core/QEngine.h>
|
| 21 |
+
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
| 22 |
+
#include <c10/util/CallOnce.h>
|
| 23 |
+
#include <c10/util/Exception.h>
|
| 24 |
+
#include <c10/util/env.h>
|
| 25 |
+
#include <c10/util/irange.h>
|
| 26 |
+
|
| 27 |
+
#include <cstdint>
|
| 28 |
+
#include <mutex>
|
| 29 |
+
|
| 30 |
+
namespace at {
|
| 31 |
+
|
| 32 |
+
class Tensor;
|
| 33 |
+
|
| 34 |
+
enum class TORCH_API Float32MatmulPrecision { HIGHEST, HIGH, MEDIUM };
|
| 35 |
+
|
| 36 |
+
class TORCH_API Context {
|
| 37 |
+
public:
|
| 38 |
+
Context();
|
| 39 |
+
|
| 40 |
+
const Generator& defaultGenerator(Device device) {
|
| 41 |
+
c10::DeviceType device_type = device.type();
|
| 42 |
+
initCUDAIfNeeded(device_type);
|
| 43 |
+
initHIPIfNeeded(device_type);
|
| 44 |
+
if (device_type == at::kCPU) {
|
| 45 |
+
return at::detail::getDefaultCPUGenerator();
|
| 46 |
+
} else if (device_type == at::kCUDA) {
|
| 47 |
+
return at::detail::getCUDAHooks().getDefaultCUDAGenerator(device.index());
|
| 48 |
+
} else if (device_type == at::kMPS) {
|
| 49 |
+
return at::detail::getMPSHooks().getDefaultMPSGenerator();
|
| 50 |
+
} else if (device_type == at::kXPU) {
|
| 51 |
+
return at::detail::getXPUHooks().getDefaultXPUGenerator(device.index());
|
| 52 |
+
} else if (device_type == at::kIPU) {
|
| 53 |
+
return at::detail::getIPUHooks().getDefaultIPUGenerator(device.index());
|
| 54 |
+
} else if (device_type == at::kPrivateUse1) {
|
| 55 |
+
return at::detail::getPrivateUse1Hooks().getDefaultGenerator(
|
| 56 |
+
device.index());
|
| 57 |
+
} else {
|
| 58 |
+
AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled.");
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
+
const AcceleratorHooksInterface& getAcceleratorHooksInterface(
|
| 62 |
+
std::optional<c10::DeviceType> opt_device_type = std::nullopt) {
|
| 63 |
+
c10::DeviceType device_type = opt_device_type.has_value()
|
| 64 |
+
? opt_device_type.value()
|
| 65 |
+
: at::getAccelerator(true).value();
|
| 66 |
+
if (device_type == at::kCUDA) {
|
| 67 |
+
return at::detail::getCUDAHooks();
|
| 68 |
+
} else if (device_type == at::kXPU) {
|
| 69 |
+
return at::detail::getXPUHooks();
|
| 70 |
+
} else if (device_type == at::kMPS) {
|
| 71 |
+
return at::detail::getMPSHooks();
|
| 72 |
+
} else if (device_type == at::kPrivateUse1) {
|
| 73 |
+
return at::detail::getPrivateUse1Hooks();
|
| 74 |
+
} else if (device_type == at::kMTIA) {
|
| 75 |
+
return at::detail::getMTIAHooks();
|
| 76 |
+
} else if (device_type == at::kHIP) {
|
| 77 |
+
return at::detail::getHIPHooks();
|
| 78 |
+
} else {
|
| 79 |
+
AT_ERROR(
|
| 80 |
+
c10::DeviceTypeName(device_type), " device type not an accelerator.");
|
| 81 |
+
}
|
| 82 |
+
}
|
| 83 |
+
Device getDeviceFromPtr(void* data, c10::DeviceType device_type) {
|
| 84 |
+
initCUDAIfNeeded(device_type);
|
| 85 |
+
initHIPIfNeeded(device_type);
|
| 86 |
+
initXPUIfNeeded(device_type);
|
| 87 |
+
if (device_type == at::kCPU) {
|
| 88 |
+
return c10::DeviceType::CPU;
|
| 89 |
+
} else if (device_type == at::kCUDA) {
|
| 90 |
+
return at::detail::getCUDAHooks().getDeviceFromPtr(data);
|
| 91 |
+
} else if (device_type == at::kXPU) {
|
| 92 |
+
return at::detail::getXPUHooks().getDeviceFromPtr(data);
|
| 93 |
+
} else if (device_type == at::kPrivateUse1) {
|
| 94 |
+
return at::detail::getPrivateUse1Hooks().getDeviceFromPtr(data);
|
| 95 |
+
} else {
|
| 96 |
+
AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled.");
|
| 97 |
+
}
|
| 98 |
+
}
|
| 99 |
+
bool isPinnedPtr(
|
| 100 |
+
const void* data,
|
| 101 |
+
std::optional<c10::DeviceType> device_type = std::nullopt) {
|
| 102 |
+
auto opt_device_type =
|
| 103 |
+
device_type.has_value() ? device_type : at::getAccelerator();
|
| 104 |
+
if (!opt_device_type.has_value() || // there is no accelerator
|
| 105 |
+
!at::isAccelerator(
|
| 106 |
+
opt_device_type.value())) { // passed device not an accelerator
|
| 107 |
+
return false;
|
| 108 |
+
}
|
| 109 |
+
return getAcceleratorHooksInterface(opt_device_type.value())
|
| 110 |
+
.isPinnedPtr(data);
|
| 111 |
+
}
|
| 112 |
+
Allocator* getPinnedMemoryAllocator(
|
| 113 |
+
std::optional<c10::DeviceType> device_type = std::nullopt) {
|
| 114 |
+
return getAcceleratorHooksInterface(device_type).getPinnedMemoryAllocator();
|
| 115 |
+
}
|
| 116 |
+
static bool hasOpenMP();
|
| 117 |
+
static bool hasMKL();
|
| 118 |
+
static bool hasLAPACK();
|
| 119 |
+
static bool hasMKLDNN();
|
| 120 |
+
static bool hasMAGMA() {
|
| 121 |
+
return detail::getCUDAHooks().hasMAGMA();
|
| 122 |
+
}
|
| 123 |
+
static bool hasCUDA() {
|
| 124 |
+
return detail::getCUDAHooks().hasCUDA();
|
| 125 |
+
}
|
| 126 |
+
static bool hasMTIA() {
|
| 127 |
+
return detail::getMTIAHooks().hasMTIA();
|
| 128 |
+
}
|
| 129 |
+
static bool hasCUDART() {
|
| 130 |
+
return detail::getCUDAHooks().hasCUDART();
|
| 131 |
+
}
|
| 132 |
+
static long versionCUDART() {
|
| 133 |
+
return detail::getCUDAHooks().versionCUDART();
|
| 134 |
+
}
|
| 135 |
+
static bool hasCuDNN() {
|
| 136 |
+
return detail::getCUDAHooks().hasCuDNN();
|
| 137 |
+
}
|
| 138 |
+
static long versionCuDNN() {
|
| 139 |
+
return detail::getCUDAHooks().versionCuDNN();
|
| 140 |
+
}
|
| 141 |
+
static bool hasCuSOLVER() {
|
| 142 |
+
return detail::getCUDAHooks().hasCuSOLVER();
|
| 143 |
+
}
|
| 144 |
+
static bool hasCuBLASLt() {
|
| 145 |
+
return detail::getCUDAHooks().hasCuBLASLt();
|
| 146 |
+
}
|
| 147 |
+
static bool hasHIP() {
|
| 148 |
+
return detail::getHIPHooks().hasHIP();
|
| 149 |
+
}
|
| 150 |
+
static bool hasMPS() {
|
| 151 |
+
return detail::getMPSHooks().hasMPS();
|
| 152 |
+
}
|
| 153 |
+
static bool hasIPU() {
|
| 154 |
+
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU);
|
| 155 |
+
}
|
| 156 |
+
static bool hasXLA() {
|
| 157 |
+
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::XLA);
|
| 158 |
+
}
|
| 159 |
+
static bool hasXPU() {
|
| 160 |
+
return detail::getXPUHooks().hasXPU();
|
| 161 |
+
}
|
| 162 |
+
static bool hasLazy() {
|
| 163 |
+
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::Lazy);
|
| 164 |
+
}
|
| 165 |
+
static bool hasMAIA() {
|
| 166 |
+
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::MAIA);
|
| 167 |
+
}
|
| 168 |
+
// defined in header so that getNonVariableType has ability to inline
|
| 169 |
+
// call_once check. getNonVariableType is called fairly frequently
|
| 170 |
+
void lazyInitCUDA() {
|
| 171 |
+
c10::call_once(thc_init, [&] { detail::getCUDAHooks().initCUDA(); });
|
| 172 |
+
}
|
| 173 |
+
void lazyInitHIP() {
|
| 174 |
+
c10::call_once(thh_init, [&] { detail::getHIPHooks().initHIP(); });
|
| 175 |
+
}
|
| 176 |
+
void lazyInitXPU() {
|
| 177 |
+
c10::call_once(thx_init, [&] { detail::getXPUHooks().initXPU(); });
|
| 178 |
+
}
|
| 179 |
+
void lazyInitMTIA() {
|
| 180 |
+
c10::call_once(th_mtia_init, [&] { detail::getMTIAHooks().initMTIA(); });
|
| 181 |
+
}
|
| 182 |
+
void lazyInitPrivateUse1() {
|
| 183 |
+
c10::call_once(thp_init, [&] {
|
| 184 |
+
if (isPrivateUse1HooksRegistered()) {
|
| 185 |
+
at::detail::getPrivateUse1Hooks().initPrivateUse1();
|
| 186 |
+
}
|
| 187 |
+
});
|
| 188 |
+
}
|
| 189 |
+
static const at::cuda::NVRTC& getNVRTC() {
|
| 190 |
+
return detail::getCUDAHooks().nvrtc();
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
static bool setFlushDenormal(bool on);
|
| 194 |
+
|
| 195 |
+
// NB: This method is *purely* whether or not a user requested
|
| 196 |
+
// that CuDNN was enabled, it doesn't actually say anything about
|
| 197 |
+
// whether or not CuDNN is actually usable. Use cudnn_is_acceptable
|
| 198 |
+
// to test this instead
|
| 199 |
+
bool userEnabledCuDNN() const;
|
| 200 |
+
void setUserEnabledCuDNN(bool e);
|
| 201 |
+
bool userEnabledMkldnn() const;
|
| 202 |
+
void setUserEnabledMkldnn(bool e);
|
| 203 |
+
bool benchmarkCuDNN() const;
|
| 204 |
+
void setBenchmarkCuDNN(bool);
|
| 205 |
+
int benchmarkLimitCuDNN() const;
|
| 206 |
+
void setBenchmarkLimitCuDNN(int);
|
| 207 |
+
bool deterministicCuDNN() const;
|
| 208 |
+
void setDeterministicCuDNN(bool);
|
| 209 |
+
bool deterministicMkldnn() const;
|
| 210 |
+
void setDeterministicMkldnn(bool);
|
| 211 |
+
bool userEnabledNNPACK() const;
|
| 212 |
+
void setUserEnabledNNPACK(bool e);
|
| 213 |
+
|
| 214 |
+
// Note [Disabling Fused SDP Kernels]
|
| 215 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 216 |
+
// Flash and Memory Efficient SDP kernels are enabled by default.
|
| 217 |
+
// However, they can be disabled by setting
|
| 218 |
+
// at::globalContext().setUserEnabledFlashSDP(false) flag.
|
| 219 |
+
// This is useful for debugging purposes. For example, if you want to
|
| 220 |
+
// compare the performance of the flash SDP kernels with the unfused
|
| 221 |
+
// kernel, you can disable the flash SDP kernels. By disabling
|
| 222 |
+
// the math SDP kernel, you can force your code to use flash kernels.
|
| 223 |
+
// The math SDP kernel can be disabled by setting
|
| 224 |
+
// at::globalContext().setUserEnabledMathSDP(false) flag.
|
| 225 |
+
void setSDPUseFlash(bool);
|
| 226 |
+
bool userEnabledFlashSDP() const;
|
| 227 |
+
|
| 228 |
+
void setSDPUseMemEfficient(bool);
|
| 229 |
+
bool userEnabledMemEfficientSDP() const;
|
| 230 |
+
|
| 231 |
+
void setSDPUseMath(bool);
|
| 232 |
+
bool userEnabledMathSDP() const;
|
| 233 |
+
|
| 234 |
+
void setSDPUseCuDNN(bool);
|
| 235 |
+
bool userEnabledCuDNNSDP() const;
|
| 236 |
+
|
| 237 |
+
void setAllowFP16BF16ReductionMathSDP(bool);
|
| 238 |
+
bool allowFP16BF16ReductionMathSDP() const;
|
| 239 |
+
|
| 240 |
+
void setSDPUseOverrideable(bool);
|
| 241 |
+
bool userEnabledOverrideableSDP() const;
|
| 242 |
+
|
| 243 |
+
at::LinalgBackend linalgPreferredBackend() const;
|
| 244 |
+
void setLinalgPreferredBackend(at::LinalgBackend);
|
| 245 |
+
|
| 246 |
+
at::BlasBackend blasPreferredBackend();
|
| 247 |
+
void setBlasPreferredBackend(at::BlasBackend);
|
| 248 |
+
|
| 249 |
+
// Note [Enabling Deterministic Operations]
|
| 250 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 251 |
+
// Operations in PyTorch that normally act nondeterministically, but have an
|
| 252 |
+
// alternate deterministic implementation, should satisfy the following
|
| 253 |
+
// requirements:
|
| 254 |
+
//
|
| 255 |
+
// * Include this comment: "See Note [Enabling Deterministic Operations]"
|
| 256 |
+
//
|
| 257 |
+
// * Check the value of `at::globalContext().deterministicAlgorithms()` to
|
| 258 |
+
// toggle
|
| 259 |
+
// between nondeterministic and deterministic implementations.
|
| 260 |
+
//
|
| 261 |
+
// * Have an entry in the list of PyTorch operations that toggle between
|
| 262 |
+
// nondeterministic
|
| 263 |
+
// and deterministic implementations, in the docstring of
|
| 264 |
+
// `use_deterministic_algorithms()` in torch/__init__.py
|
| 265 |
+
//
|
| 266 |
+
// `example_func()` below shows an example of toggling between
|
| 267 |
+
// nondeterministic and deterministic implementations:
|
| 268 |
+
//
|
| 269 |
+
// void example_func() {
|
| 270 |
+
// // See Note [Enabling Deterministic Operations]
|
| 271 |
+
// if (at::globalContext().deterministicAlgorithms()) {
|
| 272 |
+
// example_func_deterministic();
|
| 273 |
+
// } else {
|
| 274 |
+
// example_func_nondeterministic();
|
| 275 |
+
// }
|
| 276 |
+
// }
|
| 277 |
+
|
| 278 |
+
bool deterministicAlgorithms() const;
|
| 279 |
+
bool deterministicAlgorithmsWarnOnly() const;
|
| 280 |
+
void setDeterministicAlgorithms(bool, bool);
|
| 281 |
+
bool deterministicFillUninitializedMemory() const;
|
| 282 |
+
void setDeterministicFillUninitializedMemory(bool);
|
| 283 |
+
|
| 284 |
+
// Note [Writing Nondeterministic Operations]
|
| 285 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 286 |
+
// Operations in PyTorch that act nondeterministically and do not have an
|
| 287 |
+
// alternate deterministic implementation should satisfy the following
|
| 288 |
+
// requirements:
|
| 289 |
+
//
|
| 290 |
+
// * Include this comment: "See Note [Writing Nondeterministic Operations]"
|
| 291 |
+
//
|
| 292 |
+
// * Include a comment explaining why the operation is nondeterministic.
|
| 293 |
+
//
|
| 294 |
+
// * Throw an error when `Context::deterministicAlgorithms()` is true. Most
|
| 295 |
+
// of the time, this should be accomplished by calling
|
| 296 |
+
// `at::globalContext().alertNotDeterminstic()`. However, if the
|
| 297 |
+
// nondeterministic behavior is caused by the CuBLAS workspace
|
| 298 |
+
// configuration in CUDA >= 10.2,
|
| 299 |
+
// `at::globalContext().alertCuBLASConfigNotDeterministic()` should be
|
| 300 |
+
// called instead (in this case, a comment explaining why the operation is
|
| 301 |
+
// nondeterministic is not necessary). See below for details on these
|
| 302 |
+
// methods.
|
| 303 |
+
//
|
| 304 |
+
// * Have an entry in the list of nondeterministic PyTorch operations in the
|
| 305 |
+
// docstring of `use_deterministic_algorithms()` in torch/__init__.py
|
| 306 |
+
//
|
| 307 |
+
// * Have a test function in `test/test_torch.py` whose name begins with
|
| 308 |
+
// `test_nondeterministic_alert_`. Alternatively, if CuBLAS workspace
|
| 309 |
+
// configuration is the reason for nondeterminism, the operation should be
|
| 310 |
+
// included in the `test_cublas_config_nondeterministic_alert` test. Any new
|
| 311 |
+
// tests should ideally follow a pattern similar to the existing ones.
|
| 312 |
+
//
|
| 313 |
+
// `example_func()` below shows an example of the comments and error-throwing
|
| 314 |
+
// code for a nondeterministic operation:
|
| 315 |
+
//
|
| 316 |
+
// void example_func() {
|
| 317 |
+
// // See Note [Writing Nondeterministic Operations]
|
| 318 |
+
// // Nondeterministic because <reason>
|
| 319 |
+
// at::globalContext().alertNondeterministic("example_func");
|
| 320 |
+
// ...
|
| 321 |
+
// }
|
| 322 |
+
|
| 323 |
+
// Throws an error if `Context::deterministicAlgorithms()` is true
|
| 324 |
+
static void alertNotDeterministic(c10::string_view const& caller);
|
| 325 |
+
|
| 326 |
+
// Throws an error if `Context::deterministicAlgorithms()` is true, CUDA
|
| 327 |
+
// >= 10.2, and CUBLAS_WORKSPACE_CONFIG is not set to either ":16:8" or
|
| 328 |
+
// ":4096:8". For more details:
|
| 329 |
+
// https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility
|
| 330 |
+
void alertCuBLASConfigNotDeterministic() const;
|
| 331 |
+
|
| 332 |
+
void setFloat32MatmulPrecision(const std::string& s);
|
| 333 |
+
bool allowTF32CuDNN() const;
|
| 334 |
+
void setAllowTF32CuDNN(bool);
|
| 335 |
+
bool allowTF32CuBLAS() const;
|
| 336 |
+
void setAllowTF32CuBLAS(bool);
|
| 337 |
+
Float32MatmulPrecision float32MatmulPrecision() const;
|
| 338 |
+
void setFloat32MatmulPrecision(Float32MatmulPrecision p);
|
| 339 |
+
bool allowFP16ReductionCuBLAS() const;
|
| 340 |
+
void setAllowFP16ReductionCuBLAS(bool);
|
| 341 |
+
bool allowBF16ReductionCuBLAS() const;
|
| 342 |
+
void setAllowBF16ReductionCuBLAS(bool);
|
| 343 |
+
at::QEngine qEngine() const;
|
| 344 |
+
void setQEngine(at::QEngine e);
|
| 345 |
+
static const std::vector<at::QEngine>& supportedQEngines();
|
| 346 |
+
static bool isXNNPACKAvailable();
|
| 347 |
+
void setCheckSparseTensorInvariants(bool e);
|
| 348 |
+
bool checkSparseTensorInvariants() const;
|
| 349 |
+
// This method is used to release the original weight after pre-packing.
|
| 350 |
+
// It should be called once before loading/running the model.
|
| 351 |
+
// NB: By default it is set to true for mobile builds.
|
| 352 |
+
void setReleaseWeightsWhenPrepacking(bool e);
|
| 353 |
+
bool releaseWeightsWhenPrepacking() const;
|
| 354 |
+
|
| 355 |
+
void setDisplayVmapFallbackWarnings(bool enabled);
|
| 356 |
+
bool areVmapFallbackWarningsEnabled() const;
|
| 357 |
+
|
| 358 |
+
void setDefaultMobileCPUAllocator();
|
| 359 |
+
void unsetDefaultMobileCPUAllocator();
|
| 360 |
+
bool allowFP16ReductionCPU() const;
|
| 361 |
+
void setAllowFP16ReductionCPU(bool);
|
| 362 |
+
|
| 363 |
+
private:
|
| 364 |
+
void initCUDAIfNeeded(c10::DeviceType p) {
|
| 365 |
+
if (p == c10::DeviceType::CUDA) {
|
| 366 |
+
lazyInitCUDA();
|
| 367 |
+
}
|
| 368 |
+
}
|
| 369 |
+
void initHIPIfNeeded(c10::DeviceType p) {
|
| 370 |
+
if (p == c10::DeviceType::HIP) {
|
| 371 |
+
lazyInitHIP();
|
| 372 |
+
}
|
| 373 |
+
}
|
| 374 |
+
void initXPUIfNeeded(c10::DeviceType p) {
|
| 375 |
+
if (p == c10::DeviceType::XPU) {
|
| 376 |
+
lazyInitXPU();
|
| 377 |
+
}
|
| 378 |
+
}
|
| 379 |
+
static bool checkCuBLASConfigDeterministic();
|
| 380 |
+
c10::once_flag thc_init;
|
| 381 |
+
c10::once_flag thh_init;
|
| 382 |
+
c10::once_flag thx_init;
|
| 383 |
+
c10::once_flag th_mtia_init;
|
| 384 |
+
c10::once_flag thp_init;
|
| 385 |
+
bool enabled_cudnn = true;
|
| 386 |
+
bool deterministic_cudnn = false;
|
| 387 |
+
bool deterministic_mkldnn = false;
|
| 388 |
+
bool _deterministic_algorithms = false;
|
| 389 |
+
bool _deterministic_algorithms_warn_only = false;
|
| 390 |
+
bool _deterministic_fill_uninitialized_memory = true;
|
| 391 |
+
bool enabled_flashSDP = true;
|
| 392 |
+
bool enabled_mem_efficientSDP = true;
|
| 393 |
+
bool enabled_mathSDP = true;
|
| 394 |
+
bool enabled_cudnnSDP = true;
|
| 395 |
+
bool enabled_overrideable = true;
|
| 396 |
+
bool allow_fp16_bf16_reduction_mathSDP = false;
|
| 397 |
+
#ifdef USE_ROCM
|
| 398 |
+
bool benchmark_cudnn = true;
|
| 399 |
+
#else
|
| 400 |
+
bool benchmark_cudnn = false;
|
| 401 |
+
#endif
|
| 402 |
+
Float32MatmulPrecision float32_matmul_precision =
|
| 403 |
+
c10::utils::check_env("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE") == true
|
| 404 |
+
? at::Float32MatmulPrecision::HIGH
|
| 405 |
+
: at::Float32MatmulPrecision::HIGHEST;
|
| 406 |
+
int benchmark_limit_cudnn = 10;
|
| 407 |
+
bool allow_tf32_cudnn = true;
|
| 408 |
+
bool allow_fp16_reduction_cublas = true;
|
| 409 |
+
bool allow_bf16_reduction_cublas = true;
|
| 410 |
+
bool enabled_mkldnn = true;
|
| 411 |
+
bool enabled_nnpack = true;
|
| 412 |
+
at::LinalgBackend linalg_preferred_backend =
|
| 413 |
+
c10::utils::check_env("TORCH_LINALG_PREFER_CUSOLVER") == true
|
| 414 |
+
? at::LinalgBackend::Cusolver
|
| 415 |
+
: at::LinalgBackend::Default;
|
| 416 |
+
at::BlasBackend blas_preferred_backend =
|
| 417 |
+
#ifdef USE_ROCM
|
| 418 |
+
(c10::utils::check_env("TORCH_BLAS_PREFER_HIPBLASLT") != false)
|
| 419 |
+
#else
|
| 420 |
+
(c10::utils::check_env("TORCH_BLAS_PREFER_CUBLASLT") == true)
|
| 421 |
+
#endif
|
| 422 |
+
? at::BlasBackend::Cublaslt
|
| 423 |
+
: at::BlasBackend::Cublas;
|
| 424 |
+
#ifdef C10_MOBILE
|
| 425 |
+
bool release_original_weights = true;
|
| 426 |
+
#else
|
| 427 |
+
bool release_original_weights = false;
|
| 428 |
+
#endif
|
| 429 |
+
bool display_vmap_fallback_warnings_ = false;
|
| 430 |
+
std::optional<at::QEngine> quantized_engine = std::nullopt;
|
| 431 |
+
bool enable_sparse_tensor_invariant_checks = false;
|
| 432 |
+
bool allow_fp16_reduction_cpu = false;
|
| 433 |
+
|
| 434 |
+
Allocator* prev_allocator_ptr_{nullptr};
|
| 435 |
+
};
|
| 436 |
+
|
| 437 |
+
TORCH_API Context& globalContext();
|
| 438 |
+
|
| 439 |
+
inline void init() {
|
| 440 |
+
globalContext();
|
| 441 |
+
}
|
| 442 |
+
|
| 443 |
+
TORCH_API Allocator* getCPUAllocator();
|
| 444 |
+
|
| 445 |
+
inline DeprecatedTypeProperties& getDeprecatedTypeProperties(
|
| 446 |
+
Backend p,
|
| 447 |
+
ScalarType s) {
|
| 448 |
+
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
|
| 449 |
+
p, s);
|
| 450 |
+
}
|
| 451 |
+
|
| 452 |
+
inline DeprecatedTypeProperties& CPU(ScalarType s) {
|
| 453 |
+
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
|
| 454 |
+
Backend::CPU, s);
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
inline DeprecatedTypeProperties& CUDA(ScalarType s) {
|
| 458 |
+
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
|
| 459 |
+
Backend::CUDA, s);
|
| 460 |
+
}
|
| 461 |
+
|
| 462 |
+
inline DeprecatedTypeProperties& HIP(ScalarType s) {
|
| 463 |
+
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
|
| 464 |
+
Backend::HIP, s);
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
inline DeprecatedTypeProperties& MPS(ScalarType s) {
|
| 468 |
+
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
|
| 469 |
+
Backend::MPS, s);
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
inline bool hasCUDA() {
|
| 473 |
+
return globalContext().hasCUDA();
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
inline bool hasMTIA() {
|
| 477 |
+
return globalContext().hasMTIA();
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
inline bool hasHIP() {
|
| 481 |
+
return globalContext().hasHIP();
|
| 482 |
+
}
|
| 483 |
+
|
| 484 |
+
inline bool hasIPU() {
|
| 485 |
+
return globalContext().hasIPU();
|
| 486 |
+
}
|
| 487 |
+
|
| 488 |
+
inline bool hasXLA() {
|
| 489 |
+
return globalContext().hasXLA();
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
inline bool hasMPS() {
|
| 493 |
+
return globalContext().hasMPS();
|
| 494 |
+
}
|
| 495 |
+
|
| 496 |
+
inline bool hasMAIA() {
|
| 497 |
+
return globalContext().hasMAIA();
|
| 498 |
+
}
|
| 499 |
+
|
| 500 |
+
inline bool hasXPU() {
|
| 501 |
+
return globalContext().hasXPU();
|
| 502 |
+
}
|
| 503 |
+
|
| 504 |
+
// Despite its name, this function returns the number of *CUDA* GPUs.
|
| 505 |
+
inline size_t getNumGPUs() {
|
| 506 |
+
// WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS
|
| 507 |
+
// FUNCTION. If you are interested in interrogating the number of
|
| 508 |
+
// devices for a specific device type, add that function to the
|
| 509 |
+
// relevant library (e.g., similar to at::cuda::device_count())
|
| 510 |
+
if (hasCUDA() && hasHIP()) {
|
| 511 |
+
throw std::runtime_error(
|
| 512 |
+
"Enabling both CUDA and HIP in ATen is not supported, as HIP masquerades "
|
| 513 |
+
"to be CUDA (e.g., when you say CUDA, on a HIP build of ATen, this actually "
|
| 514 |
+
"means HIP. Rebuild PyTorch with one or the other disabled.");
|
| 515 |
+
} else if (hasCUDA()) {
|
| 516 |
+
return detail::getCUDAHooks().getNumGPUs();
|
| 517 |
+
} else if (hasHIP()) {
|
| 518 |
+
return detail::getHIPHooks().getNumGPUs();
|
| 519 |
+
} else {
|
| 520 |
+
return 0;
|
| 521 |
+
}
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
inline bool hasOpenMP() {
|
| 525 |
+
return globalContext().hasOpenMP();
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
inline bool hasMKL() {
|
| 529 |
+
return globalContext().hasMKL();
|
| 530 |
+
}
|
| 531 |
+
|
| 532 |
+
inline bool hasLAPACK() {
|
| 533 |
+
return globalContext().hasLAPACK();
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
inline bool hasMAGMA() {
|
| 537 |
+
return globalContext().hasMAGMA();
|
| 538 |
+
}
|
| 539 |
+
|
| 540 |
+
inline bool hasMKLDNN() {
|
| 541 |
+
return globalContext().hasMKLDNN();
|
| 542 |
+
}
|
| 543 |
+
|
| 544 |
+
inline void manual_seed(uint64_t seed) {
|
| 545 |
+
auto gen = globalContext().defaultGenerator(c10::DeviceType::CPU);
|
| 546 |
+
{
|
| 547 |
+
// See Note [Acquire lock when using random generators]
|
| 548 |
+
std::lock_guard<std::mutex> lock(gen.mutex());
|
| 549 |
+
gen.set_current_seed(seed);
|
| 550 |
+
}
|
| 551 |
+
// NB: Sometimes we build with CUDA, but we don't have any GPUs
|
| 552 |
+
// available. In that case, we must not seed CUDA; it will fail!
|
| 553 |
+
const auto cuda_num_gpus = detail::getCUDAHooks().getNumGPUs();
|
| 554 |
+
if (hasCUDA() && cuda_num_gpus > 0) {
|
| 555 |
+
for (const auto i : c10::irange(cuda_num_gpus)) {
|
| 556 |
+
auto cuda_gen = globalContext().defaultGenerator(
|
| 557 |
+
Device(at::kCUDA, static_cast<c10::DeviceIndex>(i)));
|
| 558 |
+
{
|
| 559 |
+
// See Note [Acquire lock when using random generators]
|
| 560 |
+
std::lock_guard<std::mutex> lock(cuda_gen.mutex());
|
| 561 |
+
cuda_gen.set_current_seed(seed);
|
| 562 |
+
}
|
| 563 |
+
}
|
| 564 |
+
}
|
| 565 |
+
|
| 566 |
+
const auto xpu_num_gpus = detail::getXPUHooks().getNumGPUs();
|
| 567 |
+
if (hasXPU() && xpu_num_gpus) {
|
| 568 |
+
for (const auto i : c10::irange(xpu_num_gpus)) {
|
| 569 |
+
auto xpu_gen = globalContext().defaultGenerator(
|
| 570 |
+
Device(at::kXPU, static_cast<c10::DeviceIndex>(i)));
|
| 571 |
+
{
|
| 572 |
+
// See Note [Acquire lock when using random generators]
|
| 573 |
+
std::lock_guard<std::mutex> lock(xpu_gen.mutex());
|
| 574 |
+
xpu_gen.set_current_seed(seed);
|
| 575 |
+
}
|
| 576 |
+
}
|
| 577 |
+
}
|
| 578 |
+
|
| 579 |
+
if (hasMPS()) {
|
| 580 |
+
auto mps_gen = globalContext().defaultGenerator(c10::DeviceType::MPS);
|
| 581 |
+
// See Note [Acquire lock when using random generators]
|
| 582 |
+
std::lock_guard<std::mutex> lock(mps_gen.mutex());
|
| 583 |
+
mps_gen.set_current_seed(seed);
|
| 584 |
+
}
|
| 585 |
+
}
|
| 586 |
+
|
| 587 |
+
// When the global flag `allow_tf32` is set to true, cuBLAS handles are
|
| 588 |
+
// automatically configured to use math mode CUBLAS_TF32_TENSOR_OP_MATH.
|
| 589 |
+
// For some operators, such as addmv, TF32 offers no performance improvement
|
| 590 |
+
// but causes precision loss. To help this case, this class implements
|
| 591 |
+
// a RAII guard that can be used to quickly disable TF32 within its scope.
|
| 592 |
+
//
|
| 593 |
+
// Usage:
|
| 594 |
+
// NoTF32Guard disable_tf32;
|
| 595 |
+
struct TORCH_API NoTF32Guard {
|
| 596 |
+
NoTF32Guard();
|
| 597 |
+
~NoTF32Guard();
|
| 598 |
+
static bool should_disable_tf32();
|
| 599 |
+
|
| 600 |
+
private:
|
| 601 |
+
bool changed = false;
|
| 602 |
+
};
|
| 603 |
+
|
| 604 |
+
struct TORCH_API ROCmBackwardPassGuard {
|
| 605 |
+
ROCmBackwardPassGuard();
|
| 606 |
+
~ROCmBackwardPassGuard();
|
| 607 |
+
static bool is_backward_pass();
|
| 608 |
+
};
|
| 609 |
+
|
| 610 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/Device.h
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <c10/core/Device.h>
|
.venv/lib/python3.11/site-packages/torch/include/ATen/DeviceAccelerator.h
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/DeviceType.h>
|
| 4 |
+
#include <c10/macros/Macros.h>
|
| 5 |
+
|
| 6 |
+
#include <ATen/detail/MTIAHooksInterface.h>
|
| 7 |
+
#include <optional>
|
| 8 |
+
|
| 9 |
+
// This file defines the top level Accelerator concept for PyTorch.
|
| 10 |
+
// A device is an accelerator per the definition here if:
|
| 11 |
+
// - It is mutually exclusive with all other accelerators
|
| 12 |
+
// - It performs asynchronous compute via a Stream/Event system
|
| 13 |
+
// - It provides a set of common APIs as defined by AcceleratorHooksInterface
|
| 14 |
+
//
|
| 15 |
+
// As of today, accelerator devices are (in no particular order):
|
| 16 |
+
// CUDA, MTIA, XPU, HIP, MPS, PrivateUse1
|
| 17 |
+
|
| 18 |
+
namespace at {
|
| 19 |
+
|
| 20 |
+
// Ensures that only one accelerator is available (at
|
| 21 |
+
// compile time if possible) and return it.
|
| 22 |
+
// When checked is true, the returned optional always has a value.
|
| 23 |
+
TORCH_API std::optional<c10::DeviceType> getAccelerator(bool checked = false);
|
| 24 |
+
|
| 25 |
+
TORCH_API bool isAccelerator(c10::DeviceType d);
|
| 26 |
+
|
| 27 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h
ADDED
|
@@ -0,0 +1,808 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/DeprecatedTypeProperties.h>
|
| 4 |
+
#include <c10/macros/Macros.h>
|
| 5 |
+
#include <c10/util/Exception.h>
|
| 6 |
+
#include <c10/util/Half.h>
|
| 7 |
+
#include <c10/util/Metaprogramming.h>
|
| 8 |
+
#include <c10/util/complex.h>
|
| 9 |
+
#include <c10/util/string_view.h>
|
| 10 |
+
|
| 11 |
+
#ifdef __CUDACC__
|
| 12 |
+
#include <cuda.h> // For CUDA_VERSION
|
| 13 |
+
#endif
|
| 14 |
+
|
| 15 |
+
#ifdef TEMPLATE_SELECTIVE_BUILD
|
| 16 |
+
#include <ATen/selected_mobile_ops.h>
|
| 17 |
+
#else
|
| 18 |
+
namespace at {
|
| 19 |
+
/**
|
| 20 |
+
* The method should_include_kernel_dtype() returns true/false
|
| 21 |
+
* based on whether the switching code for a specific dtype should be
|
| 22 |
+
* included based on build time constants generated from tracing model
|
| 23 |
+
* execution. This method will be implemented via code-generation and
|
| 24 |
+
* included in this file when code-gen is ready.
|
| 25 |
+
*/
|
| 26 |
+
inline constexpr bool should_include_kernel_dtype(
|
| 27 |
+
const char* /*kernel_tag_str*/,
|
| 28 |
+
at::ScalarType /*scalar_type*/
|
| 29 |
+
) {
|
| 30 |
+
return true;
|
| 31 |
+
}
|
| 32 |
+
} // namespace at
|
| 33 |
+
#endif
|
| 34 |
+
|
| 35 |
+
/**
|
| 36 |
+
* In the Facebook internal build (using BUCK), this macro is enabled by
|
| 37 |
+
* passing in -c pt.enable_record_kernel_dtype=1 when building the tracer
|
| 38 |
+
* binary.
|
| 39 |
+
*/
|
| 40 |
+
#if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE
|
| 41 |
+
namespace at {
|
| 42 |
+
namespace detail {
|
| 43 |
+
TORCH_API void record_kernel_function_dtype(std::string name);
|
| 44 |
+
}
|
| 45 |
+
} // namespace at
|
| 46 |
+
|
| 47 |
+
#define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type) \
|
| 48 |
+
at::detail::record_kernel_function_dtype( \
|
| 49 |
+
std::string(NAME) + "$" + toString(enum_type));
|
| 50 |
+
#else
|
| 51 |
+
#define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type)
|
| 52 |
+
#endif
|
| 53 |
+
|
| 54 |
+
#define AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type) \
|
| 55 |
+
do { \
|
| 56 |
+
if constexpr (!at::should_include_kernel_dtype( \
|
| 57 |
+
at_dispatch_name, enum_type)) { \
|
| 58 |
+
AT_ERROR( \
|
| 59 |
+
"dtype '", \
|
| 60 |
+
toString(enum_type), \
|
| 61 |
+
"' not selected for kernel tag ", \
|
| 62 |
+
at_dispatch_name); \
|
| 63 |
+
} \
|
| 64 |
+
} while (0)
|
| 65 |
+
|
| 66 |
+
#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
|
| 67 |
+
case enum_type: { \
|
| 68 |
+
AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
|
| 69 |
+
using HINT C10_UNUSED = c10::impl::ScalarTypeToCPPTypeT<enum_type>; \
|
| 70 |
+
return __VA_ARGS__(); \
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
#define AT_DISPATCH_CASE(enum_type, ...) \
|
| 74 |
+
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
|
| 75 |
+
|
| 76 |
+
#define AT_DISPATCH_CASE_QINT(enum_type, scalar_type, ...) \
|
| 77 |
+
case enum_type: { \
|
| 78 |
+
AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
|
| 79 |
+
using scalar_t = scalar_type; \
|
| 80 |
+
using underlying_t C10_UNUSED = typename scalar_t::underlying; \
|
| 81 |
+
const auto& SCALAR_TYPE C10_UNUSED = enum_type; \
|
| 82 |
+
const auto& UNDERLYING_TYPE C10_UNUSED = toUnderlying(enum_type); \
|
| 83 |
+
return __VA_ARGS__(); \
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
#define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
|
| 87 |
+
enum_type, scalar_type, bitwidth, qmin, qmax, ...) \
|
| 88 |
+
case enum_type: { \
|
| 89 |
+
AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
|
| 90 |
+
using scalar_t = scalar_type; \
|
| 91 |
+
using underlying_t C10_UNUSED = typename scalar_t::underlying; \
|
| 92 |
+
const auto& SCALAR_TYPE C10_UNUSED = enum_type; \
|
| 93 |
+
const auto& UNDERLYING_TYPE C10_UNUSED = toUnderlying(enum_type); \
|
| 94 |
+
C10_UNUSED int bit_width = bitwidth; \
|
| 95 |
+
C10_UNUSED int64_t quant_min = qmin; \
|
| 96 |
+
C10_UNUSED int64_t quant_max = qmax; \
|
| 97 |
+
return __VA_ARGS__(); \
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
namespace detail {
|
| 101 |
+
|
| 102 |
+
inline at::ScalarType scalar_type(at::ScalarType s) {
|
| 103 |
+
return s;
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
C10_DEPRECATED_MESSAGE(
|
| 107 |
+
"passing at::DeprecatedTypeProperties to an AT_DISPATCH macro is deprecated, "
|
| 108 |
+
"pass an at::ScalarType instead")
|
| 109 |
+
inline at::ScalarType scalar_type(const at::DeprecatedTypeProperties& t) {
|
| 110 |
+
return t.scalarType();
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
C10_DEPRECATED_MESSAGE(
|
| 114 |
+
"AT_DISPATCH_ALL_TYPES_AND_HALF is deprecated, "
|
| 115 |
+
"use AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, ...) instead")
|
| 116 |
+
inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF() {}
|
| 117 |
+
|
| 118 |
+
C10_DEPRECATED_MESSAGE(
|
| 119 |
+
"AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX is deprecated, "
|
| 120 |
+
"use AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Half, ...) "
|
| 121 |
+
"instead")
|
| 122 |
+
inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
|
| 123 |
+
|
| 124 |
+
} // namespace detail
|
| 125 |
+
|
| 126 |
+
// The AT_DISPATCH_* family of macros provides the ability to
|
| 127 |
+
// conveniently generate specializations of a kernel over all of the
|
| 128 |
+
// dtypes we care about in PyTorch. We call it "dispatch" because
|
| 129 |
+
// we are "dispatching" to the correct, dtype-specific kernel.
|
| 130 |
+
//
|
| 131 |
+
// A standard usage looks like:
|
| 132 |
+
//
|
| 133 |
+
// AT_DISPATCH_ALL_TYPES(self.scalar_type(), "op_name", [&] {
|
| 134 |
+
// // Your code here, with 'scalar_t' now defined to
|
| 135 |
+
// // be the dtype in question
|
| 136 |
+
// });
|
| 137 |
+
//
|
| 138 |
+
// There are many variations of this macro, so it's important to
|
| 139 |
+
// understand exactly /which/ dtypes you want to get instantiated, as
|
| 140 |
+
// well as what the "default" set is.
|
| 141 |
+
//
|
| 142 |
+
// The default set of dtypes that are instantiated (e.g., by
|
| 143 |
+
// AT_DISPATCH_ALL_TYPES) are floating point types (float, double),
|
| 144 |
+
// and integral types (int32_t, int64_t, int16_t, int8_t, uint8_t),
|
| 145 |
+
// but NOT booleans (bool), half-precision floats (Half) or
|
| 146 |
+
// complex number (c10::complex<float>, c10::complex<double>).
|
| 147 |
+
// This "cut" is somewhat historical (the default types are the
|
| 148 |
+
// ones that TH historically supported), but it also reflects the
|
| 149 |
+
// fact that the non-default types are "poorly" behaved (booleans
|
| 150 |
+
// are NOT integers mod 2, half precision operations ~essentially
|
| 151 |
+
// don't exist on CPU, complex numbers are an experimental application).
|
| 152 |
+
//
|
| 153 |
+
// Here are the questions you should generally ask to decide which
|
| 154 |
+
// dispatch you want:
|
| 155 |
+
//
|
| 156 |
+
// 1. Is this an integral or floating point specific operation?
|
| 157 |
+
// (If so, you'll want one of the FLOATING or INTEGRAL macros.)
|
| 158 |
+
//
|
| 159 |
+
// 2. Should half be supported? (If you're on CPU, the answer is almost
|
| 160 |
+
// definitely no. If you do want support, use one of the AND_HALF
|
| 161 |
+
// macros)
|
| 162 |
+
//
|
| 163 |
+
// Much rarer situations:
|
| 164 |
+
//
|
| 165 |
+
// 3. Should bool be supported? (You often have to write your kernel
|
| 166 |
+
// differently if arithmetic operations are involved.) If so,
|
| 167 |
+
// Use AT_DISPATCH_ALL_TYPES_AND along with ScalarType::Bool
|
| 168 |
+
//
|
| 169 |
+
// 4. Should complex be supported? The answer is almost always no,
|
| 170 |
+
// unless you are working on "generic" code that should work on
|
| 171 |
+
// all dtypes.
|
| 172 |
+
//
|
| 173 |
+
// Parameters:
|
| 174 |
+
// -----------
|
| 175 |
+
//
|
| 176 |
+
// 1. The NAME argument is a "tag" that is used to trace and then
|
| 177 |
+
// conditionally compile fragments of the case statements such
|
| 178 |
+
// that the kernel functions are specialized only for the dtypes
|
| 179 |
+
// that are needed. The NAME parameter *must* be a build time
|
| 180 |
+
// const char* (can't be std::string, etc...)
|
| 181 |
+
//
|
| 182 |
+
// Please ensure that the NAME is unique for every implementation
|
| 183 |
+
// or you run the risk of over-including code for the kernel
|
| 184 |
+
// functions. There is no risk of missing out on any code, so
|
| 185 |
+
// it's mostly a risk of a Type-2 error, and not a Type-1 error.
|
| 186 |
+
//
|
| 187 |
+
// Switch-like syntax:
|
| 188 |
+
// -------------------
|
| 189 |
+
// There is also a switch-case like syntax which is useful if a kernel
|
| 190 |
+
// needs to be specialized for particular scalar types
|
| 191 |
+
//
|
| 192 |
+
// AT_DISPATCH_SWITCH(self.scalar_type(), "op_name",
|
| 193 |
+
// AT_DISPATCH_CASE_INTEGRAL_TYPES([&] {
|
| 194 |
+
// op_integral<scalar_t>(iter);
|
| 195 |
+
// })
|
| 196 |
+
// AT_DISPATCH_CASE_FLOATING_TYPES([&] {
|
| 197 |
+
// op_floating<scalar_t>(iter);
|
| 198 |
+
// })
|
| 199 |
+
// AT_DISPATCH_CASE(kBool, [&] {
|
| 200 |
+
// op_bool(iter);
|
| 201 |
+
// })
|
| 202 |
+
// );
|
| 203 |
+
//
|
| 204 |
+
// For each AT_DISPATCH_FOO macro, there is a corresponding
|
| 205 |
+
// AT_DISPATCH_CASE_FOO macro which can be used inside of an
|
| 206 |
+
// AT_DISPATCH_SWITCH block.
|
| 207 |
+
|
| 208 |
+
// NB: the the_type variable is not used, but we have kept it for
|
| 209 |
+
// backwards compatibility. It's probably not used by anyone though;
|
| 210 |
+
// but we're just being safe (and it doesn't hurt.) Note we must
|
| 211 |
+
// use it to shut up warnings about unused store.
|
| 212 |
+
|
| 213 |
+
#define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
|
| 214 |
+
[&] { \
|
| 215 |
+
const auto& the_type = TYPE; \
|
| 216 |
+
constexpr const char* at_dispatch_name = NAME; \
|
| 217 |
+
/* don't use TYPE again in case it is an expensive or side-effect op */ \
|
| 218 |
+
at::ScalarType _st = ::detail::scalar_type(the_type); \
|
| 219 |
+
RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st); \
|
| 220 |
+
switch (_st) { \
|
| 221 |
+
__VA_ARGS__ \
|
| 222 |
+
default: \
|
| 223 |
+
AT_ERROR( \
|
| 224 |
+
'"', \
|
| 225 |
+
at_dispatch_name, \
|
| 226 |
+
"\" not implemented for '", \
|
| 227 |
+
toString(_st), \
|
| 228 |
+
"'"); \
|
| 229 |
+
} \
|
| 230 |
+
}()
|
| 231 |
+
|
| 232 |
+
#define AT_DISPATCH_CASE_FLOATING_TYPES(...) \
|
| 233 |
+
AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
|
| 234 |
+
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)
|
| 235 |
+
|
| 236 |
+
#define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
| 237 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
| 238 |
+
|
| 239 |
+
#define AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(...) \
|
| 240 |
+
AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
|
| 241 |
+
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
| 242 |
+
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
|
| 243 |
+
|
| 244 |
+
#define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \
|
| 245 |
+
AT_DISPATCH_SWITCH( \
|
| 246 |
+
TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(__VA_ARGS__))
|
| 247 |
+
|
| 248 |
+
#define AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(...) \
|
| 249 |
+
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
| 250 |
+
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
| 251 |
+
|
| 252 |
+
#define AT_DISPATCH_REDUCED_FLOATING_TYPES(TYPE, NAME, ...) \
|
| 253 |
+
AT_DISPATCH_SWITCH( \
|
| 254 |
+
TYPE, NAME, AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__))
|
| 255 |
+
|
| 256 |
+
#define AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, ...) \
|
| 257 |
+
AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
|
| 258 |
+
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
|
| 259 |
+
|
| 260 |
+
#define AT_DISPATCH_FLOATING_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
|
| 261 |
+
AT_DISPATCH_SWITCH( \
|
| 262 |
+
TYPE, \
|
| 263 |
+
NAME, \
|
| 264 |
+
AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, __VA_ARGS__))
|
| 265 |
+
|
| 266 |
+
#define AT_DISPATCH_CASE_FLOATING_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \
|
| 267 |
+
AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
|
| 268 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 269 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
|
| 270 |
+
|
| 271 |
+
#define AT_DISPATCH_FLOATING_TYPES_AND2( \
|
| 272 |
+
SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
|
| 273 |
+
AT_DISPATCH_SWITCH( \
|
| 274 |
+
TYPE, \
|
| 275 |
+
NAME, \
|
| 276 |
+
AT_DISPATCH_CASE_FLOATING_TYPES_AND2( \
|
| 277 |
+
SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
|
| 278 |
+
|
| 279 |
+
#define AT_DISPATCH_CASE_FLOATING_TYPES_AND3( \
|
| 280 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
|
| 281 |
+
AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
|
| 282 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 283 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 284 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
|
| 285 |
+
|
| 286 |
+
#define AT_DISPATCH_FLOATING_TYPES_AND3( \
|
| 287 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
|
| 288 |
+
AT_DISPATCH_SWITCH( \
|
| 289 |
+
TYPE, \
|
| 290 |
+
NAME, \
|
| 291 |
+
AT_DISPATCH_CASE_FLOATING_TYPES_AND3( \
|
| 292 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
|
| 293 |
+
|
| 294 |
+
#define AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \
|
| 295 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
|
| 296 |
+
AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
|
| 297 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 298 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 299 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
| 300 |
+
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
|
| 301 |
+
|
| 302 |
+
#define AT_DISPATCH_FLOATING_TYPES_AND4( \
|
| 303 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
|
| 304 |
+
AT_DISPATCH_SWITCH( \
|
| 305 |
+
TYPE, \
|
| 306 |
+
NAME, \
|
| 307 |
+
AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \
|
| 308 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
|
| 309 |
+
|
| 310 |
+
#define AT_DISPATCH_CASE_COMPLEX_TYPES(...) \
|
| 311 |
+
AT_DISPATCH_CASE(at::ScalarType::ComplexDouble, __VA_ARGS__) \
|
| 312 |
+
AT_DISPATCH_CASE(at::ScalarType::ComplexFloat, __VA_ARGS__)
|
| 313 |
+
|
| 314 |
+
#define AT_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \
|
| 315 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__))
|
| 316 |
+
|
| 317 |
+
#define AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, ...) \
|
| 318 |
+
AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__) \
|
| 319 |
+
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
|
| 320 |
+
|
| 321 |
+
#define AT_DISPATCH_COMPLEX_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
|
| 322 |
+
AT_DISPATCH_SWITCH( \
|
| 323 |
+
TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, __VA_ARGS__))
|
| 324 |
+
|
| 325 |
+
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(...) \
|
| 326 |
+
AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
|
| 327 |
+
AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)
|
| 328 |
+
|
| 329 |
+
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
|
| 330 |
+
AT_DISPATCH_SWITCH( \
|
| 331 |
+
TYPE, NAME, AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__))
|
| 332 |
+
|
| 333 |
+
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1(SCALARTYPE, ...) \
|
| 334 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
|
| 335 |
+
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
|
| 336 |
+
|
| 337 |
+
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1( \
|
| 338 |
+
SCALARTYPE, TYPE, NAME, ...) \
|
| 339 |
+
AT_DISPATCH_SWITCH( \
|
| 340 |
+
TYPE, \
|
| 341 |
+
NAME, \
|
| 342 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1( \
|
| 343 |
+
SCALARTYPE, __VA_ARGS__))
|
| 344 |
+
|
| 345 |
+
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \
|
| 346 |
+
SCALARTYPE1, SCALARTYPE2, ...) \
|
| 347 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
|
| 348 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 349 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
|
| 350 |
+
|
| 351 |
+
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( \
|
| 352 |
+
SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
|
| 353 |
+
AT_DISPATCH_SWITCH( \
|
| 354 |
+
TYPE, \
|
| 355 |
+
NAME, \
|
| 356 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \
|
| 357 |
+
SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
|
| 358 |
+
|
| 359 |
+
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \
|
| 360 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
|
| 361 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
|
| 362 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 363 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 364 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
|
| 365 |
+
|
| 366 |
+
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3( \
|
| 367 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
|
| 368 |
+
AT_DISPATCH_SWITCH( \
|
| 369 |
+
TYPE, \
|
| 370 |
+
NAME, \
|
| 371 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \
|
| 372 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
|
| 373 |
+
|
| 374 |
+
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \
|
| 375 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
|
| 376 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
|
| 377 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 378 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 379 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
| 380 |
+
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
|
| 381 |
+
|
| 382 |
+
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4( \
|
| 383 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
|
| 384 |
+
AT_DISPATCH_SWITCH( \
|
| 385 |
+
TYPE, \
|
| 386 |
+
NAME, \
|
| 387 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \
|
| 388 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
|
| 389 |
+
|
| 390 |
+
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND5( \
|
| 391 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \
|
| 392 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
|
| 393 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 394 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 395 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
| 396 |
+
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
|
| 397 |
+
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)
|
| 398 |
+
|
| 399 |
+
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND5( \
|
| 400 |
+
SCALARTYPE1, \
|
| 401 |
+
SCALARTYPE2, \
|
| 402 |
+
SCALARTYPE3, \
|
| 403 |
+
SCALARTYPE4, \
|
| 404 |
+
SCALARTYPE5, \
|
| 405 |
+
TYPE, \
|
| 406 |
+
NAME, \
|
| 407 |
+
...) \
|
| 408 |
+
AT_DISPATCH_SWITCH( \
|
| 409 |
+
TYPE, \
|
| 410 |
+
NAME, \
|
| 411 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND5( \
|
| 412 |
+
SCALARTYPE1, \
|
| 413 |
+
SCALARTYPE2, \
|
| 414 |
+
SCALARTYPE3, \
|
| 415 |
+
SCALARTYPE4, \
|
| 416 |
+
SCALARTYPE5, \
|
| 417 |
+
__VA_ARGS__))
|
| 418 |
+
|
| 419 |
+
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND6( \
|
| 420 |
+
SCALARTYPE1, \
|
| 421 |
+
SCALARTYPE2, \
|
| 422 |
+
SCALARTYPE3, \
|
| 423 |
+
SCALARTYPE4, \
|
| 424 |
+
SCALARTYPE5, \
|
| 425 |
+
SCALARTYPE6, \
|
| 426 |
+
...) \
|
| 427 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
|
| 428 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 429 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 430 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
| 431 |
+
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
|
| 432 |
+
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
|
| 433 |
+
AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__)
|
| 434 |
+
|
| 435 |
+
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND6( \
|
| 436 |
+
SCALARTYPE1, \
|
| 437 |
+
SCALARTYPE2, \
|
| 438 |
+
SCALARTYPE3, \
|
| 439 |
+
SCALARTYPE4, \
|
| 440 |
+
SCALARTYPE5, \
|
| 441 |
+
SCALARTYPE6, \
|
| 442 |
+
TYPE, \
|
| 443 |
+
NAME, \
|
| 444 |
+
...) \
|
| 445 |
+
AT_DISPATCH_SWITCH( \
|
| 446 |
+
TYPE, \
|
| 447 |
+
NAME, \
|
| 448 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND6( \
|
| 449 |
+
SCALARTYPE1, \
|
| 450 |
+
SCALARTYPE2, \
|
| 451 |
+
SCALARTYPE3, \
|
| 452 |
+
SCALARTYPE4, \
|
| 453 |
+
SCALARTYPE5, \
|
| 454 |
+
SCALARTYPE6, \
|
| 455 |
+
__VA_ARGS__))
|
| 456 |
+
|
| 457 |
+
#define AT_DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
| 458 |
+
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
| 459 |
+
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
| 460 |
+
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
| 461 |
+
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
|
| 462 |
+
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__)
|
| 463 |
+
|
| 464 |
+
#define AT_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
| 465 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
| 466 |
+
|
| 467 |
+
#define AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, ...) \
|
| 468 |
+
AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \
|
| 469 |
+
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
|
| 470 |
+
|
| 471 |
+
#define AT_DISPATCH_INTEGRAL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
|
| 472 |
+
AT_DISPATCH_SWITCH( \
|
| 473 |
+
TYPE, \
|
| 474 |
+
NAME, \
|
| 475 |
+
AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, __VA_ARGS__))
|
| 476 |
+
|
| 477 |
+
#define AT_DISPATCH_CASE_ALL_TYPES(...) \
|
| 478 |
+
AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \
|
| 479 |
+
AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)
|
| 480 |
+
|
| 481 |
+
#define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
|
| 482 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__))
|
| 483 |
+
|
| 484 |
+
#define AT_DISPATCH_CASE_QINT_TYPES(...) \
|
| 485 |
+
AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \
|
| 486 |
+
AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__) \
|
| 487 |
+
AT_DISPATCH_CASE_QINT(at::kQInt32, at::qint32, __VA_ARGS__)
|
| 488 |
+
|
| 489 |
+
#define AT_DISPATCH_QINT_TYPES(TYPE, NAME, ...) \
|
| 490 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__))
|
| 491 |
+
|
| 492 |
+
#define AT_DISPATCH_CASE_QINT_TYPES_AND(SCALARTYPE, ...) \
|
| 493 |
+
AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__) \
|
| 494 |
+
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
|
| 495 |
+
|
| 496 |
+
#define AT_DISPATCH_QINT_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
|
| 497 |
+
AT_DISPATCH_SWITCH( \
|
| 498 |
+
TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES_AND(SCALARTYPE, __VA_ARGS__))
|
| 499 |
+
|
| 500 |
+
#define AT_DISPATCH_CASE_QINT_BYTE_TYPES(...) \
|
| 501 |
+
AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \
|
| 502 |
+
AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__)
|
| 503 |
+
|
| 504 |
+
#define AT_DISPATCH_QINT_BYTE_TYPES(TYPE, NAME, ...) \
|
| 505 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_BYTE_TYPES(__VA_ARGS__))
|
| 506 |
+
|
| 507 |
+
#define AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(...) \
|
| 508 |
+
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
|
| 509 |
+
at::kQInt8, at::qint8, CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__) \
|
| 510 |
+
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
|
| 511 |
+
at::kQUInt8, at::quint8, CHAR_BIT, 0, UCHAR_MAX, __VA_ARGS__) \
|
| 512 |
+
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
|
| 513 |
+
at::kQInt32, \
|
| 514 |
+
at::qint32, \
|
| 515 |
+
CHAR_BIT * sizeof(int), \
|
| 516 |
+
INT_MIN, \
|
| 517 |
+
INT_MAX, \
|
| 518 |
+
__VA_ARGS__) \
|
| 519 |
+
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
|
| 520 |
+
at::kQUInt4x2, at::quint4x2, 4, 0, 15, __VA_ARGS__) \
|
| 521 |
+
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
|
| 522 |
+
at::kQUInt2x4, at::quint2x4, 2, 0, 3, __VA_ARGS__)
|
| 523 |
+
|
| 524 |
+
#define AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(TYPE, NAME, ...) \
|
| 525 |
+
AT_DISPATCH_SWITCH( \
|
| 526 |
+
TYPE, NAME, AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(__VA_ARGS__))
|
| 527 |
+
|
| 528 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(...) \
|
| 529 |
+
AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
|
| 530 |
+
AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)
|
| 531 |
+
|
| 532 |
+
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX(TYPE, NAME, ...) \
|
| 533 |
+
AT_DISPATCH_SWITCH( \
|
| 534 |
+
TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__))
|
| 535 |
+
|
| 536 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, ...) \
|
| 537 |
+
AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
|
| 538 |
+
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
|
| 539 |
+
|
| 540 |
+
#define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
|
| 541 |
+
AT_DISPATCH_SWITCH( \
|
| 542 |
+
TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, __VA_ARGS__))
|
| 543 |
+
|
| 544 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, ...) \
|
| 545 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
|
| 546 |
+
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
|
| 547 |
+
|
| 548 |
+
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, TYPE, NAME, ...) \
|
| 549 |
+
AT_DISPATCH_SWITCH( \
|
| 550 |
+
TYPE, \
|
| 551 |
+
NAME, \
|
| 552 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, __VA_ARGS__))
|
| 553 |
+
|
| 554 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \
|
| 555 |
+
AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
|
| 556 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 557 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
|
| 558 |
+
|
| 559 |
+
#define AT_DISPATCH_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
|
| 560 |
+
AT_DISPATCH_SWITCH( \
|
| 561 |
+
TYPE, \
|
| 562 |
+
NAME, \
|
| 563 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
|
| 564 |
+
|
| 565 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \
|
| 566 |
+
SCALARTYPE1, SCALARTYPE2, ...) \
|
| 567 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
|
| 568 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 569 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
|
| 570 |
+
|
| 571 |
+
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \
|
| 572 |
+
SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
|
| 573 |
+
AT_DISPATCH_SWITCH( \
|
| 574 |
+
TYPE, \
|
| 575 |
+
NAME, \
|
| 576 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \
|
| 577 |
+
SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
|
| 578 |
+
|
| 579 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND3( \
|
| 580 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
|
| 581 |
+
AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
|
| 582 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 583 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 584 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
|
| 585 |
+
|
| 586 |
+
#define AT_DISPATCH_ALL_TYPES_AND3( \
|
| 587 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
|
| 588 |
+
AT_DISPATCH_SWITCH( \
|
| 589 |
+
TYPE, \
|
| 590 |
+
NAME, \
|
| 591 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND3( \
|
| 592 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
|
| 593 |
+
|
| 594 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3( \
|
| 595 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
|
| 596 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
|
| 597 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 598 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 599 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
|
| 600 |
+
|
| 601 |
+
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \
|
| 602 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
|
| 603 |
+
AT_DISPATCH_SWITCH( \
|
| 604 |
+
TYPE, \
|
| 605 |
+
NAME, \
|
| 606 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3( \
|
| 607 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
|
| 608 |
+
|
| 609 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
|
| 610 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
|
| 611 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
|
| 612 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 613 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 614 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
| 615 |
+
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
|
| 616 |
+
|
| 617 |
+
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
|
| 618 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
|
| 619 |
+
AT_DISPATCH_SWITCH( \
|
| 620 |
+
TYPE, \
|
| 621 |
+
NAME, \
|
| 622 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
|
| 623 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
|
| 624 |
+
|
| 625 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \
|
| 626 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \
|
| 627 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
|
| 628 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 629 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 630 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
| 631 |
+
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
|
| 632 |
+
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)
|
| 633 |
+
|
| 634 |
+
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5( \
|
| 635 |
+
SCALARTYPE1, \
|
| 636 |
+
SCALARTYPE2, \
|
| 637 |
+
SCALARTYPE3, \
|
| 638 |
+
SCALARTYPE4, \
|
| 639 |
+
SCALARTYPE5, \
|
| 640 |
+
TYPE, \
|
| 641 |
+
NAME, \
|
| 642 |
+
...) \
|
| 643 |
+
AT_DISPATCH_SWITCH( \
|
| 644 |
+
TYPE, \
|
| 645 |
+
NAME, \
|
| 646 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \
|
| 647 |
+
SCALARTYPE1, \
|
| 648 |
+
SCALARTYPE2, \
|
| 649 |
+
SCALARTYPE3, \
|
| 650 |
+
SCALARTYPE4, \
|
| 651 |
+
SCALARTYPE5, \
|
| 652 |
+
__VA_ARGS__))
|
| 653 |
+
|
| 654 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \
|
| 655 |
+
SCALARTYPE1, \
|
| 656 |
+
SCALARTYPE2, \
|
| 657 |
+
SCALARTYPE3, \
|
| 658 |
+
SCALARTYPE4, \
|
| 659 |
+
SCALARTYPE5, \
|
| 660 |
+
SCALARTYPE6, \
|
| 661 |
+
...) \
|
| 662 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
|
| 663 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 664 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 665 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
| 666 |
+
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
|
| 667 |
+
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
|
| 668 |
+
AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__)
|
| 669 |
+
|
| 670 |
+
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
|
| 671 |
+
SCALARTYPE1, \
|
| 672 |
+
SCALARTYPE2, \
|
| 673 |
+
SCALARTYPE3, \
|
| 674 |
+
SCALARTYPE4, \
|
| 675 |
+
SCALARTYPE5, \
|
| 676 |
+
SCALARTYPE6, \
|
| 677 |
+
TYPE, \
|
| 678 |
+
NAME, \
|
| 679 |
+
...) \
|
| 680 |
+
AT_DISPATCH_SWITCH( \
|
| 681 |
+
TYPE, \
|
| 682 |
+
NAME, \
|
| 683 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \
|
| 684 |
+
SCALARTYPE1, \
|
| 685 |
+
SCALARTYPE2, \
|
| 686 |
+
SCALARTYPE3, \
|
| 687 |
+
SCALARTYPE4, \
|
| 688 |
+
SCALARTYPE5, \
|
| 689 |
+
SCALARTYPE6, \
|
| 690 |
+
__VA_ARGS__))
|
| 691 |
+
|
| 692 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND7( \
|
| 693 |
+
SCALARTYPE1, \
|
| 694 |
+
SCALARTYPE2, \
|
| 695 |
+
SCALARTYPE3, \
|
| 696 |
+
SCALARTYPE4, \
|
| 697 |
+
SCALARTYPE5, \
|
| 698 |
+
SCALARTYPE6, \
|
| 699 |
+
SCALARTYPE7, \
|
| 700 |
+
...) \
|
| 701 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
|
| 702 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 703 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 704 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
| 705 |
+
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
|
| 706 |
+
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
|
| 707 |
+
AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__) \
|
| 708 |
+
AT_DISPATCH_CASE(SCALARTYPE7, __VA_ARGS__)
|
| 709 |
+
|
| 710 |
+
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND7( \
|
| 711 |
+
SCALARTYPE1, \
|
| 712 |
+
SCALARTYPE2, \
|
| 713 |
+
SCALARTYPE3, \
|
| 714 |
+
SCALARTYPE4, \
|
| 715 |
+
SCALARTYPE5, \
|
| 716 |
+
SCALARTYPE6, \
|
| 717 |
+
SCALARTYPE7, \
|
| 718 |
+
TYPE, \
|
| 719 |
+
NAME, \
|
| 720 |
+
...) \
|
| 721 |
+
AT_DISPATCH_SWITCH( \
|
| 722 |
+
TYPE, \
|
| 723 |
+
NAME, \
|
| 724 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND7( \
|
| 725 |
+
SCALARTYPE1, \
|
| 726 |
+
SCALARTYPE2, \
|
| 727 |
+
SCALARTYPE3, \
|
| 728 |
+
SCALARTYPE4, \
|
| 729 |
+
SCALARTYPE5, \
|
| 730 |
+
SCALARTYPE6, \
|
| 731 |
+
SCALARTYPE7, \
|
| 732 |
+
__VA_ARGS__))
|
| 733 |
+
|
| 734 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND8( \
|
| 735 |
+
SCALARTYPE1, \
|
| 736 |
+
SCALARTYPE2, \
|
| 737 |
+
SCALARTYPE3, \
|
| 738 |
+
SCALARTYPE4, \
|
| 739 |
+
SCALARTYPE5, \
|
| 740 |
+
SCALARTYPE6, \
|
| 741 |
+
SCALARTYPE7, \
|
| 742 |
+
SCALARTYPE8, \
|
| 743 |
+
...) \
|
| 744 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
|
| 745 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 746 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 747 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
| 748 |
+
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
|
| 749 |
+
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
|
| 750 |
+
AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__) \
|
| 751 |
+
AT_DISPATCH_CASE(SCALARTYPE7, __VA_ARGS__) \
|
| 752 |
+
AT_DISPATCH_CASE(SCALARTYPE8, __VA_ARGS__)
|
| 753 |
+
|
| 754 |
+
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND8( \
|
| 755 |
+
SCALARTYPE1, \
|
| 756 |
+
SCALARTYPE2, \
|
| 757 |
+
SCALARTYPE3, \
|
| 758 |
+
SCALARTYPE4, \
|
| 759 |
+
SCALARTYPE5, \
|
| 760 |
+
SCALARTYPE6, \
|
| 761 |
+
SCALARTYPE7, \
|
| 762 |
+
SCALARTYPE8, \
|
| 763 |
+
TYPE, \
|
| 764 |
+
NAME, \
|
| 765 |
+
...) \
|
| 766 |
+
AT_DISPATCH_SWITCH( \
|
| 767 |
+
TYPE, \
|
| 768 |
+
NAME, \
|
| 769 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND8( \
|
| 770 |
+
SCALARTYPE1, \
|
| 771 |
+
SCALARTYPE2, \
|
| 772 |
+
SCALARTYPE3, \
|
| 773 |
+
SCALARTYPE4, \
|
| 774 |
+
SCALARTYPE5, \
|
| 775 |
+
SCALARTYPE6, \
|
| 776 |
+
SCALARTYPE7, \
|
| 777 |
+
SCALARTYPE8, \
|
| 778 |
+
__VA_ARGS__))
|
| 779 |
+
|
| 780 |
+
#define AT_DISPATCH_CASE_BIT_TYPES(...) \
|
| 781 |
+
AT_DISPATCH_CASE(at::ScalarType::Bits1x8, __VA_ARGS__) \
|
| 782 |
+
AT_DISPATCH_CASE(at::ScalarType::Bits2x4, __VA_ARGS__) \
|
| 783 |
+
AT_DISPATCH_CASE(at::ScalarType::Bits4x2, __VA_ARGS__) \
|
| 784 |
+
AT_DISPATCH_CASE(at::ScalarType::Bits8, __VA_ARGS__) \
|
| 785 |
+
AT_DISPATCH_CASE(at::ScalarType::Bits16, __VA_ARGS__)
|
| 786 |
+
|
| 787 |
+
#define AT_DISPATCH_BIT_TYPES(TYPE, NAME, ...) \
|
| 788 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_BIT_TYPES(__VA_ARGS__))
|
| 789 |
+
|
| 790 |
+
#define AT_DISPATCH_INDEX_TYPES(TYPE, NAME, ...) \
|
| 791 |
+
AT_DISPATCH_SWITCH( \
|
| 792 |
+
TYPE, \
|
| 793 |
+
NAME, \
|
| 794 |
+
AT_PRIVATE_CASE_TYPE_USING_HINT( \
|
| 795 |
+
at::ScalarType::Int, index_t, __VA_ARGS__) \
|
| 796 |
+
AT_PRIVATE_CASE_TYPE_USING_HINT( \
|
| 797 |
+
at::ScalarType::Long, index_t, __VA_ARGS__))
|
| 798 |
+
|
| 799 |
+
// ----------------------------------------------------------------------------
|
| 800 |
+
// DEPRECATED MACROS, DON'T USE THESE
|
| 801 |
+
// ----------------------------------------------------------------------------
|
| 802 |
+
|
| 803 |
+
#define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...) \
|
| 804 |
+
detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF(); \
|
| 805 |
+
AT_DISPATCH_SWITCH( \
|
| 806 |
+
TYPE, \
|
| 807 |
+
NAME, \
|
| 808 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND(at::ScalarType::Half, __VA_ARGS__))
|
.venv/lib/python3.11/site-packages/torch/include/ATen/EmptyTensor.h
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/TensorBase.h>
|
| 3 |
+
|
| 4 |
+
namespace at::detail {
|
| 5 |
+
|
| 6 |
+
inline void check_size_nonnegative(ArrayRef<int64_t> size) {
|
| 7 |
+
for (const auto& x : size) {
|
| 8 |
+
TORCH_CHECK(
|
| 9 |
+
x >= 0,
|
| 10 |
+
"Trying to create tensor with negative dimension ",
|
| 11 |
+
x,
|
| 12 |
+
": ",
|
| 13 |
+
size);
|
| 14 |
+
}
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
inline void check_size_nonnegative(ArrayRef<c10::SymInt> size) {
|
| 18 |
+
for (const auto& x : size) {
|
| 19 |
+
TORCH_CHECK(
|
| 20 |
+
x.expect_size(__FILE__, __LINE__),
|
| 21 |
+
"Trying to create tensor with negative dimension ",
|
| 22 |
+
x,
|
| 23 |
+
": ",
|
| 24 |
+
size);
|
| 25 |
+
}
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
TORCH_API size_t computeStorageNbytesContiguous(
|
| 29 |
+
IntArrayRef sizes,
|
| 30 |
+
size_t itemsize,
|
| 31 |
+
size_t storage_offset = 0);
|
| 32 |
+
TORCH_API SymInt computeStorageNbytesContiguous(
|
| 33 |
+
SymIntArrayRef sizes,
|
| 34 |
+
const SymInt& itemsize,
|
| 35 |
+
const SymInt& storage_offset = 0);
|
| 36 |
+
TORCH_API size_t computeStorageNbytes(
|
| 37 |
+
IntArrayRef sizes,
|
| 38 |
+
IntArrayRef strides,
|
| 39 |
+
size_t itemsize,
|
| 40 |
+
size_t storage_offset = 0);
|
| 41 |
+
TORCH_API SymInt computeStorageNbytes(
|
| 42 |
+
SymIntArrayRef sizes,
|
| 43 |
+
SymIntArrayRef strides,
|
| 44 |
+
const SymInt& itemsize,
|
| 45 |
+
const SymInt& storage_offset = 0);
|
| 46 |
+
|
| 47 |
+
TORCH_API TensorBase empty_generic(
|
| 48 |
+
IntArrayRef size,
|
| 49 |
+
c10::Allocator* allocator,
|
| 50 |
+
c10::DispatchKeySet ks,
|
| 51 |
+
ScalarType scalar_type,
|
| 52 |
+
std::optional<c10::MemoryFormat> memory_format_opt);
|
| 53 |
+
|
| 54 |
+
TORCH_API TensorBase empty_generic_symint(
|
| 55 |
+
SymIntArrayRef size,
|
| 56 |
+
c10::Allocator* allocator,
|
| 57 |
+
c10::DispatchKeySet ks,
|
| 58 |
+
ScalarType scalar_type,
|
| 59 |
+
std::optional<c10::MemoryFormat> memory_format_opt);
|
| 60 |
+
|
| 61 |
+
TORCH_API TensorBase empty_strided_generic(
|
| 62 |
+
IntArrayRef size,
|
| 63 |
+
IntArrayRef stride,
|
| 64 |
+
c10::Allocator* allocator,
|
| 65 |
+
c10::DispatchKeySet ks,
|
| 66 |
+
ScalarType scalar_type);
|
| 67 |
+
|
| 68 |
+
TORCH_API TensorBase empty_strided_symint_generic(
|
| 69 |
+
SymIntArrayRef size,
|
| 70 |
+
SymIntArrayRef stride,
|
| 71 |
+
c10::Allocator* allocator,
|
| 72 |
+
c10::DispatchKeySet ks,
|
| 73 |
+
ScalarType scalar_type);
|
| 74 |
+
|
| 75 |
+
TORCH_API TensorBase empty_cpu(
|
| 76 |
+
IntArrayRef size,
|
| 77 |
+
ScalarType dtype,
|
| 78 |
+
bool pin_memory = false,
|
| 79 |
+
std::optional<c10::MemoryFormat> memory_format_opt = std::nullopt);
|
| 80 |
+
|
| 81 |
+
TORCH_API TensorBase empty_cpu(
|
| 82 |
+
IntArrayRef size,
|
| 83 |
+
std::optional<ScalarType> dtype_opt,
|
| 84 |
+
std::optional<Layout> layout_opt,
|
| 85 |
+
std::optional<Device> device_opt,
|
| 86 |
+
std::optional<bool> pin_memory_opt,
|
| 87 |
+
std::optional<c10::MemoryFormat> memory_format_opt);
|
| 88 |
+
|
| 89 |
+
TORCH_API TensorBase empty_cpu(IntArrayRef size, const TensorOptions& options);
|
| 90 |
+
|
| 91 |
+
TORCH_API TensorBase empty_strided_cpu(
|
| 92 |
+
IntArrayRef size,
|
| 93 |
+
IntArrayRef stride,
|
| 94 |
+
ScalarType dtype,
|
| 95 |
+
bool pin_memory = false);
|
| 96 |
+
|
| 97 |
+
TORCH_API TensorBase empty_strided_cpu(
|
| 98 |
+
IntArrayRef size,
|
| 99 |
+
IntArrayRef stride,
|
| 100 |
+
std::optional<ScalarType> dtype_opt,
|
| 101 |
+
std::optional<Layout> layout_opt,
|
| 102 |
+
std::optional<Device> device_opt,
|
| 103 |
+
std::optional<bool> pin_memory_opt);
|
| 104 |
+
|
| 105 |
+
TORCH_API TensorBase empty_strided_cpu(
|
| 106 |
+
IntArrayRef size,
|
| 107 |
+
IntArrayRef stride,
|
| 108 |
+
const TensorOptions& options);
|
| 109 |
+
|
| 110 |
+
TORCH_API TensorBase empty_meta(
|
| 111 |
+
IntArrayRef size,
|
| 112 |
+
ScalarType dtype,
|
| 113 |
+
std::optional<c10::MemoryFormat> memory_format_opt = std::nullopt);
|
| 114 |
+
|
| 115 |
+
TORCH_API TensorBase empty_meta(
|
| 116 |
+
IntArrayRef size,
|
| 117 |
+
std::optional<ScalarType> dtype_opt,
|
| 118 |
+
std::optional<Layout> layout_opt,
|
| 119 |
+
std::optional<Device> device_opt,
|
| 120 |
+
std::optional<bool> pin_memory_opt,
|
| 121 |
+
std::optional<c10::MemoryFormat> memory_format_opt);
|
| 122 |
+
|
| 123 |
+
TORCH_API TensorBase empty_symint_meta(
|
| 124 |
+
SymIntArrayRef size,
|
| 125 |
+
std::optional<ScalarType> dtype_opt,
|
| 126 |
+
std::optional<Layout> layout_opt,
|
| 127 |
+
std::optional<Device> device_opt,
|
| 128 |
+
std::optional<bool> pin_memory_opt,
|
| 129 |
+
std::optional<c10::MemoryFormat> memory_format_opt);
|
| 130 |
+
|
| 131 |
+
TORCH_API TensorBase empty_meta(IntArrayRef size, const TensorOptions& options);
|
| 132 |
+
|
| 133 |
+
TORCH_API TensorBase
|
| 134 |
+
empty_strided_meta(IntArrayRef size, IntArrayRef stride, ScalarType dtype);
|
| 135 |
+
|
| 136 |
+
TORCH_API TensorBase empty_strided_meta(
|
| 137 |
+
IntArrayRef size,
|
| 138 |
+
IntArrayRef stride,
|
| 139 |
+
std::optional<ScalarType> dtype_opt,
|
| 140 |
+
std::optional<Layout> layout_opt,
|
| 141 |
+
std::optional<Device> device_opt,
|
| 142 |
+
std::optional<bool> pin_memory_opt);
|
| 143 |
+
|
| 144 |
+
TORCH_API TensorBase empty_strided_meta(
|
| 145 |
+
IntArrayRef size,
|
| 146 |
+
IntArrayRef stride,
|
| 147 |
+
const TensorOptions& options);
|
| 148 |
+
|
| 149 |
+
TORCH_API TensorBase empty_strided_symint_meta(
|
| 150 |
+
SymIntArrayRef size,
|
| 151 |
+
SymIntArrayRef stride,
|
| 152 |
+
ScalarType dtype);
|
| 153 |
+
|
| 154 |
+
TORCH_API TensorBase empty_strided_symint_meta(
|
| 155 |
+
SymIntArrayRef size,
|
| 156 |
+
SymIntArrayRef stride,
|
| 157 |
+
std::optional<ScalarType> dtype_opt,
|
| 158 |
+
std::optional<Layout> layout_opt,
|
| 159 |
+
std::optional<Device> device_opt);
|
| 160 |
+
|
| 161 |
+
TORCH_API TensorBase empty_strided_symint_meta(
|
| 162 |
+
SymIntArrayRef size,
|
| 163 |
+
SymIntArrayRef stride,
|
| 164 |
+
const TensorOptions& options);
|
| 165 |
+
|
| 166 |
+
} // namespace at::detail
|
.venv/lib/python3.11/site-packages/torch/include/ATen/ExpandBase.h
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/TensorBase.h>
|
| 2 |
+
|
| 3 |
+
// Broadcasting utilities for working with TensorBase
|
| 4 |
+
namespace at {
|
| 5 |
+
namespace internal {
|
| 6 |
+
TORCH_API TensorBase expand_slow_path(const TensorBase& self, IntArrayRef size);
|
| 7 |
+
} // namespace internal
|
| 8 |
+
|
| 9 |
+
inline c10::MaybeOwned<TensorBase> expand_size(
|
| 10 |
+
const TensorBase& self,
|
| 11 |
+
IntArrayRef size) {
|
| 12 |
+
if (size.equals(self.sizes())) {
|
| 13 |
+
return c10::MaybeOwned<TensorBase>::borrowed(self);
|
| 14 |
+
}
|
| 15 |
+
return c10::MaybeOwned<TensorBase>::owned(
|
| 16 |
+
at::internal::expand_slow_path(self, size));
|
| 17 |
+
}
|
| 18 |
+
c10::MaybeOwned<TensorBase> expand_size(TensorBase&& self, IntArrayRef size) =
|
| 19 |
+
delete;
|
| 20 |
+
|
| 21 |
+
inline c10::MaybeOwned<TensorBase> expand_inplace(
|
| 22 |
+
const TensorBase& tensor,
|
| 23 |
+
const TensorBase& to_expand) {
|
| 24 |
+
return expand_size(to_expand, tensor.sizes());
|
| 25 |
+
}
|
| 26 |
+
c10::MaybeOwned<TensorBase> expand_inplace(
|
| 27 |
+
const TensorBase& tensor,
|
| 28 |
+
TensorBase&& to_expand) = delete;
|
| 29 |
+
|
| 30 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/FuncTorchTLS.h
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/macros/Macros.h>
|
| 4 |
+
#include <memory>
|
| 5 |
+
|
| 6 |
+
namespace at::functorch {
|
| 7 |
+
|
| 8 |
+
// NOTE [functorch TLS in pytorch/pytorch]
|
| 9 |
+
//
|
| 10 |
+
// functorch lives out-of-tree. However, it has some TLS that needs to be
|
| 11 |
+
// propagated. The solution for that is we store a pointer to the TLS
|
| 12 |
+
// inside pytorch/pytorch and extend FuncTorchTLSBase inside functorch to
|
| 13 |
+
// include whatever functorch needs.
|
| 14 |
+
//
|
| 15 |
+
// We need to store a pointer due to the indirection:
|
| 16 |
+
// inside functorch, we will create a subclass of FunctorchTLSBase called
|
| 17 |
+
// FuncTorchTLSImpl that actually contains metadata, like the DynamicLayerStack.
|
| 18 |
+
// FuncTorchTLSBase doesn't have any metadata because it hasn't been defined
|
| 19 |
+
// yet.
|
| 20 |
+
//
|
| 21 |
+
// Here in pytorch/pytorch, we will pass around FuncTorchTLSBase*, but inside
|
| 22 |
+
// functorch, we will assign a FuncTorchTLSImpl* to the FunctorchTLSBase*.
|
| 23 |
+
// We can't directly pass around FunctorchTLSBase (without a pointer) because
|
| 24 |
+
// FuncTorchTLSImpl does not fit inside a FuncTorchTLSBase by virtue of having
|
| 25 |
+
// more elements.
|
| 26 |
+
struct TORCH_API FuncTorchTLSBase {
|
| 27 |
+
virtual ~FuncTorchTLSBase() = default;
|
| 28 |
+
virtual std::unique_ptr<FuncTorchTLSBase> deepcopy() const = 0;
|
| 29 |
+
|
| 30 |
+
virtual int64_t checkSupportsSingleLevelAutogradFunction() const = 0;
|
| 31 |
+
virtual void checkSupportsCppAutogradFunction() const = 0;
|
| 32 |
+
virtual void checkSupportsInplaceRequiresGrad() const = 0;
|
| 33 |
+
virtual void checkSupportsRetainGrad() const = 0;
|
| 34 |
+
};
|
| 35 |
+
|
| 36 |
+
// returns deepcopy of the functorch tls
|
| 37 |
+
TORCH_API std::unique_ptr<FuncTorchTLSBase> getCopyOfFuncTorchTLS();
|
| 38 |
+
|
| 39 |
+
// sets the functorch tls. always does a deep copy.
|
| 40 |
+
TORCH_API void setFuncTorchTLS(
|
| 41 |
+
const std::shared_ptr<const FuncTorchTLSBase>& state);
|
| 42 |
+
|
| 43 |
+
// get a mutable reference to the functorch tls
|
| 44 |
+
TORCH_API std::unique_ptr<FuncTorchTLSBase>& functorchTLSAccessor();
|
| 45 |
+
|
| 46 |
+
} // namespace at::functorch
|
.venv/lib/python3.11/site-packages/torch/include/ATen/FunctionalStorageImpl.h
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/Tensor.h>
|
| 4 |
+
|
| 5 |
+
#include <utility>
|
| 6 |
+
|
| 7 |
+
namespace at::functionalization {
|
| 8 |
+
|
| 9 |
+
// See Note [Functionalization Pass In Core]
|
| 10 |
+
|
| 11 |
+
// ViewMeta is a class used by the functionalization pass to navigate between
|
| 12 |
+
// a base tensor and a view tensor.
|
| 13 |
+
// For example, if I call `b = a.view1(...)`
|
| 14 |
+
// the functionalization pass will generate and store a ViewMeta on b that looks
|
| 15 |
+
// like:
|
| 16 |
+
//
|
| 17 |
+
// ViewMeta(
|
| 18 |
+
// [<captures>](const Tensor& base, int64_t mutated_view_idx) {
|
| 19 |
+
// return base.view1(...);
|
| 20 |
+
// },
|
| 21 |
+
// [<captures>](const at::Tensor& base, const at::Tensor& mutated_view,
|
| 22 |
+
// int64_t mutated_view_idx) -> at::Tensor {
|
| 23 |
+
// return at::functionalization::impl::view1_inverse(base, mutated_view,
|
| 24 |
+
// ...);
|
| 25 |
+
// }
|
| 26 |
+
//
|
| 27 |
+
// The forward_fn lambda describes how to replay view1 on a tensor.
|
| 28 |
+
//
|
| 29 |
+
// The reverse_fn lambda describes how, given a tensor that is already a view,
|
| 30 |
+
// how to get the corresponding base tensor. See Note [Functionalization Pass:
|
| 31 |
+
// View Inverses] for details.
|
| 32 |
+
struct ViewMeta {
|
| 33 |
+
ViewMeta(
|
| 34 |
+
std::function<Tensor(const Tensor&, int64_t)> forward,
|
| 35 |
+
std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse,
|
| 36 |
+
bool has_symbolic_inputs,
|
| 37 |
+
bool is_multi_output = false,
|
| 38 |
+
bool is_as_strided = false,
|
| 39 |
+
int64_t out_idx = 0)
|
| 40 |
+
: forward_fn(std::move(forward)),
|
| 41 |
+
reverse_fn(std::move(reverse)),
|
| 42 |
+
out_index(out_idx),
|
| 43 |
+
is_multi_output(is_multi_output),
|
| 44 |
+
is_as_strided(is_as_strided),
|
| 45 |
+
has_symbolic_inputs(has_symbolic_inputs) {}
|
| 46 |
+
|
| 47 |
+
std::function<Tensor(const Tensor&, int64_t)> forward_fn;
|
| 48 |
+
std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse_fn;
|
| 49 |
+
// See Note [out_idx in ViewMeta]
|
| 50 |
+
int64_t out_index;
|
| 51 |
+
|
| 52 |
+
// Tells us if this is a multi-output view
|
| 53 |
+
bool is_multi_output;
|
| 54 |
+
|
| 55 |
+
bool is_as_strided;
|
| 56 |
+
|
| 57 |
+
// Tells us if this view operation has any symbolic inputs
|
| 58 |
+
bool has_symbolic_inputs;
|
| 59 |
+
|
| 60 |
+
// Returns a copy of the current ViewMeta, if out_idx matches the current
|
| 61 |
+
// out_index. Otherwise, returns a new ViewMeta with the same forward/reverse
|
| 62 |
+
// functions, but a new out index.
|
| 63 |
+
ViewMeta to_out_idx(int64_t out_idx);
|
| 64 |
+
};
|
| 65 |
+
|
| 66 |
+
// FunctionalStorageImpl is a subclass of StorageImpl used by the
|
| 67 |
+
// functionalization pass. It has no underlying data (similar to meta storage).
|
| 68 |
+
// It also knows how to reflect mutations to tensors in the absence of a valid
|
| 69 |
+
// data pointer.
|
| 70 |
+
//
|
| 71 |
+
// A storage represents the state shared by (potentially multiple) views of the
|
| 72 |
+
// same tensor. For example, in the following code:
|
| 73 |
+
//
|
| 74 |
+
// b = a.view1(...)
|
| 75 |
+
// c = b.view2(...)
|
| 76 |
+
// b.add_(1)
|
| 77 |
+
// --> storage.add_update(b, {view1_meta})
|
| 78 |
+
//
|
| 79 |
+
// The call to add_(1) will result in a call to alias.add_update(b,
|
| 80 |
+
// {view1_meta}), queueing up the mutation from b onto the alias. Later, suppose
|
| 81 |
+
// c is used in an expression (e.g. you try to print c, or pass it to an
|
| 82 |
+
// operator). Doing so will involve "syncing" c. First we apply any pending
|
| 83 |
+
// updates to the alias, and then we regenerate c by replaying its views off of
|
| 84 |
+
// the updated alias. E.g:
|
| 85 |
+
//
|
| 86 |
+
// print(str(c))
|
| 87 |
+
// --> c.sync_()
|
| 88 |
+
// --> alias.apply_updates() // after this, the alias will be updated to
|
| 89 |
+
// reflect the mutation to b
|
| 90 |
+
struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl {
|
| 91 |
+
public:
|
| 92 |
+
struct Update {
|
| 93 |
+
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
| 94 |
+
const at::Tensor new_val;
|
| 95 |
+
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
| 96 |
+
const std::vector<ViewMeta> view_metas;
|
| 97 |
+
};
|
| 98 |
+
|
| 99 |
+
explicit FunctionalStorageImpl(const Tensor& value);
|
| 100 |
+
|
| 101 |
+
void add_update(
|
| 102 |
+
const Tensor& updated_val,
|
| 103 |
+
const std::vector<ViewMeta>& view_metas);
|
| 104 |
+
bool apply_updates();
|
| 105 |
+
const Tensor& base() {
|
| 106 |
+
return base_;
|
| 107 |
+
}
|
| 108 |
+
size_t generation() const {
|
| 109 |
+
return generation_;
|
| 110 |
+
}
|
| 111 |
+
void freeze() {
|
| 112 |
+
frozen_ = true;
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
c10::SymInt get_storage_size(bool before) {
|
| 116 |
+
if (before) {
|
| 117 |
+
return original_storage_size_;
|
| 118 |
+
} else {
|
| 119 |
+
return curr_storage_size_;
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
~FunctionalStorageImpl() override = default;
|
| 124 |
+
|
| 125 |
+
void mark_mutation() {
|
| 126 |
+
mutation_counter_++;
|
| 127 |
+
}
|
| 128 |
+
void mark_mutation_during_no_grad_or_inference_mode() {
|
| 129 |
+
mutation_counter_during_no_grad_or_inference_mode_++;
|
| 130 |
+
}
|
| 131 |
+
void mark_mutation_hidden_from_autograd() {
|
| 132 |
+
mutation_counter_hidden_from_autograd_++;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
bool are_all_mutations_under_no_grad_or_inference_mode() const {
|
| 136 |
+
auto non_autograd_mutations =
|
| 137 |
+
mutation_counter_during_no_grad_or_inference_mode_ +
|
| 138 |
+
mutation_counter_hidden_from_autograd_;
|
| 139 |
+
// The <= is because both counters will technically be incremented, if we
|
| 140 |
+
// perform e.g. a triton kernel mutation under no_grad
|
| 141 |
+
return mutation_counter_ <= non_autograd_mutations;
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
bool are_all_mutations_hidden_from_autograd() const {
|
| 145 |
+
// mutations under no_grad / inference_mode are technically not hidden from
|
| 146 |
+
// autograd - they change the version counter
|
| 147 |
+
return mutation_counter_ <= mutation_counter_hidden_from_autograd_;
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
void mark_inductor_storage_resize(c10::SymInt new_size) {
|
| 151 |
+
inductor_storage_resized_ = true;
|
| 152 |
+
curr_storage_size_ = std::move(new_size);
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
bool was_inductor_storage_resized() {
|
| 156 |
+
return inductor_storage_resized_;
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
private:
|
| 160 |
+
// NB: base_ should always point to a tensor BELOW the current
|
| 161 |
+
// functionalization layer. This is mainly to avoid reference cycles. e.g.
|
| 162 |
+
// given `b = a.view(...)` Both a.storage_ and b.storage_ are a
|
| 163 |
+
// FunctionStorageImpl containing an Walualias, with contains a Tensor
|
| 164 |
+
// `base_`. In this case (where a and b are FunctionalTensorWrapper's), base_
|
| 165 |
+
// should point not to a, but to a's unwrapped value, a.value_` See Note
|
| 166 |
+
// [Functionalization: Walualias Removal] for a diagram that shows this
|
| 167 |
+
// visually.
|
| 168 |
+
at::Tensor base_;
|
| 169 |
+
std::vector<Update> updates_;
|
| 170 |
+
// generation_ gets incremented every time a mutation is queued onto the
|
| 171 |
+
// alias. It is used to determine if a given tensor is "up to date", or if it
|
| 172 |
+
// needs to be regenerated from the alias.
|
| 173 |
+
size_t generation_ = 0;
|
| 174 |
+
// If frozen, no more mutations are allowed on this storage. Once frozen, a
|
| 175 |
+
// storage cannot be unfrozen.
|
| 176 |
+
bool frozen_ = false;
|
| 177 |
+
|
| 178 |
+
// These mutation counters are bumped on the storage
|
| 179 |
+
// whenever a FunctionalTensorWrapper experiences a mutation.
|
| 180 |
+
// When the mutation is under no_grad, or comes from a triton kernel, we also
|
| 181 |
+
// bump the corresponding during_no_grad or hidden_from_autograd counters. Why
|
| 182 |
+
// do we need to detect these two situations separately from "normal" input
|
| 183 |
+
// mutations? (1) "normal" input mutations can mutate autograd metadata like
|
| 184 |
+
// .grad_fn,
|
| 185 |
+
// in which case they need to be replayed outside of the compiled graph
|
| 186 |
+
// (2) "no_grad" input mutations are generally safe to keep in the graph (and
|
| 187 |
+
// compile),
|
| 188 |
+
// but they bump the tensor's VC, so we need to mark_dirty() on the inputs
|
| 189 |
+
// in torch.compile
|
| 190 |
+
// (3) mutations that are fully hidden from autograd (e.g. from a triton
|
| 191 |
+
// kernel)
|
| 192 |
+
// do not mutate any autograd state, and be fully kept in the graph
|
| 193 |
+
// When we detect that an input was mutated, we need to be able to tell if:
|
| 194 |
+
// (1) all of the mutations were from triton kernels
|
| 195 |
+
// (2) all of the mutations were under no_grad
|
| 196 |
+
uint64_t mutation_counter_during_no_grad_or_inference_mode_ = 0;
|
| 197 |
+
uint64_t mutation_counter_ = 0;
|
| 198 |
+
uint64_t mutation_counter_hidden_from_autograd_ = 0;
|
| 199 |
+
|
| 200 |
+
// Used to tell if:
|
| 201 |
+
// (1) There were any storage resizes on a graph input
|
| 202 |
+
// (2) The original/curr storage size tell us if these resizes result in a nop
|
| 203 |
+
bool inductor_storage_resized_ = false;
|
| 204 |
+
c10::SymInt original_storage_size_;
|
| 205 |
+
c10::SymInt curr_storage_size_;
|
| 206 |
+
};
|
| 207 |
+
|
| 208 |
+
} // namespace at::functionalization
|
.venv/lib/python3.11/site-packages/torch/include/ATen/FunctionalTensorWrapper.h
ADDED
|
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/ArrayRef.h>
|
| 5 |
+
#include <ATen/FunctionalStorageImpl.h>
|
| 6 |
+
#include <ATen/core/IListRef.h>
|
| 7 |
+
#include <ATen/core/List.h>
|
| 8 |
+
#include <ATen/core/boxing/BoxedKernel.h>
|
| 9 |
+
#include <ATen/core/boxing/impl/boxing.h>
|
| 10 |
+
#include <ATen/core/dispatch/Dispatcher.h>
|
| 11 |
+
|
| 12 |
+
#include <c10/core/DispatchKey.h>
|
| 13 |
+
|
| 14 |
+
namespace at {
|
| 15 |
+
|
| 16 |
+
// Note [Functionalization Pass In Core]
|
| 17 |
+
// The Functionalization pass is used to remove aliasing from a pytorch program.
|
| 18 |
+
//
|
| 19 |
+
// This is useful for backends that don't support aliasing, like XLA and Vulkan.
|
| 20 |
+
// It's also necessary in order to remove mutation from a program, which is
|
| 21 |
+
// needed in Functorch.
|
| 22 |
+
//
|
| 23 |
+
// Consider this program:
|
| 24 |
+
// a = torch.ones(...)
|
| 25 |
+
// b = a.view(...)
|
| 26 |
+
// b.add_(1)
|
| 27 |
+
//
|
| 28 |
+
// In this program, b is meant to alias with a due to the use of view(). At the
|
| 29 |
+
// end of the program, both a and b are full of 2's. However, backends that
|
| 30 |
+
// don't support aliasing aren't able to correctly implement the view()
|
| 31 |
+
// operator. Instead, they can opt into the Functionalization pass, which will
|
| 32 |
+
// sit between the user and the backend, and provide the necessary aliasing
|
| 33 |
+
// logic.
|
| 34 |
+
//
|
| 35 |
+
// The functionalization pass will turn the above program into a slightly
|
| 36 |
+
// different program that has the same semantics, transparently to the user,
|
| 37 |
+
// that backends like XLA/Vulkan are able to implement a = torch.ones(...) b =
|
| 38 |
+
// a.view_copy(...) # view() replaced with view_copy(). Backends like
|
| 39 |
+
// XLA/Vulkan can implement this! b.add_(1) a.add_(1) # Our functionalization
|
| 40 |
+
// pass machinery knows that a and b are aliased - it applies b's mutation to a
|
| 41 |
+
// too.
|
| 42 |
+
//
|
| 43 |
+
// So, how does the functionalization pass keep track of which tensors are
|
| 44 |
+
// aliased? The pass works by wrapping EVERY tensor in the program inside of a
|
| 45 |
+
// FunctionalTensorWrapper, which knows about its alias'd tensors.
|
| 46 |
+
//
|
| 47 |
+
// See Note [Functionalization: Alias Removal] for details on the aliasing
|
| 48 |
+
// machinery. See Note [Functionalization: Mutation Removal] for details on
|
| 49 |
+
// mutation removal.
|
| 50 |
+
struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
|
| 51 |
+
explicit FunctionalTensorWrapper(const Tensor& value);
|
| 52 |
+
// Additional constructor to create a FunctionalTensorWrapper directly from an
|
| 53 |
+
// underlying tensor that was created from a view. For example, the code b =
|
| 54 |
+
// a.view1() will generate a constructor call to FunctionalTensorWrapper(b, a,
|
| 55 |
+
// view1_meta)
|
| 56 |
+
explicit FunctionalTensorWrapper(
|
| 57 |
+
const Tensor& view_value,
|
| 58 |
+
const FunctionalTensorWrapper* base,
|
| 59 |
+
const functionalization::ViewMeta& meta);
|
| 60 |
+
|
| 61 |
+
// Get the underlying, actual tensor, that doesn't know anything about
|
| 62 |
+
// functionalization.
|
| 63 |
+
const Tensor& value() const {
|
| 64 |
+
return value_;
|
| 65 |
+
};
|
| 66 |
+
// The concept of "level" is only ever important to functorch; it's exposed
|
| 67 |
+
// here as more of a hook for functorch to use.
|
| 68 |
+
int64_t level() const {
|
| 69 |
+
return level_;
|
| 70 |
+
};
|
| 71 |
+
void set_level(int64_t level) {
|
| 72 |
+
level_ = level;
|
| 73 |
+
}
|
| 74 |
+
bool has_metadata_mutation() const {
|
| 75 |
+
return has_metadata_mutation_;
|
| 76 |
+
};
|
| 77 |
+
|
| 78 |
+
void mark_mutation() {
|
| 79 |
+
functional_storage_impl()->mark_mutation();
|
| 80 |
+
}
|
| 81 |
+
// Denotes a mutation that's hidden from autograd,
|
| 82 |
+
// e.g. for the purposes of passing a tensor to a triton kernel
|
| 83 |
+
void mark_mutation_hidden_from_autograd() {
|
| 84 |
+
functional_storage_impl()->mark_mutation_hidden_from_autograd();
|
| 85 |
+
}
|
| 86 |
+
void mark_mutation_during_no_grad_or_inference_mode() {
|
| 87 |
+
functional_storage_impl()->mark_mutation_during_no_grad_or_inference_mode();
|
| 88 |
+
}
|
| 89 |
+
// Are all the mutations happening to the tensor hidden from autograd
|
| 90 |
+
bool are_all_mutations_hidden_from_autograd() const {
|
| 91 |
+
return functional_storage_impl()->are_all_mutations_hidden_from_autograd();
|
| 92 |
+
}
|
| 93 |
+
// Did all mutations happen under no_grad or inference_mode
|
| 94 |
+
// (We also need to ignore mutations fully hidden from autograd here)
|
| 95 |
+
bool are_all_mutations_under_no_grad_or_inference_mode() const {
|
| 96 |
+
return functional_storage_impl()
|
| 97 |
+
->are_all_mutations_under_no_grad_or_inference_mode();
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
void maybe_mark_symbolic(const functionalization::ViewMeta& meta) {
|
| 101 |
+
is_symbolic_ = is_symbolic_ | meta.has_symbolic_inputs;
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
bool is_symbolic() const {
|
| 105 |
+
return is_symbolic_;
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
// Runs the forward_fn of every ViewMeta collected in the current instance
|
| 109 |
+
// to some other base.
|
| 110 |
+
Tensor apply_view_metas(const Tensor& base);
|
| 111 |
+
|
| 112 |
+
// Sync's the underlying tensor with its alias, if it's out of date. This
|
| 113 |
+
// involves two steps: 1) Apply any pending updates/mutations to the alias 2)
|
| 114 |
+
// Replay the views (if any) to regenerate the current tensor off of the
|
| 115 |
+
// updated alias.
|
| 116 |
+
void sync_();
|
| 117 |
+
// Performs step (1) of the sync. This is its own public API because it's
|
| 118 |
+
// needed by view_inplace ops like transpose_. See Note [Functionalization
|
| 119 |
+
// Pass - Inplace View Ops]
|
| 120 |
+
void regenerate_from_base();
|
| 121 |
+
// Performs step (2) of the sync. This is its own public API because it's
|
| 122 |
+
// needed by functorch. functorch wants to make sure that all input tensors to
|
| 123 |
+
// a functionalized program have been properly synced so it can properly
|
| 124 |
+
// propagate mutations to inputs. It can't just call sync_(), because the
|
| 125 |
+
// FunctionalTensorWrapper will look like it has no aliases and sync_ will be
|
| 126 |
+
// a noop. We use the reference count on storage_ to determine if the wrapper
|
| 127 |
+
// is aliased, and by the time functorch is ready to propagate updates to
|
| 128 |
+
// inputs, any intermediate views of the input created by the program will
|
| 129 |
+
// have been deallocated. This function also returns whether or not the base
|
| 130 |
+
// actually had any updates to apply.
|
| 131 |
+
bool apply_updates();
|
| 132 |
+
// Takes the current state of value_ and snapshots it, sending it as a pending
|
| 133 |
+
// update to the alias.
|
| 134 |
+
void commit_update();
|
| 135 |
+
// When any tensor is mutated, the tensor increments its alias's "generation".
|
| 136 |
+
// Separately, each tensor maintains its own "generation" counter, which is
|
| 137 |
+
// used to determine if it's up-to-date with its alias. The act of syncing a
|
| 138 |
+
// tensor will set a tensor's generation equal to its alias's generation.
|
| 139 |
+
bool is_up_to_date() const;
|
| 140 |
+
// Freezes the storage of this tensor, preventing subsequent mutations
|
| 141 |
+
void freeze_storage() const;
|
| 142 |
+
// Every FunctionalTensorWrapper contains a vector<ViewMeta> objects
|
| 143 |
+
// describing the series of view ops that ran to generate the current tensor
|
| 144 |
+
// from the base tensor. This method is used by inplace-view ops like
|
| 145 |
+
// transpose_. It appends a ViewMeta to the existing stack, and refreshes the
|
| 146 |
+
// tensor by replaying the views off of the alias.
|
| 147 |
+
void mutate_view_meta(const at::functionalization::ViewMeta& meta);
|
| 148 |
+
|
| 149 |
+
// Custom implementation of self.set_(src)
|
| 150 |
+
void set__impl(const FunctionalTensorWrapper* other);
|
| 151 |
+
|
| 152 |
+
// Custom implementation of resize_storage_bytes_(self, new_size)
|
| 153 |
+
void storage_resize_(const c10::SymInt& new_size);
|
| 154 |
+
|
| 155 |
+
// Returns whether the current tensor's data was ever mutated
|
| 156 |
+
bool has_data_mutation();
|
| 157 |
+
//
|
| 158 |
+
// Returns whether the current FunctionalTensorWrapper
|
| 159 |
+
// experienced a set_() call.
|
| 160 |
+
bool was_storage_changed() {
|
| 161 |
+
return was_storage_changed_;
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
void set_storage_changed() {
|
| 165 |
+
was_storage_changed_ = true;
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
// A FunctionalTensor is considered a base if its not a view of another
|
| 169 |
+
// tensor.
|
| 170 |
+
bool isBaseTensor() const {
|
| 171 |
+
return view_metas_.empty();
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
c10::SymInt get_storage_size(bool before) {
|
| 175 |
+
return functional_storage_impl()->get_storage_size(before);
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
// Returns whether the FunctionalTensor experienced an
|
| 179 |
+
// untyped_storage().resize_() call
|
| 180 |
+
bool was_inductor_storage_resized() {
|
| 181 |
+
return functional_storage_impl()->was_inductor_storage_resized();
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
// The functionalization pass can be used to remove mutations.
|
| 185 |
+
// It does so by replacing any mutation op with it's corresponding
|
| 186 |
+
// out-of-place op, followed by a call to replace_(). e.g:
|
| 187 |
+
//
|
| 188 |
+
// a.add_(1)
|
| 189 |
+
//
|
| 190 |
+
// will turn into:
|
| 191 |
+
//
|
| 192 |
+
// tmp = a.add(1)
|
| 193 |
+
// a.replace_(tmp)
|
| 194 |
+
//
|
| 195 |
+
// replace_() swaps out the wrapped tensor, value_, with tmp.
|
| 196 |
+
void replace_(const Tensor& other, bool from_lazy_regenerate = false);
|
| 197 |
+
|
| 198 |
+
bool is_multi_output_view() {
|
| 199 |
+
return is_multi_output_view_;
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
// See Note[resize_() in functionalization pass]
|
| 203 |
+
void maybe_replace_storage(const Tensor& other);
|
| 204 |
+
|
| 205 |
+
// Replaces the storage with a new functional storage,
|
| 206 |
+
// and clears the view_metas_ stack.
|
| 207 |
+
// WARNING: Calling this function will sever the aliasing relationship between
|
| 208 |
+
// the current FunctionalTensorWrapper and any of its outstanding aliases.
|
| 209 |
+
// Please only call if you know what you're doing.
|
| 210 |
+
void _unsafe_reset_storage();
|
| 211 |
+
|
| 212 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 213 |
+
const c10::VariableVersion& version_counter,
|
| 214 |
+
bool allow_tensor_metadata_change) const override;
|
| 215 |
+
|
| 216 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 217 |
+
c10::VariableVersion&& version_counter,
|
| 218 |
+
bool allow_tensor_metadata_change) const override;
|
| 219 |
+
|
| 220 |
+
~FunctionalTensorWrapper() override = default;
|
| 221 |
+
|
| 222 |
+
// FunctionalTensorWrapper overrides all custom size/stride function,
|
| 223 |
+
// so that if the inner tensor has a custom implementation
|
| 224 |
+
// we make sure to call that implementation.
|
| 225 |
+
at::IntArrayRef sizes_custom() const override;
|
| 226 |
+
at::IntArrayRef strides_custom() const override;
|
| 227 |
+
int64_t dim_custom() const override;
|
| 228 |
+
int64_t numel_custom() const override;
|
| 229 |
+
bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
|
| 230 |
+
c10::SymIntArrayRef sym_sizes_custom() const override;
|
| 231 |
+
c10::SymInt sym_size_custom(int64_t d) const override;
|
| 232 |
+
c10::SymIntArrayRef sym_strides_custom() const override;
|
| 233 |
+
c10::SymInt sym_storage_offset_custom() const override;
|
| 234 |
+
c10::Device device_custom() const override;
|
| 235 |
+
c10::Layout layout_impl() const override;
|
| 236 |
+
|
| 237 |
+
private:
|
| 238 |
+
const char* tensorimpl_type_name() const override;
|
| 239 |
+
void set_constructor_metadata();
|
| 240 |
+
functionalization::FunctionalStorageImpl* functional_storage_impl() const;
|
| 241 |
+
|
| 242 |
+
// This is used to re-implement shallow_copy_and_detach for
|
| 243 |
+
// FunctionalTensorWrapper. The implementation is identical, but we just need
|
| 244 |
+
// to return a subclass instead of a plain TensorImpl.
|
| 245 |
+
// TODO: maybe it's possible to arrange for that to happen automatically
|
| 246 |
+
// without an override here?
|
| 247 |
+
template <typename VariableVersion>
|
| 248 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
|
| 249 |
+
VariableVersion&& version_counter,
|
| 250 |
+
bool allow_tensor_metadata_change) const;
|
| 251 |
+
|
| 252 |
+
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
|
| 253 |
+
void copy_tensor_metadata_and_refresh(
|
| 254 |
+
const FunctionalTensorWrapper* src_impl,
|
| 255 |
+
FunctionalTensorWrapper* dest_impl,
|
| 256 |
+
const c10::VariableVersion& version_counter,
|
| 257 |
+
bool allow_tensor_metadata_change) const;
|
| 258 |
+
|
| 259 |
+
// Note that value is not taken by reference: internally, the wrapper will
|
| 260 |
+
// change the value tensor that it points to over time.
|
| 261 |
+
Tensor value_;
|
| 262 |
+
int64_t level_{};
|
| 263 |
+
// These two counters are used for identifying
|
| 264 |
+
// whether all the mutations on a given tensor are hidden from autograd or
|
| 265 |
+
// not. If we have an input mutation that is hidden from autograd, then once
|
| 266 |
+
// we convert the input mutation to a copy_() we know it will be safe to hide
|
| 267 |
+
// the copy_() from autograd as well.
|
| 268 |
+
bool has_metadata_mutation_ = false;
|
| 269 |
+
bool is_multi_output_view_ = false;
|
| 270 |
+
// Did the tensor experience a set_() call.
|
| 271 |
+
bool was_storage_changed_ = false;
|
| 272 |
+
// Did the tensor experience any view operation with symbolic int.
|
| 273 |
+
bool is_symbolic_ = false;
|
| 274 |
+
|
| 275 |
+
size_t generation_ = 0;
|
| 276 |
+
std::vector<at::functionalization::ViewMeta> view_metas_;
|
| 277 |
+
|
| 278 |
+
protected:
|
| 279 |
+
static void copy_tensor_metadata(
|
| 280 |
+
const FunctionalTensorWrapper* src_impl,
|
| 281 |
+
FunctionalTensorWrapper* dest_impl,
|
| 282 |
+
const c10::VariableVersion& version_counter,
|
| 283 |
+
bool allow_tensor_metadata_change);
|
| 284 |
+
};
|
| 285 |
+
|
| 286 |
+
// Utility functions for the functionalization pass.
|
| 287 |
+
|
| 288 |
+
namespace functionalization {
|
| 289 |
+
namespace impl {
|
| 290 |
+
|
| 291 |
+
TORCH_API inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper(
|
| 292 |
+
const Tensor& tensor) {
|
| 293 |
+
auto functional_impl =
|
| 294 |
+
static_cast<FunctionalTensorWrapper*>(tensor.unsafeGetTensorImpl());
|
| 295 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_impl != nullptr);
|
| 296 |
+
return functional_impl;
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
TORCH_API bool isBaseTensor(const at::Tensor& tensor);
|
| 300 |
+
|
| 301 |
+
TORCH_API bool isFunctionalTensor(const at::Tensor& tensor);
|
| 302 |
+
TORCH_API bool isFunctionalTensor(const std::optional<Tensor>& t);
|
| 303 |
+
TORCH_API bool isFunctionalTensor(
|
| 304 |
+
const c10::List<std::optional<Tensor>>& t_list);
|
| 305 |
+
TORCH_API bool isFunctionalTensor(ITensorListRef list);
|
| 306 |
+
|
| 307 |
+
TORCH_API Tensor to_functional_tensor(const Tensor& tensor);
|
| 308 |
+
TORCH_API std::optional<Tensor> to_functional_tensor(
|
| 309 |
+
const std::optional<Tensor>& tensor);
|
| 310 |
+
TORCH_API c10::List<std::optional<Tensor>> to_functional_tensor(
|
| 311 |
+
const c10::List<std::optional<Tensor>>& t_list);
|
| 312 |
+
TORCH_API std::vector<Tensor> to_functional_tensor(ITensorListRef t_list);
|
| 313 |
+
|
| 314 |
+
TORCH_API void freeze_functional_tensor(const Tensor& tensor);
|
| 315 |
+
|
| 316 |
+
TORCH_API Tensor
|
| 317 |
+
from_functional_tensor(const Tensor& tensor, bool assert_functional = true);
|
| 318 |
+
TORCH_API std::optional<Tensor> from_functional_tensor(
|
| 319 |
+
const std::optional<Tensor>& t,
|
| 320 |
+
bool assert_functional = true);
|
| 321 |
+
TORCH_API c10::List<std::optional<Tensor>> from_functional_tensor(
|
| 322 |
+
const c10::List<std::optional<Tensor>>& t_list);
|
| 323 |
+
TORCH_API std::vector<Tensor> from_functional_tensor(ITensorListRef t_list);
|
| 324 |
+
|
| 325 |
+
TORCH_API void sync(const at::Tensor& t);
|
| 326 |
+
TORCH_API void sync(const std::optional<Tensor>& t);
|
| 327 |
+
TORCH_API void sync(const c10::List<std::optional<Tensor>>& t_list);
|
| 328 |
+
TORCH_API void sync(ITensorListRef t_list);
|
| 329 |
+
|
| 330 |
+
TORCH_API void replace_(const Tensor& functional_tensor, const Tensor& other);
|
| 331 |
+
TORCH_API void replace_(
|
| 332 |
+
const ITensorListRef functional_tensor,
|
| 333 |
+
ITensorListRef other);
|
| 334 |
+
|
| 335 |
+
TORCH_API void commit_update(const Tensor& functional_tensor);
|
| 336 |
+
TORCH_API void commit_update(ITensorListRef functional_tensor);
|
| 337 |
+
|
| 338 |
+
TORCH_API void unsafe_reset_storage(const Tensor& functional_tensor);
|
| 339 |
+
|
| 340 |
+
TORCH_API void mark_mutation_hidden_from_autograd(
|
| 341 |
+
const Tensor& functional_tensor);
|
| 342 |
+
|
| 343 |
+
TORCH_API bool are_all_mutations_hidden_from_autograd(
|
| 344 |
+
const Tensor& functional_tensor);
|
| 345 |
+
|
| 346 |
+
TORCH_API bool are_all_mutations_under_no_grad_or_inference_mode(
|
| 347 |
+
const Tensor& functional_tensor);
|
| 348 |
+
|
| 349 |
+
// These two methods are XLA-specific logic and are no-ops
|
| 350 |
+
// for the normal functionalization flow.
|
| 351 |
+
TORCH_API void propagate_xla_data(
|
| 352 |
+
const Tensor& functional_tensor,
|
| 353 |
+
const Tensor& other);
|
| 354 |
+
TORCH_API void propagate_xla_data(
|
| 355 |
+
const ITensorListRef functional_tensor,
|
| 356 |
+
ITensorListRef other);
|
| 357 |
+
|
| 358 |
+
TORCH_API void propagate_xla_data_direct(
|
| 359 |
+
const Tensor& tensor,
|
| 360 |
+
const Tensor& other);
|
| 361 |
+
TORCH_API void propagate_xla_data_direct(
|
| 362 |
+
const ITensorListRef tensor,
|
| 363 |
+
ITensorListRef other);
|
| 364 |
+
|
| 365 |
+
Tensor create_functional_tensor_with_view_meta(
|
| 366 |
+
const Tensor& view_to_wrap,
|
| 367 |
+
const Tensor& base,
|
| 368 |
+
functionalization::ViewMeta meta,
|
| 369 |
+
int64_t out_idx = 0);
|
| 370 |
+
std::vector<Tensor> create_functional_tensor_with_view_meta(
|
| 371 |
+
ITensorListRef view_to_wrap,
|
| 372 |
+
const Tensor& base,
|
| 373 |
+
const functionalization::ViewMeta& meta);
|
| 374 |
+
|
| 375 |
+
void mutate_view_meta(
|
| 376 |
+
const Tensor& self,
|
| 377 |
+
const functionalization::ViewMeta& meta);
|
| 378 |
+
|
| 379 |
+
void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out);
|
| 380 |
+
void set_sizes_strides_offset(
|
| 381 |
+
const std::vector<Tensor>& outs,
|
| 382 |
+
const std::vector<Tensor>& meta_outs);
|
| 383 |
+
|
| 384 |
+
// ~~~~~ TLS used in functionalization ~~~~~
|
| 385 |
+
|
| 386 |
+
TORCH_API bool getFunctionalizationReapplyViewsTLS();
|
| 387 |
+
TORCH_API void setFunctionalizationReapplyViewsTLS(bool reapply_views);
|
| 388 |
+
|
| 389 |
+
class TORCH_API FunctionalizationReapplyViewsGuard {
|
| 390 |
+
public:
|
| 391 |
+
FunctionalizationReapplyViewsGuard(bool reapply_views)
|
| 392 |
+
: prev_(getFunctionalizationReapplyViewsTLS()) {
|
| 393 |
+
setFunctionalizationReapplyViewsTLS(reapply_views);
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
~FunctionalizationReapplyViewsGuard() {
|
| 397 |
+
setFunctionalizationReapplyViewsTLS(prev_);
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
FunctionalizationReapplyViewsGuard(
|
| 401 |
+
const FunctionalizationReapplyViewsGuard&) = delete;
|
| 402 |
+
FunctionalizationReapplyViewsGuard operator=(
|
| 403 |
+
const FunctionalizationReapplyViewsGuard&) = delete;
|
| 404 |
+
FunctionalizationReapplyViewsGuard(FunctionalizationReapplyViewsGuard&&) =
|
| 405 |
+
delete;
|
| 406 |
+
FunctionalizationReapplyViewsGuard operator=(
|
| 407 |
+
FunctionalizationReapplyViewsGuard&&) = delete;
|
| 408 |
+
|
| 409 |
+
private:
|
| 410 |
+
bool prev_;
|
| 411 |
+
};
|
| 412 |
+
|
| 413 |
+
} // namespace impl
|
| 414 |
+
|
| 415 |
+
// Helper function to call an out-of-place composite aten kernel that may use
|
| 416 |
+
// mutations / views internally, and functionalize them.
|
| 417 |
+
TORCH_API void functionalize_op_helper(
|
| 418 |
+
const c10::OperatorHandle& op,
|
| 419 |
+
torch::jit::Stack* stack);
|
| 420 |
+
|
| 421 |
+
template <class Op, bool symint, class ReturnType, class... ParameterTypes>
|
| 422 |
+
struct _functionalize_aten_op final {};
|
| 423 |
+
|
| 424 |
+
template <class Op, bool symint, class ReturnType, class... ParameterTypes>
|
| 425 |
+
struct _functionalize_aten_op<Op, symint, ReturnType(ParameterTypes...)> final {
|
| 426 |
+
static ReturnType call(
|
| 427 |
+
typename c10::maybe_keep_symint<symint, ParameterTypes>::type... args) {
|
| 428 |
+
using FuncType = ReturnType(
|
| 429 |
+
typename c10::maybe_keep_symint<symint, ParameterTypes>::type...);
|
| 430 |
+
auto op = c10::Dispatcher::singleton()
|
| 431 |
+
.findSchemaOrThrow(
|
| 432 |
+
(const char*)Op::name, (const char*)Op::overload_name)
|
| 433 |
+
.typed<FuncType>();
|
| 434 |
+
|
| 435 |
+
return c10::impl::BoxedKernelWrapper<FuncType>::call(
|
| 436 |
+
c10::BoxedKernel::makeFromFunction<functionalize_op_helper>(),
|
| 437 |
+
op,
|
| 438 |
+
// BoxedKernelWrapper knows to ignore this keyset argument,
|
| 439 |
+
// because functionalize_op_helper doesn't take in a DispatchKeySet
|
| 440 |
+
c10::DispatchKeySet(),
|
| 441 |
+
args...);
|
| 442 |
+
}
|
| 443 |
+
};
|
| 444 |
+
|
| 445 |
+
template <class Op>
|
| 446 |
+
using functionalize_aten_op =
|
| 447 |
+
_functionalize_aten_op<Op, false, typename Op::schema>;
|
| 448 |
+
|
| 449 |
+
template <class Op>
|
| 450 |
+
using functionalize_aten_op_symint =
|
| 451 |
+
_functionalize_aten_op<Op, true, typename Op::schema>;
|
| 452 |
+
|
| 453 |
+
} // namespace functionalization
|
| 454 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/InferSize.h
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/DimVector.h>
|
| 4 |
+
#include <c10/core/ScalarType.h>
|
| 5 |
+
#include <c10/core/SymIntArrayRef.h>
|
| 6 |
+
#include <c10/util/DimVector.h>
|
| 7 |
+
#include <optional>
|
| 8 |
+
#include <sstream>
|
| 9 |
+
#include <vector>
|
| 10 |
+
|
| 11 |
+
namespace at {
|
| 12 |
+
|
| 13 |
+
// Infers the size of a dim with size -1, if it exists. Also checks that new
|
| 14 |
+
// shape is compatible with the number of elements.
|
| 15 |
+
//
|
| 16 |
+
// templated to handle std::vector<int64_t> and DimVector use cases, see
|
| 17 |
+
// below
|
| 18 |
+
//
|
| 19 |
+
template <typename InputArrayRef, typename NumelType, typename ResultVec>
|
| 20 |
+
inline void infer_size_impl(
|
| 21 |
+
InputArrayRef shape,
|
| 22 |
+
NumelType numel,
|
| 23 |
+
ResultVec& res) {
|
| 24 |
+
NumelType newsize = 1;
|
| 25 |
+
// N.B. this is an index, not a sym dim!
|
| 26 |
+
std::optional<int64_t> infer_dim;
|
| 27 |
+
for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) {
|
| 28 |
+
if (shape[dim] == -1) {
|
| 29 |
+
if (infer_dim) {
|
| 30 |
+
throw std::runtime_error("only one dimension can be inferred");
|
| 31 |
+
}
|
| 32 |
+
infer_dim = dim;
|
| 33 |
+
} else if (shape[dim] >= 0) {
|
| 34 |
+
newsize *= shape[dim];
|
| 35 |
+
} else {
|
| 36 |
+
AT_ERROR("invalid shape dimension ", shape[dim]);
|
| 37 |
+
}
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, newsize)) ||
|
| 41 |
+
(infer_dim && newsize > 0 && numel % newsize == 0)) {
|
| 42 |
+
if (infer_dim) {
|
| 43 |
+
// We have a degree of freedom here to select the dimension size; follow
|
| 44 |
+
// NumPy semantics and just bail. However, a nice error message is needed
|
| 45 |
+
// because users often use `view` as a way to flatten & unflatten
|
| 46 |
+
// dimensions and will otherwise be confused why
|
| 47 |
+
// empty_tensor.view( 0, 0)
|
| 48 |
+
// works yet
|
| 49 |
+
// empty_tensor.view(-1, 0)
|
| 50 |
+
// doesn't.
|
| 51 |
+
TORCH_CHECK(
|
| 52 |
+
newsize != 0,
|
| 53 |
+
"cannot reshape tensor of 0 elements into shape ",
|
| 54 |
+
shape,
|
| 55 |
+
" because the unspecified dimension size -1 can be any "
|
| 56 |
+
"value and is ambiguous");
|
| 57 |
+
res[*infer_dim] = numel / newsize;
|
| 58 |
+
}
|
| 59 |
+
return;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
std::ostringstream ss;
|
| 63 |
+
ss << "shape '" << shape << "' is invalid for input of size " << numel;
|
| 64 |
+
throw std::runtime_error(ss.str());
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
inline std::vector<int64_t> infer_size(IntArrayRef shape, int64_t numel) {
|
| 68 |
+
auto res = shape.vec();
|
| 69 |
+
infer_size_impl(shape, numel, res);
|
| 70 |
+
return res;
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
inline at::DimVector infer_size_dv(IntArrayRef shape, int64_t numel) {
|
| 74 |
+
auto res = at::DimVector(shape);
|
| 75 |
+
infer_size_impl(shape, numel, res);
|
| 76 |
+
return res;
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
inline at::SymDimVector infer_size_dv(
|
| 80 |
+
c10::SymIntArrayRef shape,
|
| 81 |
+
c10::SymInt numel) {
|
| 82 |
+
auto res = at::SymDimVector(shape);
|
| 83 |
+
infer_size_impl<c10::SymIntArrayRef, c10::SymInt, at::SymDimVector>(
|
| 84 |
+
shape, std::move(numel), res);
|
| 85 |
+
return res;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/InitialTensorOptions.h
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/TensorOptions.h>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
|
| 7 |
+
// Represents the initial TensorOptions, before the "defaults" are ever changed.
|
| 8 |
+
// This is designed to be used in library code, where the explicit devices,
|
| 9 |
+
// dtypes, etc. are known. NOTE: this is not a stable API.
|
| 10 |
+
inline TensorOptions initialTensorOptions() {
|
| 11 |
+
return TensorOptions(kCPU).dtype(kFloat).layout(kStrided).requires_grad(
|
| 12 |
+
false);
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/Layout.h
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <c10/core/Layout.h>
|
.venv/lib/python3.11/site-packages/torch/include/ATen/LegacyBatchedFallback.h
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/ATen.h>
|
| 3 |
+
#include <ATen/core/op_registration/op_registration.h>
|
| 4 |
+
#include <torch/library.h>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
|
| 8 |
+
// If an operator doesn't have a batching rule implemented then we fallback
|
| 9 |
+
// to this implementation. The fallback only works on out-of-place operators
|
| 10 |
+
// that return only tensors with new memory. (e.g., no in-place operators, no
|
| 11 |
+
// view operations).
|
| 12 |
+
//
|
| 13 |
+
// The fallback effectively takes all of the BatchedTensors in `stack`, slices
|
| 14 |
+
// them, and runs `op` on all of the corresponding slices to produce slices
|
| 15 |
+
// of the outputs. The output slices then get `torch.stack`ed to create the
|
| 16 |
+
// final returns.
|
| 17 |
+
//
|
| 18 |
+
// The performance of the fallback is not very good because it introduces an
|
| 19 |
+
// extra copy from stacking the sliced outputs. Because of this, we prefer to
|
| 20 |
+
// write batching rules for operators whenever possible.
|
| 21 |
+
void batchedTensorForLoopFallback(
|
| 22 |
+
const c10::OperatorHandle& op,
|
| 23 |
+
torch::jit::Stack* stack);
|
| 24 |
+
|
| 25 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/LegacyBatchedTensorImpl.h
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <bitset>
|
| 4 |
+
|
| 5 |
+
#include <ATen/ArrayRef.h>
|
| 6 |
+
#include <ATen/SmallVector.h>
|
| 7 |
+
#include <ATen/Tensor.h>
|
| 8 |
+
|
| 9 |
+
namespace at {
|
| 10 |
+
|
| 11 |
+
// We assume this in a few other places in the codebase,
|
| 12 |
+
// but there isn't a centralized definition.
|
| 13 |
+
constexpr int64_t kVmapMaxTensorDims = 64;
|
| 14 |
+
|
| 15 |
+
// The valid vmap levels range from [0, 64). This effectively means that we
|
| 16 |
+
// support a maximum of 64 nested vmaps.
|
| 17 |
+
constexpr int64_t kVmapNumLevels = 64;
|
| 18 |
+
|
| 19 |
+
// Store this number of elements of BatchDims on the stack. Most people will
|
| 20 |
+
// probably use <= 5 nested vmaps, but adjust this number as necessary.
|
| 21 |
+
constexpr int64_t kBatchDimsStackSize = 5;
|
| 22 |
+
|
| 23 |
+
// a BatchDim represents a "private" dimension on a Tensor created inside of
|
| 24 |
+
// vmap. It is a (level, dim) tuple, with the `dim` indicating which dimension
|
| 25 |
+
// is being vmap'ed over and the `level` being an identifier for which vmap
|
| 26 |
+
// said dimension was created inside. The `dim` corresponds to a "physical
|
| 27 |
+
// dim" - it is a dimension index on the underlying physical tensor that is
|
| 28 |
+
// being vmapped over.
|
| 29 |
+
struct BatchDim {
|
| 30 |
+
BatchDim(int64_t level, int64_t dim) : dim_(dim), level_(level) {}
|
| 31 |
+
int64_t dim() const {
|
| 32 |
+
return dim_;
|
| 33 |
+
}
|
| 34 |
+
int64_t level() const {
|
| 35 |
+
return level_;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
private:
|
| 39 |
+
int64_t dim_;
|
| 40 |
+
int64_t level_;
|
| 41 |
+
};
|
| 42 |
+
|
| 43 |
+
using BatchDims = SmallVector<BatchDim, kBatchDimsStackSize>;
|
| 44 |
+
using BatchDimsRef = ArrayRef<BatchDim>;
|
| 45 |
+
|
| 46 |
+
// A BatchedTensorImpl holds an underlying Tensor and a list of BatchDim
|
| 47 |
+
// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
|
| 48 |
+
// BatchedTensorImpl.
|
| 49 |
+
//
|
| 50 |
+
// The batch dimensions are treated as being "private"; they are not
|
| 51 |
+
// user-visible. For example, in the following Tensor,
|
| 52 |
+
// bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2, dim=1)])
|
| 53 |
+
// dimensions 0 and 1 are batch dimensions.
|
| 54 |
+
//
|
| 55 |
+
// bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
|
| 56 |
+
// dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7)
|
| 57 |
+
// tensor.
|
| 58 |
+
struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
|
| 59 |
+
explicit BatchedTensorImpl(Tensor value, BatchDims bdims);
|
| 60 |
+
|
| 61 |
+
// Returns a reference to BatchDims that represent which dimensions of this
|
| 62 |
+
// tensor are private.
|
| 63 |
+
BatchDimsRef bdims() const {
|
| 64 |
+
return bdims_;
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
// BatchedTensorImpl wraps a Tensor
|
| 68 |
+
const Tensor& value() const {
|
| 69 |
+
return value_;
|
| 70 |
+
};
|
| 71 |
+
|
| 72 |
+
// Given a public dimension index, return the dimension index in the
|
| 73 |
+
// underlying value() tensor. For example, if we have
|
| 74 |
+
// bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2,
|
| 75 |
+
// dim=2)])
|
| 76 |
+
// bt.actualDim(0) -> 1
|
| 77 |
+
// bt.actualDim(1) -> 3
|
| 78 |
+
// bt.actualDim(2) -> Error
|
| 79 |
+
int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
|
| 80 |
+
|
| 81 |
+
// We have to override this because we opted into CustomStrides
|
| 82 |
+
IntArrayRef strides_custom() const override;
|
| 83 |
+
// Override a bunch of methods inherited from TensorImpl to return error
|
| 84 |
+
// messages.
|
| 85 |
+
bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
|
| 86 |
+
void set_size(int64_t dim, int64_t new_size) override;
|
| 87 |
+
void set_stride(int64_t dim, int64_t new_stride) override;
|
| 88 |
+
void set_storage_offset(int64_t storage_offset) override;
|
| 89 |
+
#ifdef DEBUG
|
| 90 |
+
bool has_storage() const override;
|
| 91 |
+
#endif
|
| 92 |
+
|
| 93 |
+
private:
|
| 94 |
+
// see NOTE: [BatchedTensorImpl levels invariant]
|
| 95 |
+
void checkInvariants() const;
|
| 96 |
+
const char* tensorimpl_type_name() const override;
|
| 97 |
+
|
| 98 |
+
Tensor value_;
|
| 99 |
+
|
| 100 |
+
// Note: [BatchedTensorImpl levels invariant]
|
| 101 |
+
// There is an invariant that the BatchDims must be stored in increasing
|
| 102 |
+
// `level` order. That is, for i < j, bdims_[i].level must be less than
|
| 103 |
+
// bdims_[j].level.
|
| 104 |
+
BatchDims bdims_;
|
| 105 |
+
};
|
| 106 |
+
|
| 107 |
+
// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
|
| 108 |
+
// BatchedTensorImpl.
|
| 109 |
+
inline bool isBatchedTensor(const Tensor& tensor) {
|
| 110 |
+
return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::Batched);
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
// It is unsafe to call this on a Tensor that is not backed by a
|
| 114 |
+
// BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible.
|
| 115 |
+
inline BatchedTensorImpl* unsafeGetBatchedImpl(const Tensor& tensor) {
|
| 116 |
+
return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
inline BatchedTensorImpl* maybeGetBatchedImpl(const Tensor& tensor) {
|
| 120 |
+
if (!isBatchedTensor(tensor)) {
|
| 121 |
+
return nullptr;
|
| 122 |
+
}
|
| 123 |
+
return unsafeGetBatchedImpl(tensor);
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
// Returns a bitset. If bit i is set, then that means dim i is a batchdim.
|
| 127 |
+
inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset(
|
| 128 |
+
BatchDimsRef bdims) {
|
| 129 |
+
std::bitset<kVmapMaxTensorDims> is_bdim;
|
| 130 |
+
for (const auto& bdim : bdims) {
|
| 131 |
+
is_bdim.set(bdim.dim());
|
| 132 |
+
}
|
| 133 |
+
return is_bdim;
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
// Creates a bitset for all of the levels present in `bdims`
|
| 137 |
+
inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(BatchDimsRef bdims) {
|
| 138 |
+
std::bitset<kVmapNumLevels> result;
|
| 139 |
+
for (const auto& bdim : bdims) {
|
| 140 |
+
result.set(bdim.level());
|
| 141 |
+
}
|
| 142 |
+
return result;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) {
|
| 146 |
+
out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ")";
|
| 147 |
+
return out;
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
// Use this to construct a BatchedTensor from a regular Tensor
|
| 151 |
+
TORCH_API Tensor makeBatched(const Tensor& tensor, BatchDims bdims);
|
| 152 |
+
|
| 153 |
+
// Adds a batch dim to `tensor`, returning a BatchedTensor
|
| 154 |
+
TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim);
|
| 155 |
+
|
| 156 |
+
// Checks if an inplace operation on self and other is "vmap compatible".
|
| 157 |
+
// See NOTE: [vmap-incompatible in-place operations] for the definition of this.
|
| 158 |
+
TORCH_API bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other);
|
| 159 |
+
|
| 160 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/LegacyVmapMode.h
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/impl/LocalDispatchKeySet.h>
|
| 4 |
+
|
| 5 |
+
namespace at::impl {
|
| 6 |
+
|
| 7 |
+
// VmapMode contains a thread local count of how many nested vmaps
|
| 8 |
+
// we are currently inside. That number is known as the `vmap level`.
|
| 9 |
+
// VmapMode is used in the implementation of the Python `torch.vmap` API.
|
| 10 |
+
//
|
| 11 |
+
// NOTE: this is NOT the c++ api for torch.vmap. That doesn't exist yet.
|
| 12 |
+
|
| 13 |
+
struct TORCH_API VmapMode {
|
| 14 |
+
// Returns the vmap level, aka the count of how many nested vmaps we're in.
|
| 15 |
+
static int64_t current_vmap_level();
|
| 16 |
+
|
| 17 |
+
// Increment the count of nested vmaps. If this causes the vmap level to be
|
| 18 |
+
// greater than 0, then it enables DispatchKey::VmapMode on all tensors.
|
| 19 |
+
static int64_t increment_nesting();
|
| 20 |
+
|
| 21 |
+
// Decrements the count of nested vmaps. If this causes the vmap level to be
|
| 22 |
+
// equal to 0, then it disables DispatchKey::VmapMode on all tensors.
|
| 23 |
+
static int64_t decrement_nesting();
|
| 24 |
+
};
|
| 25 |
+
|
| 26 |
+
} // namespace at::impl
|
.venv/lib/python3.11/site-packages/torch/include/ATen/LegacyVmapTransforms.h
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/LegacyBatchedTensorImpl.h>
|
| 4 |
+
#include <ATen/core/IListRef.h>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
|
| 8 |
+
// This file contains abstractions used for transforming *logical* vmap
|
| 9 |
+
// arguments into *physical* arguments. (Keep reading for definitions of these
|
| 10 |
+
// terms).
|
| 11 |
+
|
| 12 |
+
// NOTE: [Logical vs physical args]
|
| 13 |
+
// Consider the following vmap.
|
| 14 |
+
// vmap(vmap(func, in_dims=(2,)), in_dims=(0,))(torch.ones(2, 3, 4))
|
| 15 |
+
// This would produce a BatchedTensor wrapping a Tensor of size [2, 3, 4],
|
| 16 |
+
// with batch dims 0 and 2:
|
| 17 |
+
// BatchedTensor(ones(2, 3, 4), bdims=[(lvl=1,dim=0),(lvl=2,dim=2)])
|
| 18 |
+
//
|
| 19 |
+
// We say the *logical* view of the tensor has size [3] -- tensors inside
|
| 20 |
+
// `func` appear to have size [3].
|
| 21 |
+
// However, the *physical* underlying tensor (the one passed to vmap) has size
|
| 22 |
+
// [2, 3, 4].
|
| 23 |
+
//
|
| 24 |
+
// This notion of logical vs physical also extends to non-tensor arguments.
|
| 25 |
+
// Consider the previous tensor; let's assume the user called
|
| 26 |
+
// `torch.sum(tensor, dim=0)` inside of `func`. Then the logical
|
| 27 |
+
// dimension they are reducing over is dim 0 but the physical dim is dim 1
|
| 28 |
+
// (the first non-batch dimension)
|
| 29 |
+
|
| 30 |
+
// Forward declared; see NOTE: [What is a VmapPhysicalView?]
|
| 31 |
+
struct VmapPhysicalView;
|
| 32 |
+
|
| 33 |
+
// Most PyTorch operators take 4 or fewer inputs.
|
| 34 |
+
constexpr int64_t kVmapTransformStaticInputSize = 4;
|
| 35 |
+
using VmapPhysicalViewVec =
|
| 36 |
+
SmallVector<VmapPhysicalView, kVmapTransformStaticInputSize>;
|
| 37 |
+
|
| 38 |
+
// Pytorch generally advertises good performance for <= 5 dims.
|
| 39 |
+
// (see ATen/core/DimVector.h). We add a few extra dims (~3) for vmap
|
| 40 |
+
// dimensions to get 8. Adjust this number as necessary
|
| 41 |
+
constexpr int64_t kVmapStaticDimVecSize = 8;
|
| 42 |
+
using VmapDimVector = SmallVector<int64_t, kVmapStaticDimVecSize>;
|
| 43 |
+
using VmapSymDimVector = SmallVector<c10::SymInt, kVmapStaticDimVecSize>;
|
| 44 |
+
|
| 45 |
+
// NOTE: [What is an VmapTransform?]
|
| 46 |
+
// An *VmapTransform* converts logical views of tensors to physical views.
|
| 47 |
+
//
|
| 48 |
+
// Batching rules use VmapTransforms to convert logical arguments to
|
| 49 |
+
// physical arguments, then call one or more at:: operator that handles the
|
| 50 |
+
// physical arguments, and then converts the physical result back to a logical
|
| 51 |
+
// argument.
|
| 52 |
+
|
| 53 |
+
// VmapTransform for operators that take tensors with multiple batch dims.
|
| 54 |
+
// Given one or more logical views on Tensors, `logicalToPhysical`
|
| 55 |
+
// permutes all of the batch dims to the front of the tensor, aligns
|
| 56 |
+
// and expands the batch dims to match each other (according to their `level`),
|
| 57 |
+
// and returns a VmapPhysicalView on the tensor(s).
|
| 58 |
+
struct TORCH_API MultiBatchVmapTransform {
|
| 59 |
+
static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor);
|
| 60 |
+
static VmapPhysicalViewVec logicalToPhysical(ITensorListRef logical_tensors);
|
| 61 |
+
};
|
| 62 |
+
|
| 63 |
+
// VmapTransform for operators that broadcast all inputs.
|
| 64 |
+
// Given some logical views on Tensors, `logicalToPhysical`:
|
| 65 |
+
// - permutes all of the batch dims to the front of the tensors
|
| 66 |
+
// - aligns all the batch dims to the collective levels of all of the tensors.
|
| 67 |
+
// If a tensor does not have a batch dim for a vmap level, then it receives
|
| 68 |
+
// a size-one dimension for said level.
|
| 69 |
+
// - aligns the non-batch dims to have the same dimensionality, adding extra
|
| 70 |
+
// size-1 dimensions in between the batch dimensions and the non-batch
|
| 71 |
+
// dimensions so that the batch dimensions are lined up from the right.
|
| 72 |
+
//
|
| 73 |
+
// For example: given inputs of size (B, 2) and (B, 3, 2) where B is the batch
|
| 74 |
+
// dimension, BroadcastingVmapTransform returns VmapPhysicalViews that wrap
|
| 75 |
+
// tensors of size (B, 1, 2) and (B, 3, 2).
|
| 76 |
+
//
|
| 77 |
+
// Given inputs of size (B, 2) and (2,), BroadcastingVmapTransform returns
|
| 78 |
+
// VmapPhysicalViews wrapping tensors of size (B, 2) and (1, 2). We don't
|
| 79 |
+
// actually *need* to return a tensor of size (1, 2) for the second tensor
|
| 80 |
+
// because the broadcasting operation takes care of that for us, but we do
|
| 81 |
+
// it anyways to keep things simple.
|
| 82 |
+
struct TORCH_API BroadcastingVmapTransform {
|
| 83 |
+
static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors);
|
| 84 |
+
};
|
| 85 |
+
|
| 86 |
+
// Forward declared, if you're reading this file head to toe, don't worry about
|
| 87 |
+
// it yet.
|
| 88 |
+
struct VmapPhysicalToLogicalMap;
|
| 89 |
+
|
| 90 |
+
// NOTE: [What is a VmapPhysicalView?]
|
| 91 |
+
// VmapPhysicalView represents a physical view on a Tensor.
|
| 92 |
+
//
|
| 93 |
+
// One can use it to further convert logical dimension indices, logical shapes,
|
| 94 |
+
// and more to their physical variants, or convert a new (physical) tensor into
|
| 95 |
+
// a logical BatchedTensor. (TODO(rzou): some of these are not yet implemented).
|
| 96 |
+
//
|
| 97 |
+
// VmapPhysicalView stores a physical tensor with all of its batch dimensions at
|
| 98 |
+
// the front and some levels that correspond to said batch dimensions.
|
| 99 |
+
//
|
| 100 |
+
// The levels bitset specifies which vmap levels correspond to the batch
|
| 101 |
+
// dimensions at the front of the tensor. In particular, the number of set bits
|
| 102 |
+
// corresponds to the number of batch dimensions on `tensor` and the rightmost
|
| 103 |
+
// bit of `levels` specifies the maximum number of nested vmaps we are in at
|
| 104 |
+
// this point in time.
|
| 105 |
+
// For example, given:
|
| 106 |
+
// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3})
|
| 107 |
+
//
|
| 108 |
+
// Rightmost bit of `levels` is 3 indicating the number of nested vmaps less
|
| 109 |
+
// than or equal to 3.
|
| 110 |
+
// bitset: 010100
|
| 111 |
+
// ^
|
| 112 |
+
// |
|
| 113 |
+
// levels: 012345
|
| 114 |
+
struct TORCH_API VmapPhysicalView {
|
| 115 |
+
VmapPhysicalView(Tensor&& tensor, std::bitset<kVmapNumLevels> levels)
|
| 116 |
+
: levels_(levels), tensor_(std::move(tensor)) {
|
| 117 |
+
TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor_));
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
Tensor& tensor() {
|
| 121 |
+
return tensor_;
|
| 122 |
+
}
|
| 123 |
+
const Tensor& tensor() const {
|
| 124 |
+
return tensor_;
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
// Maps logical dim indices to physical dim indices. Also does dim wrapping.
|
| 128 |
+
//
|
| 129 |
+
// For example, given:
|
| 130 |
+
// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5), levels={1, 3})
|
| 131 |
+
//
|
| 132 |
+
// Then physical_view.getPhysicalDims({0, 1}) returns {2, 3}.
|
| 133 |
+
// This is because the size of levels tell us that the first two dimensions
|
| 134 |
+
// of `tensor_` are batch dimensions, so a logical dim of `n` is actually
|
| 135 |
+
// a physical dim of `n + 2`.
|
| 136 |
+
VmapDimVector getPhysicalDims(OptionalIntArrayRef logical_dims) const;
|
| 137 |
+
int64_t getPhysicalDim(int64_t logical_dim) const;
|
| 138 |
+
|
| 139 |
+
// Returns a VmapPhysicalToLogicalMap object. This can be used for
|
| 140 |
+
// mapping a physical tensor to a new logical tensor (BatchedTensor)
|
| 141 |
+
VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const;
|
| 142 |
+
|
| 143 |
+
// Maps a logical shape to a physical shape by pre-pending the batch
|
| 144 |
+
// sizes to the logical shape.
|
| 145 |
+
VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const;
|
| 146 |
+
|
| 147 |
+
int64_t numBatchDims() const;
|
| 148 |
+
|
| 149 |
+
private:
|
| 150 |
+
int64_t numLogicalDims() const;
|
| 151 |
+
|
| 152 |
+
std::bitset<kVmapNumLevels> levels_;
|
| 153 |
+
Tensor tensor_;
|
| 154 |
+
};
|
| 155 |
+
|
| 156 |
+
// Convenience struct used for mapping a physical tensor (a non-BatchedTensor)
|
| 157 |
+
// to a logical one (BatchedTensor). It holds some levels that are used to do
|
| 158 |
+
// the mapping and assumes that the batch dimensions in the physical tensor all
|
| 159 |
+
// occur at the front of the tensor.
|
| 160 |
+
struct TORCH_API VmapPhysicalToLogicalMap {
|
| 161 |
+
VmapPhysicalToLogicalMap(std::bitset<kVmapNumLevels> levels)
|
| 162 |
+
: levels_(levels) {}
|
| 163 |
+
|
| 164 |
+
// Maps a physical tensor to a new logical tensor (BatchedTensor).
|
| 165 |
+
// Assumes that all of the "batch dimensions" are at the front
|
| 166 |
+
// of the physical tensor. For example, given:
|
| 167 |
+
// - x = rank-4 Tensor with size 2, 3, 5, 7
|
| 168 |
+
// - levels = (2, 4)
|
| 169 |
+
// Returns:
|
| 170 |
+
// - BatchedTensor(x, bdims=[(dim=0,lvl=2), (dim=1, lvl=4)])
|
| 171 |
+
Tensor apply(const Tensor& physical_tensor) const;
|
| 172 |
+
|
| 173 |
+
// Given a vector of physical tensors,
|
| 174 |
+
// 1. maps each tensor to a new logical tensor. Assumes that all of the
|
| 175 |
+
// "batch dimensions" are at the front of the physical tensors.
|
| 176 |
+
// 2. stores the new logical tensors back into the passed-in vector. This is
|
| 177 |
+
// to avoid additional dynamic allocations.
|
| 178 |
+
void applyInplace(std::vector<Tensor>& physical_tensors) const;
|
| 179 |
+
|
| 180 |
+
std::bitset<kVmapNumLevels> levels_;
|
| 181 |
+
};
|
| 182 |
+
|
| 183 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/MapAllocator.h
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/Allocator.h>
|
| 4 |
+
#include <c10/util/string_view.h>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
|
| 8 |
+
enum MappedAllocatorModes {
|
| 9 |
+
ALLOCATOR_MAPPED_SHARED = 1,
|
| 10 |
+
ALLOCATOR_MAPPED_SHAREDMEM = 2,
|
| 11 |
+
ALLOCATOR_MAPPED_EXCLUSIVE = 4,
|
| 12 |
+
ALLOCATOR_MAPPED_NOCREATE = 8,
|
| 13 |
+
ALLOCATOR_MAPPED_KEEPFD = 16,
|
| 14 |
+
ALLOCATOR_MAPPED_FROMFD = 32,
|
| 15 |
+
ALLOCATOR_MAPPED_UNLINK = 64
|
| 16 |
+
};
|
| 17 |
+
|
| 18 |
+
// Sentinel value/type to help distinguish the file descriptor constructor from
|
| 19 |
+
// the non-file descriptor constructor
|
| 20 |
+
enum WithFd { WITH_FD };
|
| 21 |
+
|
| 22 |
+
TORCH_API std::string NewProcessWideShmHandle();
|
| 23 |
+
|
| 24 |
+
class TORCH_API MapAllocator {
|
| 25 |
+
public:
|
| 26 |
+
MapAllocator(c10::string_view filename, int flags, size_t size);
|
| 27 |
+
MapAllocator(
|
| 28 |
+
WithFd,
|
| 29 |
+
c10::string_view filename,
|
| 30 |
+
int fd,
|
| 31 |
+
int flags,
|
| 32 |
+
size_t size);
|
| 33 |
+
MapAllocator(const MapAllocator&) = delete;
|
| 34 |
+
MapAllocator& operator=(const MapAllocator&) = delete;
|
| 35 |
+
MapAllocator(MapAllocator&&) = delete;
|
| 36 |
+
MapAllocator& operator=(MapAllocator&&) = delete;
|
| 37 |
+
|
| 38 |
+
const char* filename() const {
|
| 39 |
+
return filename_.c_str();
|
| 40 |
+
}
|
| 41 |
+
int fd() const {
|
| 42 |
+
#ifdef _WIN32
|
| 43 |
+
TORCH_CHECK(false, "MapAllocator::fd() is unsupported on Windows");
|
| 44 |
+
#else
|
| 45 |
+
return fd_;
|
| 46 |
+
#endif
|
| 47 |
+
}
|
| 48 |
+
ptrdiff_t size() const {
|
| 49 |
+
return size_;
|
| 50 |
+
}
|
| 51 |
+
// Return a pointer to the actual data for this allocator
|
| 52 |
+
// (in the case of the refcounted allocator, this is offset
|
| 53 |
+
// from the base pointer.)
|
| 54 |
+
virtual void* data() const {
|
| 55 |
+
return base_ptr_;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
int flags() const {
|
| 59 |
+
return flags_;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
static MapAllocator* fromDataPtr(const at::DataPtr&);
|
| 63 |
+
static at::DataPtr makeDataPtr(
|
| 64 |
+
c10::string_view filename,
|
| 65 |
+
int flags,
|
| 66 |
+
size_t size,
|
| 67 |
+
size_t* actual_size_out);
|
| 68 |
+
static at::DataPtr makeDataPtr(
|
| 69 |
+
WithFd,
|
| 70 |
+
const char* filename,
|
| 71 |
+
int fd,
|
| 72 |
+
int flags,
|
| 73 |
+
size_t size,
|
| 74 |
+
size_t* actual_size_out);
|
| 75 |
+
|
| 76 |
+
// Closes the data. Helps us avoid destructor shenanigans
|
| 77 |
+
virtual void close();
|
| 78 |
+
|
| 79 |
+
// This is very dangerous. You have to redefine this destructor for each
|
| 80 |
+
// subclass
|
| 81 |
+
virtual ~MapAllocator();
|
| 82 |
+
|
| 83 |
+
protected:
|
| 84 |
+
bool closed_ = false;
|
| 85 |
+
std::string filename_;
|
| 86 |
+
int flags_ = 0;
|
| 87 |
+
ptrdiff_t size_; /* mapped size */
|
| 88 |
+
#ifdef _WIN32
|
| 89 |
+
void* handle_;
|
| 90 |
+
void* event_;
|
| 91 |
+
std::string eventname_;
|
| 92 |
+
#else
|
| 93 |
+
int fd_ = -1;
|
| 94 |
+
#endif
|
| 95 |
+
void* base_ptr_ = nullptr;
|
| 96 |
+
};
|
| 97 |
+
|
| 98 |
+
// Base-from-member idiom
|
| 99 |
+
struct TORCH_API RefcountedMapAllocatorArgCheck {
|
| 100 |
+
RefcountedMapAllocatorArgCheck(int flags);
|
| 101 |
+
};
|
| 102 |
+
|
| 103 |
+
class TORCH_API RefcountedMapAllocator : private RefcountedMapAllocatorArgCheck,
|
| 104 |
+
public MapAllocator {
|
| 105 |
+
public:
|
| 106 |
+
RefcountedMapAllocator(const char* filename, int flags, size_t size);
|
| 107 |
+
RefcountedMapAllocator(
|
| 108 |
+
WithFd,
|
| 109 |
+
const char* filename,
|
| 110 |
+
int fd,
|
| 111 |
+
int flags,
|
| 112 |
+
size_t size);
|
| 113 |
+
|
| 114 |
+
static RefcountedMapAllocator* fromDataPtr(const at::DataPtr&);
|
| 115 |
+
static at::DataPtr makeDataPtr(
|
| 116 |
+
const char* filename,
|
| 117 |
+
int flags,
|
| 118 |
+
size_t size,
|
| 119 |
+
size_t* actual_size_out);
|
| 120 |
+
static at::DataPtr makeDataPtr(
|
| 121 |
+
WithFd,
|
| 122 |
+
const char* filename,
|
| 123 |
+
int fd,
|
| 124 |
+
int flags,
|
| 125 |
+
size_t size,
|
| 126 |
+
size_t* actual_size_out);
|
| 127 |
+
|
| 128 |
+
void* data() const override;
|
| 129 |
+
|
| 130 |
+
void incref();
|
| 131 |
+
int decref();
|
| 132 |
+
void close() override;
|
| 133 |
+
|
| 134 |
+
~RefcountedMapAllocator() override {
|
| 135 |
+
RefcountedMapAllocator::close();
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
protected:
|
| 139 |
+
void checkFlags();
|
| 140 |
+
void initializeAlloc();
|
| 141 |
+
};
|
| 142 |
+
|
| 143 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/MatrixRef.h
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/Utils.h>
|
| 3 |
+
#include <c10/util/ArrayRef.h>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
/// MatrixRef - Like an ArrayRef, but with an extra recorded strides so that
|
| 7 |
+
/// we can easily view it as a multidimensional array.
|
| 8 |
+
///
|
| 9 |
+
/// Like ArrayRef, this class does not own the underlying data, it is expected
|
| 10 |
+
/// to be used in situations where the data resides in some other buffer.
|
| 11 |
+
///
|
| 12 |
+
/// This is intended to be trivially copyable, so it should be passed by
|
| 13 |
+
/// value.
|
| 14 |
+
///
|
| 15 |
+
/// For now, 2D only (so the copies are actually cheap, without having
|
| 16 |
+
/// to write a SmallVector class) and contiguous only (so we can
|
| 17 |
+
/// return non-strided ArrayRef on index).
|
| 18 |
+
///
|
| 19 |
+
/// P.S. dimension 0 indexes rows, dimension 1 indexes columns
|
| 20 |
+
template <typename T>
|
| 21 |
+
class MatrixRef {
|
| 22 |
+
public:
|
| 23 |
+
typedef size_t size_type;
|
| 24 |
+
|
| 25 |
+
private:
|
| 26 |
+
/// Underlying ArrayRef
|
| 27 |
+
ArrayRef<T> arr;
|
| 28 |
+
|
| 29 |
+
/// Stride of dim 0 (outer dimension)
|
| 30 |
+
size_type stride0;
|
| 31 |
+
|
| 32 |
+
// Stride of dim 1 is assumed to be 1
|
| 33 |
+
|
| 34 |
+
public:
|
| 35 |
+
/// Construct an empty Matrixref.
|
| 36 |
+
/*implicit*/ MatrixRef() : arr(nullptr), stride0(0) {}
|
| 37 |
+
|
| 38 |
+
/// Construct an MatrixRef from an ArrayRef and outer stride.
|
| 39 |
+
/*implicit*/ MatrixRef(ArrayRef<T> arr, size_type stride0)
|
| 40 |
+
: arr(arr), stride0(stride0) {
|
| 41 |
+
TORCH_CHECK(
|
| 42 |
+
arr.size() % stride0 == 0,
|
| 43 |
+
"MatrixRef: ArrayRef size ",
|
| 44 |
+
arr.size(),
|
| 45 |
+
" not divisible by stride ",
|
| 46 |
+
stride0)
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
/// @}
|
| 50 |
+
/// @name Simple Operations
|
| 51 |
+
/// @{
|
| 52 |
+
|
| 53 |
+
/// empty - Check if the matrix is empty.
|
| 54 |
+
bool empty() const {
|
| 55 |
+
return arr.empty();
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
const T* data() const {
|
| 59 |
+
return arr.data();
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
/// size - Get size a dimension
|
| 63 |
+
size_t size(size_t dim) const {
|
| 64 |
+
if (dim == 0) {
|
| 65 |
+
return arr.size() / stride0;
|
| 66 |
+
} else if (dim == 1) {
|
| 67 |
+
return stride0;
|
| 68 |
+
} else {
|
| 69 |
+
TORCH_CHECK(
|
| 70 |
+
0, "MatrixRef: out of bounds dimension ", dim, "; expected 0 or 1");
|
| 71 |
+
}
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
size_t numel() const {
|
| 75 |
+
return arr.size();
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
/// equals - Check for element-wise equality.
|
| 79 |
+
bool equals(MatrixRef RHS) const {
|
| 80 |
+
return stride0 == RHS.stride0 && arr.equals(RHS.arr);
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
/// @}
|
| 84 |
+
/// @name Operator Overloads
|
| 85 |
+
/// @{
|
| 86 |
+
ArrayRef<T> operator[](size_t Index) const {
|
| 87 |
+
return arr.slice(Index * stride0, stride0);
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
/// Disallow accidental assignment from a temporary.
|
| 91 |
+
///
|
| 92 |
+
/// The declaration here is extra complicated so that "arrayRef = {}"
|
| 93 |
+
/// continues to select the move assignment operator.
|
| 94 |
+
template <typename U>
|
| 95 |
+
std::enable_if_t<std::is_same_v<U, T>, MatrixRef<T>>& operator=(
|
| 96 |
+
U&& Temporary) = delete;
|
| 97 |
+
|
| 98 |
+
/// Disallow accidental assignment from a temporary.
|
| 99 |
+
///
|
| 100 |
+
/// The declaration here is extra complicated so that "arrayRef = {}"
|
| 101 |
+
/// continues to select the move assignment operator.
|
| 102 |
+
template <typename U>
|
| 103 |
+
std::enable_if_t<std::is_same_v<U, T>, MatrixRef<T>>& operator=(
|
| 104 |
+
std::initializer_list<U>) = delete;
|
| 105 |
+
};
|
| 106 |
+
|
| 107 |
+
} // end namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/MetaFunctions.h
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/TensorBody.h>
|
| 2 |
+
|
| 3 |
+
// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
|
| 4 |
+
// Code introduced to avoid cyclic dependency in static dispatch is no longer
|
| 5 |
+
// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
|
| 6 |
+
// to Operators.cpp for supporting multiple backends with multiple kernels.
|
| 7 |
+
//
|
| 8 |
+
// Note [Avoiding Include Cycles In Static Dispatch]
|
| 9 |
+
// In order to avoid #include cycles in the static dispatch build, we've carefully split out
|
| 10 |
+
// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
|
| 11 |
+
//
|
| 12 |
+
// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
|
| 13 |
+
// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
|
| 14 |
+
// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
|
| 15 |
+
// directly inlined into TensorBody.h.
|
| 16 |
+
// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
|
| 17 |
+
// which include functions that have defaultable std::optional<Tensor> arguments.
|
| 18 |
+
// That requires knowing the full Tensor class definition.
|
| 19 |
+
//
|
| 20 |
+
// We break the cycle by doing the following:
|
| 21 |
+
// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
|
| 22 |
+
// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
|
| 23 |
+
// - CPUFunctions_inl.h includes everything else
|
| 24 |
+
// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
|
| 25 |
+
// and then it includes CPUFunctions_inl.h.
|
| 26 |
+
// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
|
| 27 |
+
// - This also means that static dispatch build, CPUFunctions.h only needs to
|
| 28 |
+
// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
|
| 29 |
+
#include <ATen/MetaFunctions_inl.h>
|
.venv/lib/python3.11/site-packages/torch/include/ATen/MethodOperators.h
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from MethodOperators.h
|
| 4 |
+
|
| 5 |
+
#ifdef TORCH_ASSERT_NO_OPERATORS
|
| 6 |
+
#error This change adds a dependency on native_functions.yaml, \
|
| 7 |
+
meaning the file will need to be re-compiled every time an operator \
|
| 8 |
+
is changed or added. Consider if your change would be better placed in \
|
| 9 |
+
another file, or if a more specific header might achieve the same goal. \
|
| 10 |
+
See NOTE: [Tensor vs. TensorBase]
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 14 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 15 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 16 |
+
#include <ATen/core/ATen_fwd.h>
|
| 17 |
+
|
| 18 |
+
#include <ATen/ops/_addmm_activation_ops.h>
|
| 19 |
+
#include <ATen/ops/_autocast_to_full_precision_ops.h>
|
| 20 |
+
#include <ATen/ops/_autocast_to_reduced_precision_ops.h>
|
| 21 |
+
#include <ATen/ops/_backward_ops.h>
|
| 22 |
+
#include <ATen/ops/_coalesced_ops.h>
|
| 23 |
+
#include <ATen/ops/_conj_ops.h>
|
| 24 |
+
#include <ATen/ops/_conj_physical_ops.h>
|
| 25 |
+
#include <ATen/ops/_dimI_ops.h>
|
| 26 |
+
#include <ATen/ops/_dimV_ops.h>
|
| 27 |
+
#include <ATen/ops/_fw_primal_ops.h>
|
| 28 |
+
#include <ATen/ops/_indices_ops.h>
|
| 29 |
+
#include <ATen/ops/_is_all_true_ops.h>
|
| 30 |
+
#include <ATen/ops/_is_any_true_ops.h>
|
| 31 |
+
#include <ATen/ops/_is_zerotensor_ops.h>
|
| 32 |
+
#include <ATen/ops/_lazy_clone_ops.h>
|
| 33 |
+
#include <ATen/ops/_neg_view_ops.h>
|
| 34 |
+
#include <ATen/ops/_nested_tensor_size_ops.h>
|
| 35 |
+
#include <ATen/ops/_nested_tensor_storage_offsets_ops.h>
|
| 36 |
+
#include <ATen/ops/_nested_tensor_strides_ops.h>
|
| 37 |
+
#include <ATen/ops/_nnz_ops.h>
|
| 38 |
+
#include <ATen/ops/_reshape_alias_ops.h>
|
| 39 |
+
#include <ATen/ops/_sparse_mask_projection_ops.h>
|
| 40 |
+
#include <ATen/ops/_to_dense_ops.h>
|
| 41 |
+
#include <ATen/ops/_to_sparse_bsc_ops.h>
|
| 42 |
+
#include <ATen/ops/_to_sparse_bsr_ops.h>
|
| 43 |
+
#include <ATen/ops/_to_sparse_csc_ops.h>
|
| 44 |
+
#include <ATen/ops/_to_sparse_csr_ops.h>
|
| 45 |
+
#include <ATen/ops/_to_sparse_ops.h>
|
| 46 |
+
#include <ATen/ops/_values_ops.h>
|
| 47 |
+
#include <ATen/ops/_version_ops.h>
|
| 48 |
+
#include <ATen/ops/abs_ops.h>
|
| 49 |
+
#include <ATen/ops/absolute_ops.h>
|
| 50 |
+
#include <ATen/ops/acos_ops.h>
|
| 51 |
+
#include <ATen/ops/acosh_ops.h>
|
| 52 |
+
#include <ATen/ops/add_ops.h>
|
| 53 |
+
#include <ATen/ops/addbmm_ops.h>
|
| 54 |
+
#include <ATen/ops/addcdiv_ops.h>
|
| 55 |
+
#include <ATen/ops/addcmul_ops.h>
|
| 56 |
+
#include <ATen/ops/addmm_ops.h>
|
| 57 |
+
#include <ATen/ops/addmv_ops.h>
|
| 58 |
+
#include <ATen/ops/addr_ops.h>
|
| 59 |
+
#include <ATen/ops/adjoint_ops.h>
|
| 60 |
+
#include <ATen/ops/alias_ops.h>
|
| 61 |
+
#include <ATen/ops/align_as_ops.h>
|
| 62 |
+
#include <ATen/ops/align_to_ops.h>
|
| 63 |
+
#include <ATen/ops/all_ops.h>
|
| 64 |
+
#include <ATen/ops/allclose_ops.h>
|
| 65 |
+
#include <ATen/ops/amax_ops.h>
|
| 66 |
+
#include <ATen/ops/amin_ops.h>
|
| 67 |
+
#include <ATen/ops/aminmax_ops.h>
|
| 68 |
+
#include <ATen/ops/and_ops.h>
|
| 69 |
+
#include <ATen/ops/angle_ops.h>
|
| 70 |
+
#include <ATen/ops/any_ops.h>
|
| 71 |
+
#include <ATen/ops/arccos_ops.h>
|
| 72 |
+
#include <ATen/ops/arccosh_ops.h>
|
| 73 |
+
#include <ATen/ops/arcsin_ops.h>
|
| 74 |
+
#include <ATen/ops/arcsinh_ops.h>
|
| 75 |
+
#include <ATen/ops/arctan2_ops.h>
|
| 76 |
+
#include <ATen/ops/arctan_ops.h>
|
| 77 |
+
#include <ATen/ops/arctanh_ops.h>
|
| 78 |
+
#include <ATen/ops/argmax_ops.h>
|
| 79 |
+
#include <ATen/ops/argmin_ops.h>
|
| 80 |
+
#include <ATen/ops/argsort_ops.h>
|
| 81 |
+
#include <ATen/ops/argwhere_ops.h>
|
| 82 |
+
#include <ATen/ops/as_strided_ops.h>
|
| 83 |
+
#include <ATen/ops/as_strided_scatter_ops.h>
|
| 84 |
+
#include <ATen/ops/asin_ops.h>
|
| 85 |
+
#include <ATen/ops/asinh_ops.h>
|
| 86 |
+
#include <ATen/ops/atan2_ops.h>
|
| 87 |
+
#include <ATen/ops/atan_ops.h>
|
| 88 |
+
#include <ATen/ops/atanh_ops.h>
|
| 89 |
+
#include <ATen/ops/baddbmm_ops.h>
|
| 90 |
+
#include <ATen/ops/bernoulli_ops.h>
|
| 91 |
+
#include <ATen/ops/bincount_ops.h>
|
| 92 |
+
#include <ATen/ops/bitwise_and_ops.h>
|
| 93 |
+
#include <ATen/ops/bitwise_left_shift_ops.h>
|
| 94 |
+
#include <ATen/ops/bitwise_not_ops.h>
|
| 95 |
+
#include <ATen/ops/bitwise_or_ops.h>
|
| 96 |
+
#include <ATen/ops/bitwise_right_shift_ops.h>
|
| 97 |
+
#include <ATen/ops/bitwise_xor_ops.h>
|
| 98 |
+
#include <ATen/ops/bmm_ops.h>
|
| 99 |
+
#include <ATen/ops/broadcast_to_ops.h>
|
| 100 |
+
#include <ATen/ops/cauchy_ops.h>
|
| 101 |
+
#include <ATen/ops/ccol_indices_ops.h>
|
| 102 |
+
#include <ATen/ops/ceil_ops.h>
|
| 103 |
+
#include <ATen/ops/chalf_ops.h>
|
| 104 |
+
#include <ATen/ops/cholesky_inverse_ops.h>
|
| 105 |
+
#include <ATen/ops/cholesky_ops.h>
|
| 106 |
+
#include <ATen/ops/cholesky_solve_ops.h>
|
| 107 |
+
#include <ATen/ops/chunk_ops.h>
|
| 108 |
+
#include <ATen/ops/clamp_max_ops.h>
|
| 109 |
+
#include <ATen/ops/clamp_min_ops.h>
|
| 110 |
+
#include <ATen/ops/clamp_ops.h>
|
| 111 |
+
#include <ATen/ops/clip_ops.h>
|
| 112 |
+
#include <ATen/ops/clone_ops.h>
|
| 113 |
+
#include <ATen/ops/coalesce_ops.h>
|
| 114 |
+
#include <ATen/ops/col_indices_ops.h>
|
| 115 |
+
#include <ATen/ops/conj_ops.h>
|
| 116 |
+
#include <ATen/ops/conj_physical_ops.h>
|
| 117 |
+
#include <ATen/ops/contiguous_ops.h>
|
| 118 |
+
#include <ATen/ops/copy_ops.h>
|
| 119 |
+
#include <ATen/ops/copysign_ops.h>
|
| 120 |
+
#include <ATen/ops/corrcoef_ops.h>
|
| 121 |
+
#include <ATen/ops/cos_ops.h>
|
| 122 |
+
#include <ATen/ops/cosh_ops.h>
|
| 123 |
+
#include <ATen/ops/count_nonzero_ops.h>
|
| 124 |
+
#include <ATen/ops/cov_ops.h>
|
| 125 |
+
#include <ATen/ops/cross_ops.h>
|
| 126 |
+
#include <ATen/ops/crow_indices_ops.h>
|
| 127 |
+
#include <ATen/ops/cummax_ops.h>
|
| 128 |
+
#include <ATen/ops/cummin_ops.h>
|
| 129 |
+
#include <ATen/ops/cumprod_ops.h>
|
| 130 |
+
#include <ATen/ops/cumsum_ops.h>
|
| 131 |
+
#include <ATen/ops/data_ops.h>
|
| 132 |
+
#include <ATen/ops/deg2rad_ops.h>
|
| 133 |
+
#include <ATen/ops/dense_dim_ops.h>
|
| 134 |
+
#include <ATen/ops/dequantize_ops.h>
|
| 135 |
+
#include <ATen/ops/det_ops.h>
|
| 136 |
+
#include <ATen/ops/detach_ops.h>
|
| 137 |
+
#include <ATen/ops/diag_embed_ops.h>
|
| 138 |
+
#include <ATen/ops/diag_ops.h>
|
| 139 |
+
#include <ATen/ops/diagflat_ops.h>
|
| 140 |
+
#include <ATen/ops/diagonal_ops.h>
|
| 141 |
+
#include <ATen/ops/diagonal_scatter_ops.h>
|
| 142 |
+
#include <ATen/ops/diff_ops.h>
|
| 143 |
+
#include <ATen/ops/digamma_ops.h>
|
| 144 |
+
#include <ATen/ops/dist_ops.h>
|
| 145 |
+
#include <ATen/ops/div_ops.h>
|
| 146 |
+
#include <ATen/ops/divide_ops.h>
|
| 147 |
+
#include <ATen/ops/dot_ops.h>
|
| 148 |
+
#include <ATen/ops/dsplit_ops.h>
|
| 149 |
+
#include <ATen/ops/eq_ops.h>
|
| 150 |
+
#include <ATen/ops/equal_ops.h>
|
| 151 |
+
#include <ATen/ops/erf_ops.h>
|
| 152 |
+
#include <ATen/ops/erfc_ops.h>
|
| 153 |
+
#include <ATen/ops/erfinv_ops.h>
|
| 154 |
+
#include <ATen/ops/exp2_ops.h>
|
| 155 |
+
#include <ATen/ops/exp_ops.h>
|
| 156 |
+
#include <ATen/ops/expand_as_ops.h>
|
| 157 |
+
#include <ATen/ops/expand_ops.h>
|
| 158 |
+
#include <ATen/ops/expm1_ops.h>
|
| 159 |
+
#include <ATen/ops/exponential_ops.h>
|
| 160 |
+
#include <ATen/ops/fill_diagonal_ops.h>
|
| 161 |
+
#include <ATen/ops/fill_ops.h>
|
| 162 |
+
#include <ATen/ops/fix_ops.h>
|
| 163 |
+
#include <ATen/ops/flatten_ops.h>
|
| 164 |
+
#include <ATen/ops/flip_ops.h>
|
| 165 |
+
#include <ATen/ops/fliplr_ops.h>
|
| 166 |
+
#include <ATen/ops/flipud_ops.h>
|
| 167 |
+
#include <ATen/ops/float_power_ops.h>
|
| 168 |
+
#include <ATen/ops/floor_divide_ops.h>
|
| 169 |
+
#include <ATen/ops/floor_ops.h>
|
| 170 |
+
#include <ATen/ops/fmax_ops.h>
|
| 171 |
+
#include <ATen/ops/fmin_ops.h>
|
| 172 |
+
#include <ATen/ops/fmod_ops.h>
|
| 173 |
+
#include <ATen/ops/frac_ops.h>
|
| 174 |
+
#include <ATen/ops/frexp_ops.h>
|
| 175 |
+
#include <ATen/ops/gather_ops.h>
|
| 176 |
+
#include <ATen/ops/gcd_ops.h>
|
| 177 |
+
#include <ATen/ops/ge_ops.h>
|
| 178 |
+
#include <ATen/ops/geometric_ops.h>
|
| 179 |
+
#include <ATen/ops/geqrf_ops.h>
|
| 180 |
+
#include <ATen/ops/ger_ops.h>
|
| 181 |
+
#include <ATen/ops/greater_equal_ops.h>
|
| 182 |
+
#include <ATen/ops/greater_ops.h>
|
| 183 |
+
#include <ATen/ops/gt_ops.h>
|
| 184 |
+
#include <ATen/ops/hardshrink_backward_ops.h>
|
| 185 |
+
#include <ATen/ops/hardshrink_ops.h>
|
| 186 |
+
#include <ATen/ops/heaviside_ops.h>
|
| 187 |
+
#include <ATen/ops/histc_ops.h>
|
| 188 |
+
#include <ATen/ops/histogram_ops.h>
|
| 189 |
+
#include <ATen/ops/hsplit_ops.h>
|
| 190 |
+
#include <ATen/ops/hypot_ops.h>
|
| 191 |
+
#include <ATen/ops/i0_ops.h>
|
| 192 |
+
#include <ATen/ops/igamma_ops.h>
|
| 193 |
+
#include <ATen/ops/igammac_ops.h>
|
| 194 |
+
#include <ATen/ops/index_add_ops.h>
|
| 195 |
+
#include <ATen/ops/index_copy_ops.h>
|
| 196 |
+
#include <ATen/ops/index_fill_ops.h>
|
| 197 |
+
#include <ATen/ops/index_ops.h>
|
| 198 |
+
#include <ATen/ops/index_put_ops.h>
|
| 199 |
+
#include <ATen/ops/index_reduce_ops.h>
|
| 200 |
+
#include <ATen/ops/index_select_ops.h>
|
| 201 |
+
#include <ATen/ops/indices_ops.h>
|
| 202 |
+
#include <ATen/ops/inner_ops.h>
|
| 203 |
+
#include <ATen/ops/int_repr_ops.h>
|
| 204 |
+
#include <ATen/ops/inverse_ops.h>
|
| 205 |
+
#include <ATen/ops/is_coalesced_ops.h>
|
| 206 |
+
#include <ATen/ops/is_complex_ops.h>
|
| 207 |
+
#include <ATen/ops/is_conj_ops.h>
|
| 208 |
+
#include <ATen/ops/is_distributed_ops.h>
|
| 209 |
+
#include <ATen/ops/is_floating_point_ops.h>
|
| 210 |
+
#include <ATen/ops/is_inference_ops.h>
|
| 211 |
+
#include <ATen/ops/is_leaf_ops.h>
|
| 212 |
+
#include <ATen/ops/is_neg_ops.h>
|
| 213 |
+
#include <ATen/ops/is_nonzero_ops.h>
|
| 214 |
+
#include <ATen/ops/is_pinned_ops.h>
|
| 215 |
+
#include <ATen/ops/is_same_size_ops.h>
|
| 216 |
+
#include <ATen/ops/is_set_to_ops.h>
|
| 217 |
+
#include <ATen/ops/is_signed_ops.h>
|
| 218 |
+
#include <ATen/ops/isclose_ops.h>
|
| 219 |
+
#include <ATen/ops/isfinite_ops.h>
|
| 220 |
+
#include <ATen/ops/isinf_ops.h>
|
| 221 |
+
#include <ATen/ops/isnan_ops.h>
|
| 222 |
+
#include <ATen/ops/isneginf_ops.h>
|
| 223 |
+
#include <ATen/ops/isposinf_ops.h>
|
| 224 |
+
#include <ATen/ops/isreal_ops.h>
|
| 225 |
+
#include <ATen/ops/istft_ops.h>
|
| 226 |
+
#include <ATen/ops/item_ops.h>
|
| 227 |
+
#include <ATen/ops/kron_ops.h>
|
| 228 |
+
#include <ATen/ops/kthvalue_ops.h>
|
| 229 |
+
#include <ATen/ops/lcm_ops.h>
|
| 230 |
+
#include <ATen/ops/ldexp_ops.h>
|
| 231 |
+
#include <ATen/ops/le_ops.h>
|
| 232 |
+
#include <ATen/ops/lerp_ops.h>
|
| 233 |
+
#include <ATen/ops/less_equal_ops.h>
|
| 234 |
+
#include <ATen/ops/less_ops.h>
|
| 235 |
+
#include <ATen/ops/lgamma_ops.h>
|
| 236 |
+
#include <ATen/ops/log10_ops.h>
|
| 237 |
+
#include <ATen/ops/log1p_ops.h>
|
| 238 |
+
#include <ATen/ops/log2_ops.h>
|
| 239 |
+
#include <ATen/ops/log_normal_ops.h>
|
| 240 |
+
#include <ATen/ops/log_ops.h>
|
| 241 |
+
#include <ATen/ops/log_softmax_ops.h>
|
| 242 |
+
#include <ATen/ops/logaddexp2_ops.h>
|
| 243 |
+
#include <ATen/ops/logaddexp_ops.h>
|
| 244 |
+
#include <ATen/ops/logcumsumexp_ops.h>
|
| 245 |
+
#include <ATen/ops/logdet_ops.h>
|
| 246 |
+
#include <ATen/ops/logical_and_ops.h>
|
| 247 |
+
#include <ATen/ops/logical_not_ops.h>
|
| 248 |
+
#include <ATen/ops/logical_or_ops.h>
|
| 249 |
+
#include <ATen/ops/logical_xor_ops.h>
|
| 250 |
+
#include <ATen/ops/logit_ops.h>
|
| 251 |
+
#include <ATen/ops/logsumexp_ops.h>
|
| 252 |
+
#include <ATen/ops/lshift_ops.h>
|
| 253 |
+
#include <ATen/ops/lt_ops.h>
|
| 254 |
+
#include <ATen/ops/lu_solve_ops.h>
|
| 255 |
+
#include <ATen/ops/mH_ops.h>
|
| 256 |
+
#include <ATen/ops/mT_ops.h>
|
| 257 |
+
#include <ATen/ops/masked_fill_ops.h>
|
| 258 |
+
#include <ATen/ops/masked_scatter_ops.h>
|
| 259 |
+
#include <ATen/ops/masked_select_ops.h>
|
| 260 |
+
#include <ATen/ops/matmul_ops.h>
|
| 261 |
+
#include <ATen/ops/matrix_H_ops.h>
|
| 262 |
+
#include <ATen/ops/matrix_exp_ops.h>
|
| 263 |
+
#include <ATen/ops/matrix_power_ops.h>
|
| 264 |
+
#include <ATen/ops/max_ops.h>
|
| 265 |
+
#include <ATen/ops/maximum_ops.h>
|
| 266 |
+
#include <ATen/ops/mean_ops.h>
|
| 267 |
+
#include <ATen/ops/median_ops.h>
|
| 268 |
+
#include <ATen/ops/min_ops.h>
|
| 269 |
+
#include <ATen/ops/minimum_ops.h>
|
| 270 |
+
#include <ATen/ops/mm_ops.h>
|
| 271 |
+
#include <ATen/ops/mode_ops.h>
|
| 272 |
+
#include <ATen/ops/moveaxis_ops.h>
|
| 273 |
+
#include <ATen/ops/movedim_ops.h>
|
| 274 |
+
#include <ATen/ops/msort_ops.h>
|
| 275 |
+
#include <ATen/ops/mul_ops.h>
|
| 276 |
+
#include <ATen/ops/multinomial_ops.h>
|
| 277 |
+
#include <ATen/ops/multiply_ops.h>
|
| 278 |
+
#include <ATen/ops/mv_ops.h>
|
| 279 |
+
#include <ATen/ops/mvlgamma_ops.h>
|
| 280 |
+
#include <ATen/ops/nan_to_num_ops.h>
|
| 281 |
+
#include <ATen/ops/nanmean_ops.h>
|
| 282 |
+
#include <ATen/ops/nanmedian_ops.h>
|
| 283 |
+
#include <ATen/ops/nanquantile_ops.h>
|
| 284 |
+
#include <ATen/ops/nansum_ops.h>
|
| 285 |
+
#include <ATen/ops/narrow_copy_ops.h>
|
| 286 |
+
#include <ATen/ops/narrow_ops.h>
|
| 287 |
+
#include <ATen/ops/ne_ops.h>
|
| 288 |
+
#include <ATen/ops/neg_ops.h>
|
| 289 |
+
#include <ATen/ops/negative_ops.h>
|
| 290 |
+
#include <ATen/ops/new_empty_ops.h>
|
| 291 |
+
#include <ATen/ops/new_empty_strided_ops.h>
|
| 292 |
+
#include <ATen/ops/new_full_ops.h>
|
| 293 |
+
#include <ATen/ops/new_ones_ops.h>
|
| 294 |
+
#include <ATen/ops/new_zeros_ops.h>
|
| 295 |
+
#include <ATen/ops/nextafter_ops.h>
|
| 296 |
+
#include <ATen/ops/nonzero_numpy_ops.h>
|
| 297 |
+
#include <ATen/ops/nonzero_ops.h>
|
| 298 |
+
#include <ATen/ops/nonzero_static_ops.h>
|
| 299 |
+
#include <ATen/ops/norm_ops.h>
|
| 300 |
+
#include <ATen/ops/normal_ops.h>
|
| 301 |
+
#include <ATen/ops/not_equal_ops.h>
|
| 302 |
+
#include <ATen/ops/numpy_T_ops.h>
|
| 303 |
+
#include <ATen/ops/or_ops.h>
|
| 304 |
+
#include <ATen/ops/orgqr_ops.h>
|
| 305 |
+
#include <ATen/ops/ormqr_ops.h>
|
| 306 |
+
#include <ATen/ops/outer_ops.h>
|
| 307 |
+
#include <ATen/ops/output_nr_ops.h>
|
| 308 |
+
#include <ATen/ops/permute_ops.h>
|
| 309 |
+
#include <ATen/ops/pin_memory_ops.h>
|
| 310 |
+
#include <ATen/ops/pinverse_ops.h>
|
| 311 |
+
#include <ATen/ops/polygamma_ops.h>
|
| 312 |
+
#include <ATen/ops/positive_ops.h>
|
| 313 |
+
#include <ATen/ops/pow_ops.h>
|
| 314 |
+
#include <ATen/ops/prelu_ops.h>
|
| 315 |
+
#include <ATen/ops/prod_ops.h>
|
| 316 |
+
#include <ATen/ops/put_ops.h>
|
| 317 |
+
#include <ATen/ops/q_per_channel_axis_ops.h>
|
| 318 |
+
#include <ATen/ops/q_per_channel_scales_ops.h>
|
| 319 |
+
#include <ATen/ops/q_per_channel_zero_points_ops.h>
|
| 320 |
+
#include <ATen/ops/q_scale_ops.h>
|
| 321 |
+
#include <ATen/ops/q_zero_point_ops.h>
|
| 322 |
+
#include <ATen/ops/qr_ops.h>
|
| 323 |
+
#include <ATen/ops/qscheme_ops.h>
|
| 324 |
+
#include <ATen/ops/quantile_ops.h>
|
| 325 |
+
#include <ATen/ops/rad2deg_ops.h>
|
| 326 |
+
#include <ATen/ops/random_ops.h>
|
| 327 |
+
#include <ATen/ops/ravel_ops.h>
|
| 328 |
+
#include <ATen/ops/reciprocal_ops.h>
|
| 329 |
+
#include <ATen/ops/record_stream_ops.h>
|
| 330 |
+
#include <ATen/ops/refine_names_ops.h>
|
| 331 |
+
#include <ATen/ops/relu_ops.h>
|
| 332 |
+
#include <ATen/ops/remainder_ops.h>
|
| 333 |
+
#include <ATen/ops/rename_ops.h>
|
| 334 |
+
#include <ATen/ops/renorm_ops.h>
|
| 335 |
+
#include <ATen/ops/repeat_interleave_ops.h>
|
| 336 |
+
#include <ATen/ops/repeat_ops.h>
|
| 337 |
+
#include <ATen/ops/requires_grad_ops.h>
|
| 338 |
+
#include <ATen/ops/reshape_as_ops.h>
|
| 339 |
+
#include <ATen/ops/reshape_ops.h>
|
| 340 |
+
#include <ATen/ops/resize_as_ops.h>
|
| 341 |
+
#include <ATen/ops/resize_as_sparse_ops.h>
|
| 342 |
+
#include <ATen/ops/resize_ops.h>
|
| 343 |
+
#include <ATen/ops/resolve_conj_ops.h>
|
| 344 |
+
#include <ATen/ops/resolve_neg_ops.h>
|
| 345 |
+
#include <ATen/ops/retain_grad_ops.h>
|
| 346 |
+
#include <ATen/ops/retains_grad_ops.h>
|
| 347 |
+
#include <ATen/ops/roll_ops.h>
|
| 348 |
+
#include <ATen/ops/rot90_ops.h>
|
| 349 |
+
#include <ATen/ops/round_ops.h>
|
| 350 |
+
#include <ATen/ops/row_indices_ops.h>
|
| 351 |
+
#include <ATen/ops/rshift_ops.h>
|
| 352 |
+
#include <ATen/ops/rsqrt_ops.h>
|
| 353 |
+
#include <ATen/ops/scatter_add_ops.h>
|
| 354 |
+
#include <ATen/ops/scatter_ops.h>
|
| 355 |
+
#include <ATen/ops/scatter_reduce_ops.h>
|
| 356 |
+
#include <ATen/ops/select_ops.h>
|
| 357 |
+
#include <ATen/ops/select_scatter_ops.h>
|
| 358 |
+
#include <ATen/ops/set_data_ops.h>
|
| 359 |
+
#include <ATen/ops/set_ops.h>
|
| 360 |
+
#include <ATen/ops/sgn_ops.h>
|
| 361 |
+
#include <ATen/ops/sigmoid_ops.h>
|
| 362 |
+
#include <ATen/ops/sign_ops.h>
|
| 363 |
+
#include <ATen/ops/signbit_ops.h>
|
| 364 |
+
#include <ATen/ops/sin_ops.h>
|
| 365 |
+
#include <ATen/ops/sinc_ops.h>
|
| 366 |
+
#include <ATen/ops/sinh_ops.h>
|
| 367 |
+
#include <ATen/ops/size_ops.h>
|
| 368 |
+
#include <ATen/ops/slice_inverse_ops.h>
|
| 369 |
+
#include <ATen/ops/slice_ops.h>
|
| 370 |
+
#include <ATen/ops/slice_scatter_ops.h>
|
| 371 |
+
#include <ATen/ops/slogdet_ops.h>
|
| 372 |
+
#include <ATen/ops/smm_ops.h>
|
| 373 |
+
#include <ATen/ops/softmax_ops.h>
|
| 374 |
+
#include <ATen/ops/sort_ops.h>
|
| 375 |
+
#include <ATen/ops/sparse_dim_ops.h>
|
| 376 |
+
#include <ATen/ops/sparse_mask_ops.h>
|
| 377 |
+
#include <ATen/ops/sparse_resize_and_clear_ops.h>
|
| 378 |
+
#include <ATen/ops/sparse_resize_ops.h>
|
| 379 |
+
#include <ATen/ops/split_ops.h>
|
| 380 |
+
#include <ATen/ops/split_with_sizes_ops.h>
|
| 381 |
+
#include <ATen/ops/sqrt_ops.h>
|
| 382 |
+
#include <ATen/ops/square_ops.h>
|
| 383 |
+
#include <ATen/ops/squeeze_ops.h>
|
| 384 |
+
#include <ATen/ops/sspaddmm_ops.h>
|
| 385 |
+
#include <ATen/ops/std_ops.h>
|
| 386 |
+
#include <ATen/ops/stft_ops.h>
|
| 387 |
+
#include <ATen/ops/stride_ops.h>
|
| 388 |
+
#include <ATen/ops/sub_ops.h>
|
| 389 |
+
#include <ATen/ops/subtract_ops.h>
|
| 390 |
+
#include <ATen/ops/sum_ops.h>
|
| 391 |
+
#include <ATen/ops/sum_to_size_ops.h>
|
| 392 |
+
#include <ATen/ops/svd_ops.h>
|
| 393 |
+
#include <ATen/ops/swapaxes_ops.h>
|
| 394 |
+
#include <ATen/ops/swapdims_ops.h>
|
| 395 |
+
#include <ATen/ops/t_ops.h>
|
| 396 |
+
#include <ATen/ops/take_along_dim_ops.h>
|
| 397 |
+
#include <ATen/ops/take_ops.h>
|
| 398 |
+
#include <ATen/ops/tan_ops.h>
|
| 399 |
+
#include <ATen/ops/tanh_ops.h>
|
| 400 |
+
#include <ATen/ops/tensor_split_ops.h>
|
| 401 |
+
#include <ATen/ops/tile_ops.h>
|
| 402 |
+
#include <ATen/ops/to_dense_ops.h>
|
| 403 |
+
#include <ATen/ops/to_mkldnn_ops.h>
|
| 404 |
+
#include <ATen/ops/to_ops.h>
|
| 405 |
+
#include <ATen/ops/to_padded_tensor_ops.h>
|
| 406 |
+
#include <ATen/ops/to_sparse_bsc_ops.h>
|
| 407 |
+
#include <ATen/ops/to_sparse_bsr_ops.h>
|
| 408 |
+
#include <ATen/ops/to_sparse_csc_ops.h>
|
| 409 |
+
#include <ATen/ops/to_sparse_csr_ops.h>
|
| 410 |
+
#include <ATen/ops/to_sparse_ops.h>
|
| 411 |
+
#include <ATen/ops/topk_ops.h>
|
| 412 |
+
#include <ATen/ops/trace_ops.h>
|
| 413 |
+
#include <ATen/ops/transpose_ops.h>
|
| 414 |
+
#include <ATen/ops/triangular_solve_ops.h>
|
| 415 |
+
#include <ATen/ops/tril_ops.h>
|
| 416 |
+
#include <ATen/ops/triu_ops.h>
|
| 417 |
+
#include <ATen/ops/true_divide_ops.h>
|
| 418 |
+
#include <ATen/ops/trunc_ops.h>
|
| 419 |
+
#include <ATen/ops/type_as_ops.h>
|
| 420 |
+
#include <ATen/ops/unbind_ops.h>
|
| 421 |
+
#include <ATen/ops/unflatten_ops.h>
|
| 422 |
+
#include <ATen/ops/unfold_ops.h>
|
| 423 |
+
#include <ATen/ops/uniform_ops.h>
|
| 424 |
+
#include <ATen/ops/unsafe_chunk_ops.h>
|
| 425 |
+
#include <ATen/ops/unsafe_split_ops.h>
|
| 426 |
+
#include <ATen/ops/unsafe_split_with_sizes_ops.h>
|
| 427 |
+
#include <ATen/ops/unsqueeze_ops.h>
|
| 428 |
+
#include <ATen/ops/values_ops.h>
|
| 429 |
+
#include <ATen/ops/var_ops.h>
|
| 430 |
+
#include <ATen/ops/vdot_ops.h>
|
| 431 |
+
#include <ATen/ops/view_as_ops.h>
|
| 432 |
+
#include <ATen/ops/view_ops.h>
|
| 433 |
+
#include <ATen/ops/vsplit_ops.h>
|
| 434 |
+
#include <ATen/ops/where_ops.h>
|
| 435 |
+
#include <ATen/ops/xlogy_ops.h>
|
| 436 |
+
#include <ATen/ops/xor_ops.h>
|
| 437 |
+
#include <ATen/ops/zero_ops.h>
|
| 438 |
+
|
| 439 |
+
namespace at {
|
| 440 |
+
namespace _ops {
|
| 441 |
+
|
| 442 |
+
} // namespace _ops
|
| 443 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/NamedTensor.h
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/NamedTensor.h>
|
.venv/lib/python3.11/site-packages/torch/include/ATen/NamedTensorUtils.h
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/NamedTensor.h>
|
| 3 |
+
#include <ATen/TensorNames.h>
|
| 4 |
+
#include <ATen/WrapDimUtilsMulti.h>
|
| 5 |
+
|
| 6 |
+
#include <ATen/core/DimVector.h>
|
| 7 |
+
#include <ATen/core/Tensor.h>
|
| 8 |
+
|
| 9 |
+
namespace at {
|
| 10 |
+
|
| 11 |
+
using NameVector = SmallVector<Dimname, kDimVectorStaticSize>;
|
| 12 |
+
|
| 13 |
+
inline bool has_names(const ITensorListRef& tensors) {
|
| 14 |
+
return std::any_of(tensors.begin(), tensors.end(), [](const Tensor& t) {
|
| 15 |
+
return t.has_names();
|
| 16 |
+
});
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
// Converts dim to an positional index. Errors if `dim` cannot be used to
|
| 20 |
+
// refer to any dimension of tensor.
|
| 21 |
+
TORCH_API int64_t dimname_to_position(const Tensor& tensor, Dimname dim);
|
| 22 |
+
TORCH_API std::vector<int64_t> dimnames_to_positions(
|
| 23 |
+
const Tensor& tensor,
|
| 24 |
+
DimnameList dims);
|
| 25 |
+
|
| 26 |
+
// Unifies two DimnameList to produce a third. This is useful for implementing
|
| 27 |
+
// the named inference rule for binary broadcasting operations like add.
|
| 28 |
+
//
|
| 29 |
+
// There are three main constraints:
|
| 30 |
+
// 1) Check matching: Names must match positionally from the right.
|
| 31 |
+
// 2) Check misaligned: If a name `n` is in `names`, then it must appear at
|
| 32 |
+
// the same index from the right in other.
|
| 33 |
+
// 3) The output names are obtained by unifying the names individually from the
|
| 34 |
+
// right.
|
| 35 |
+
TORCH_API std::vector<Dimname> unify_from_right(
|
| 36 |
+
DimnameList names,
|
| 37 |
+
DimnameList other,
|
| 38 |
+
const char* action = "broadcast");
|
| 39 |
+
|
| 40 |
+
[[noreturn]] inline void reportNYIDimnameOverload(const char* op_name) {
|
| 41 |
+
TORCH_CHECK(
|
| 42 |
+
false,
|
| 43 |
+
op_name,
|
| 44 |
+
": You passed a dimname (string) to this op in place of a dimension "
|
| 45 |
+
"index but it does not yet support this behavior. Please pass a dimension "
|
| 46 |
+
"index to work around this.");
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
// [NOTE] Writing name inference rules
|
| 50 |
+
//
|
| 51 |
+
// Operators that support named tensors are either composed of operations that
|
| 52 |
+
// support named tensors or implement some name inference rule. An op that
|
| 53 |
+
// implements its own name inference rule generally looks like the following:
|
| 54 |
+
//
|
| 55 |
+
// Tensor op(...) {
|
| 56 |
+
// perform_shape_checks(...);
|
| 57 |
+
// # (1)
|
| 58 |
+
// auto maybe_outnames = compute_outnames(...);
|
| 59 |
+
// auto result = [&]() {
|
| 60 |
+
// NoNamesGuard guard;
|
| 61 |
+
// return op_impl(...);
|
| 62 |
+
// }();
|
| 63 |
+
// # (2)
|
| 64 |
+
// propagate_names_if_nonempty(result, maybe_outnames);
|
| 65 |
+
//
|
| 66 |
+
// Each op has (1) a compute outnames step and (2) a propagate names step.
|
| 67 |
+
//
|
| 68 |
+
// compute_outnames is responsible for checking that input names match and
|
| 69 |
+
// determining what the output names should be. It returns either:
|
| 70 |
+
// - {} (if the inputs tensors are all unnamed)
|
| 71 |
+
// - non-empty outnames.
|
| 72 |
+
//
|
| 73 |
+
// propagate_names_if_nonempty propagates the outnames if they exist to the
|
| 74 |
+
// result tensors.
|
| 75 |
+
//
|
| 76 |
+
// The {} case is an optimization; if the user does not use named tensors they
|
| 77 |
+
// pay no perf cost for it.
|
| 78 |
+
|
| 79 |
+
namespace namedinference {
|
| 80 |
+
|
| 81 |
+
const Tensor& propagate_names_if_present_and_nonempty(
|
| 82 |
+
const Tensor& result,
|
| 83 |
+
std::optional<DimnameList> maybe_names,
|
| 84 |
+
bool validate_names = false);
|
| 85 |
+
// Propagates `names` to `result` if `names` is not empty.
|
| 86 |
+
// `names` can be empty; see [NOTE] Writing name inference rules
|
| 87 |
+
// If `names` is not empty, `names.size()` should equal `result.dim()`.
|
| 88 |
+
// When in doubt, use this overload instead of the others.
|
| 89 |
+
TORCH_API const Tensor& propagate_names_if_nonempty(
|
| 90 |
+
const Tensor& result,
|
| 91 |
+
DimnameList maybe_names,
|
| 92 |
+
bool validate_names = false);
|
| 93 |
+
|
| 94 |
+
// Propagates `names` to `result`. Only use this if we are certain that there
|
| 95 |
+
// are names to propagate (that names is not empty).
|
| 96 |
+
TORCH_API const Tensor& propagate_names(
|
| 97 |
+
const Tensor& result,
|
| 98 |
+
DimnameList names,
|
| 99 |
+
bool validate_names = false);
|
| 100 |
+
|
| 101 |
+
// Propagates all names from src to result.
|
| 102 |
+
TORCH_API void propagate_names(const Tensor& result, const Tensor& src);
|
| 103 |
+
|
| 104 |
+
// Propagates all names except for those at the excluded_idxs.
|
| 105 |
+
TORCH_API void propagate_names_except(
|
| 106 |
+
const Tensor& result,
|
| 107 |
+
const Tensor& src,
|
| 108 |
+
IntArrayRef excluded_idxs);
|
| 109 |
+
|
| 110 |
+
// Used for reduction ops that have a `keepdim` arg.
|
| 111 |
+
TORCH_API void propagate_names_for_reduction(
|
| 112 |
+
const Tensor& result,
|
| 113 |
+
const Tensor& src,
|
| 114 |
+
IntArrayRef excluded_idxs,
|
| 115 |
+
bool keepdim);
|
| 116 |
+
|
| 117 |
+
TORCH_API void propagate_names_for_expand(
|
| 118 |
+
const Tensor& result,
|
| 119 |
+
const Tensor& self);
|
| 120 |
+
|
| 121 |
+
TORCH_API std::vector<Dimname> compute_cat_outnames(
|
| 122 |
+
const MaterializedITensorListRef& tensors);
|
| 123 |
+
|
| 124 |
+
TORCH_API std::vector<Dimname> compute_broadcast_outnames(
|
| 125 |
+
const Tensor& self,
|
| 126 |
+
const Tensor& other);
|
| 127 |
+
|
| 128 |
+
TORCH_API std::vector<Dimname> broadcast_to_outnames(
|
| 129 |
+
const Tensor& tensor,
|
| 130 |
+
const Tensor& reference_tensor,
|
| 131 |
+
const char* op_name);
|
| 132 |
+
|
| 133 |
+
TORCH_API std::vector<Dimname> compute_matmul_outnames(
|
| 134 |
+
const Tensor& self,
|
| 135 |
+
const Tensor& other);
|
| 136 |
+
|
| 137 |
+
TORCH_API std::vector<Dimname> compute_cdist_outnames(
|
| 138 |
+
const Tensor& self,
|
| 139 |
+
const Tensor& other);
|
| 140 |
+
|
| 141 |
+
TORCH_API std::vector<Dimname> compute_bmm_outnames(
|
| 142 |
+
const Tensor& result,
|
| 143 |
+
const Tensor& self,
|
| 144 |
+
const Tensor& other);
|
| 145 |
+
|
| 146 |
+
TORCH_API std::vector<Dimname> compute_squeeze_outnames(const Tensor& tensor);
|
| 147 |
+
TORCH_API std::vector<Dimname> compute_squeeze_outnames(
|
| 148 |
+
const Tensor& tensor,
|
| 149 |
+
std::bitset<dim_bitset_size> dims);
|
| 150 |
+
|
| 151 |
+
std::vector<Dimname> compute_diagonal_outnames(
|
| 152 |
+
const Tensor& tensor,
|
| 153 |
+
int64_t dim1,
|
| 154 |
+
int64_t dim2);
|
| 155 |
+
|
| 156 |
+
// TensorImpl* overloads for Legacy TH/THC code. Use these sparingly.
|
| 157 |
+
|
| 158 |
+
TORCH_API TensorImpl* propagate_names_if_nonempty(
|
| 159 |
+
TensorImpl* result,
|
| 160 |
+
DimnameList maybe_names,
|
| 161 |
+
bool validate_names = false);
|
| 162 |
+
|
| 163 |
+
TORCH_API TensorImpl* propagate_names(
|
| 164 |
+
TensorImpl* result,
|
| 165 |
+
DimnameList names,
|
| 166 |
+
bool validate_names = false);
|
| 167 |
+
|
| 168 |
+
TORCH_API void propagate_names(TensorImpl* result, /*const */ TensorImpl* src);
|
| 169 |
+
|
| 170 |
+
TORCH_API inline void propagate_names(
|
| 171 |
+
const TensorBase& result,
|
| 172 |
+
DimnameList names,
|
| 173 |
+
bool validate_names = false) {
|
| 174 |
+
propagate_names(result.unsafeGetTensorImpl(), names, validate_names);
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
TORCH_API inline void propagate_names_if_nonempty(
|
| 178 |
+
const TensorBase& result,
|
| 179 |
+
DimnameList names,
|
| 180 |
+
bool validate_names = false) {
|
| 181 |
+
propagate_names_if_nonempty(
|
| 182 |
+
result.unsafeGetTensorImpl(), names, validate_names);
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
TORCH_API inline void propagate_names(
|
| 186 |
+
const TensorBase& result,
|
| 187 |
+
const TensorBase& src) {
|
| 188 |
+
propagate_names(result.unsafeGetTensorImpl(), src.unsafeGetTensorImpl());
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
// result = m1 @ m2 + bias
|
| 192 |
+
TORCH_API std::vector<Dimname> propagate_names_for_addmm(
|
| 193 |
+
const Tensor& m1,
|
| 194 |
+
const Tensor& m2,
|
| 195 |
+
const Tensor& bias);
|
| 196 |
+
|
| 197 |
+
TORCH_API std::vector<Dimname> propagate_names_for_addmv(
|
| 198 |
+
const Tensor& mat,
|
| 199 |
+
const Tensor& vec,
|
| 200 |
+
const Tensor& bias);
|
| 201 |
+
|
| 202 |
+
TORCH_API void check_names_for_dot(TensorImpl* vec1, TensorImpl* vec2);
|
| 203 |
+
|
| 204 |
+
TORCH_API std::vector<Dimname> compute_baddbmm_outnames(
|
| 205 |
+
const Tensor& result,
|
| 206 |
+
const Tensor& self,
|
| 207 |
+
const Tensor& other,
|
| 208 |
+
const Tensor& bias);
|
| 209 |
+
|
| 210 |
+
TORCH_API bool are_names_equal(TensorImpl* self, TensorImpl* other);
|
| 211 |
+
|
| 212 |
+
} // namespace namedinference
|
| 213 |
+
|
| 214 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/NativeFunctions.h
ADDED
|
@@ -0,0 +1,1344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from NativeFunctions.h
|
| 4 |
+
|
| 5 |
+
#ifdef TORCH_ASSERT_NO_OPERATORS
|
| 6 |
+
#error This change adds a dependency on native_functions.yaml, \
|
| 7 |
+
meaning the file will need to be re-compiled every time an operator \
|
| 8 |
+
is changed or added. Consider if your change would be better placed in \
|
| 9 |
+
another file, or if a more specific header might achieve the same goal. \
|
| 10 |
+
See NOTE: [Tensor vs. TensorBase]
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 14 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 15 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 16 |
+
Consider including a specific operator from <ATen/ops/{my_operator}_native.h> \
|
| 17 |
+
and see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 18 |
+
#endif
|
| 19 |
+
|
| 20 |
+
#include <c10/core/Scalar.h>
|
| 21 |
+
#include <c10/core/Storage.h>
|
| 22 |
+
#include <c10/core/TensorOptions.h>
|
| 23 |
+
#include <c10/util/Deprecated.h>
|
| 24 |
+
#include <optional>
|
| 25 |
+
#include <c10/core/QScheme.h>
|
| 26 |
+
#include <ATen/core/Reduction.h>
|
| 27 |
+
#include <ATen/core/Tensor.h>
|
| 28 |
+
#include <tuple>
|
| 29 |
+
#include <vector>
|
| 30 |
+
|
| 31 |
+
#include <ATen/ops/_adaptive_avg_pool2d_native.h>
|
| 32 |
+
#include <ATen/ops/_adaptive_avg_pool2d_backward_native.h>
|
| 33 |
+
#include <ATen/ops/_adaptive_avg_pool3d_native.h>
|
| 34 |
+
#include <ATen/ops/_adaptive_avg_pool3d_backward_native.h>
|
| 35 |
+
#include <ATen/ops/_add_batch_dim_native.h>
|
| 36 |
+
#include <ATen/ops/_add_relu_native.h>
|
| 37 |
+
#include <ATen/ops/_addmm_activation_native.h>
|
| 38 |
+
#include <ATen/ops/_aminmax_native.h>
|
| 39 |
+
#include <ATen/ops/_amp_foreach_non_finite_check_and_unscale_native.h>
|
| 40 |
+
#include <ATen/ops/_amp_update_scale_native.h>
|
| 41 |
+
#include <ATen/ops/_assert_async_native.h>
|
| 42 |
+
#include <ATen/ops/_assert_scalar_native.h>
|
| 43 |
+
#include <ATen/ops/_assert_tensor_metadata_native.h>
|
| 44 |
+
#include <ATen/ops/_autocast_to_full_precision_native.h>
|
| 45 |
+
#include <ATen/ops/_autocast_to_reduced_precision_native.h>
|
| 46 |
+
#include <ATen/ops/_backward_native.h>
|
| 47 |
+
#include <ATen/ops/_batch_norm_impl_index_native.h>
|
| 48 |
+
#include <ATen/ops/_batch_norm_impl_index_backward_native.h>
|
| 49 |
+
#include <ATen/ops/_batch_norm_no_update_native.h>
|
| 50 |
+
#include <ATen/ops/_batch_norm_with_update_native.h>
|
| 51 |
+
#include <ATen/ops/_cast_Byte_native.h>
|
| 52 |
+
#include <ATen/ops/_cast_Char_native.h>
|
| 53 |
+
#include <ATen/ops/_cast_Double_native.h>
|
| 54 |
+
#include <ATen/ops/_cast_Float_native.h>
|
| 55 |
+
#include <ATen/ops/_cast_Half_native.h>
|
| 56 |
+
#include <ATen/ops/_cast_Int_native.h>
|
| 57 |
+
#include <ATen/ops/_cast_Long_native.h>
|
| 58 |
+
#include <ATen/ops/_cast_Short_native.h>
|
| 59 |
+
#include <ATen/ops/_cdist_backward_native.h>
|
| 60 |
+
#include <ATen/ops/_cdist_forward_native.h>
|
| 61 |
+
#include <ATen/ops/_cholesky_solve_helper_native.h>
|
| 62 |
+
#include <ATen/ops/_choose_qparams_per_tensor_native.h>
|
| 63 |
+
#include <ATen/ops/_chunk_cat_native.h>
|
| 64 |
+
#include <ATen/ops/_coalesce_native.h>
|
| 65 |
+
#include <ATen/ops/_coalesced_native.h>
|
| 66 |
+
#include <ATen/ops/_compute_linear_combination_native.h>
|
| 67 |
+
#include <ATen/ops/_conj_native.h>
|
| 68 |
+
#include <ATen/ops/_conj_copy_native.h>
|
| 69 |
+
#include <ATen/ops/_conj_physical_native.h>
|
| 70 |
+
#include <ATen/ops/_conv_depthwise2d_native.h>
|
| 71 |
+
#include <ATen/ops/_convert_indices_from_coo_to_csr_native.h>
|
| 72 |
+
#include <ATen/ops/_convert_indices_from_csr_to_coo_native.h>
|
| 73 |
+
#include <ATen/ops/_convert_weight_to_int4pack_native.h>
|
| 74 |
+
#include <ATen/ops/_convolution_native.h>
|
| 75 |
+
#include <ATen/ops/_convolution_double_backward_native.h>
|
| 76 |
+
#include <ATen/ops/_convolution_mode_native.h>
|
| 77 |
+
#include <ATen/ops/_copy_from_native.h>
|
| 78 |
+
#include <ATen/ops/_copy_from_and_resize_native.h>
|
| 79 |
+
#include <ATen/ops/_cslt_compress_native.h>
|
| 80 |
+
#include <ATen/ops/_cslt_sparse_mm_native.h>
|
| 81 |
+
#include <ATen/ops/_cslt_sparse_mm_search_native.h>
|
| 82 |
+
#include <ATen/ops/_ctc_loss_native.h>
|
| 83 |
+
#include <ATen/ops/_ctc_loss_backward_native.h>
|
| 84 |
+
#include <ATen/ops/_cudnn_ctc_loss_native.h>
|
| 85 |
+
#include <ATen/ops/_cudnn_init_dropout_state_native.h>
|
| 86 |
+
#include <ATen/ops/_cudnn_rnn_native.h>
|
| 87 |
+
#include <ATen/ops/_cudnn_rnn_backward_native.h>
|
| 88 |
+
#include <ATen/ops/_cudnn_rnn_flatten_weight_native.h>
|
| 89 |
+
#include <ATen/ops/_cufft_clear_plan_cache_native.h>
|
| 90 |
+
#include <ATen/ops/_cufft_get_plan_cache_max_size_native.h>
|
| 91 |
+
#include <ATen/ops/_cufft_get_plan_cache_size_native.h>
|
| 92 |
+
#include <ATen/ops/_cufft_set_plan_cache_max_size_native.h>
|
| 93 |
+
#include <ATen/ops/_cummax_helper_native.h>
|
| 94 |
+
#include <ATen/ops/_cummin_helper_native.h>
|
| 95 |
+
#include <ATen/ops/_debug_has_internal_overlap_native.h>
|
| 96 |
+
#include <ATen/ops/_dimI_native.h>
|
| 97 |
+
#include <ATen/ops/_dimV_native.h>
|
| 98 |
+
#include <ATen/ops/_dim_arange_native.h>
|
| 99 |
+
#include <ATen/ops/_dirichlet_grad_native.h>
|
| 100 |
+
#include <ATen/ops/_efficient_attention_backward_native.h>
|
| 101 |
+
#include <ATen/ops/_efficient_attention_forward_native.h>
|
| 102 |
+
#include <ATen/ops/_efficientzerotensor_native.h>
|
| 103 |
+
#include <ATen/ops/_embedding_bag_native.h>
|
| 104 |
+
#include <ATen/ops/_embedding_bag_backward_native.h>
|
| 105 |
+
#include <ATen/ops/_embedding_bag_dense_backward_native.h>
|
| 106 |
+
#include <ATen/ops/_embedding_bag_forward_only_native.h>
|
| 107 |
+
#include <ATen/ops/_embedding_bag_per_sample_weights_backward_native.h>
|
| 108 |
+
#include <ATen/ops/_embedding_bag_sparse_backward_native.h>
|
| 109 |
+
#include <ATen/ops/_empty_affine_quantized_native.h>
|
| 110 |
+
#include <ATen/ops/_empty_per_channel_affine_quantized_native.h>
|
| 111 |
+
#include <ATen/ops/_euclidean_dist_native.h>
|
| 112 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_native.h>
|
| 113 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_backward_native.h>
|
| 114 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_native.h>
|
| 115 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_backward_native.h>
|
| 116 |
+
#include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_native.h>
|
| 117 |
+
#include <ATen/ops/_fft_c2c_native.h>
|
| 118 |
+
#include <ATen/ops/_fft_c2r_native.h>
|
| 119 |
+
#include <ATen/ops/_fft_r2c_native.h>
|
| 120 |
+
#include <ATen/ops/_fill_mem_eff_dropout_mask_native.h>
|
| 121 |
+
#include <ATen/ops/_flash_attention_backward_native.h>
|
| 122 |
+
#include <ATen/ops/_flash_attention_forward_native.h>
|
| 123 |
+
#include <ATen/ops/_foobar_native.h>
|
| 124 |
+
#include <ATen/ops/_foreach_abs_native.h>
|
| 125 |
+
#include <ATen/ops/_foreach_acos_native.h>
|
| 126 |
+
#include <ATen/ops/_foreach_add_native.h>
|
| 127 |
+
#include <ATen/ops/_foreach_addcdiv_native.h>
|
| 128 |
+
#include <ATen/ops/_foreach_addcmul_native.h>
|
| 129 |
+
#include <ATen/ops/_foreach_asin_native.h>
|
| 130 |
+
#include <ATen/ops/_foreach_atan_native.h>
|
| 131 |
+
#include <ATen/ops/_foreach_ceil_native.h>
|
| 132 |
+
#include <ATen/ops/_foreach_clamp_max_native.h>
|
| 133 |
+
#include <ATen/ops/_foreach_clamp_min_native.h>
|
| 134 |
+
#include <ATen/ops/_foreach_copy_native.h>
|
| 135 |
+
#include <ATen/ops/_foreach_cos_native.h>
|
| 136 |
+
#include <ATen/ops/_foreach_cosh_native.h>
|
| 137 |
+
#include <ATen/ops/_foreach_div_native.h>
|
| 138 |
+
#include <ATen/ops/_foreach_erf_native.h>
|
| 139 |
+
#include <ATen/ops/_foreach_erfc_native.h>
|
| 140 |
+
#include <ATen/ops/_foreach_exp_native.h>
|
| 141 |
+
#include <ATen/ops/_foreach_expm1_native.h>
|
| 142 |
+
#include <ATen/ops/_foreach_floor_native.h>
|
| 143 |
+
#include <ATen/ops/_foreach_frac_native.h>
|
| 144 |
+
#include <ATen/ops/_foreach_lerp_native.h>
|
| 145 |
+
#include <ATen/ops/_foreach_lgamma_native.h>
|
| 146 |
+
#include <ATen/ops/_foreach_log_native.h>
|
| 147 |
+
#include <ATen/ops/_foreach_log10_native.h>
|
| 148 |
+
#include <ATen/ops/_foreach_log1p_native.h>
|
| 149 |
+
#include <ATen/ops/_foreach_log2_native.h>
|
| 150 |
+
#include <ATen/ops/_foreach_max_native.h>
|
| 151 |
+
#include <ATen/ops/_foreach_maximum_native.h>
|
| 152 |
+
#include <ATen/ops/_foreach_minimum_native.h>
|
| 153 |
+
#include <ATen/ops/_foreach_mul_native.h>
|
| 154 |
+
#include <ATen/ops/_foreach_neg_native.h>
|
| 155 |
+
#include <ATen/ops/_foreach_norm_native.h>
|
| 156 |
+
#include <ATen/ops/_foreach_pow_native.h>
|
| 157 |
+
#include <ATen/ops/_foreach_reciprocal_native.h>
|
| 158 |
+
#include <ATen/ops/_foreach_round_native.h>
|
| 159 |
+
#include <ATen/ops/_foreach_sigmoid_native.h>
|
| 160 |
+
#include <ATen/ops/_foreach_sign_native.h>
|
| 161 |
+
#include <ATen/ops/_foreach_sin_native.h>
|
| 162 |
+
#include <ATen/ops/_foreach_sinh_native.h>
|
| 163 |
+
#include <ATen/ops/_foreach_sqrt_native.h>
|
| 164 |
+
#include <ATen/ops/_foreach_sub_native.h>
|
| 165 |
+
#include <ATen/ops/_foreach_tan_native.h>
|
| 166 |
+
#include <ATen/ops/_foreach_tanh_native.h>
|
| 167 |
+
#include <ATen/ops/_foreach_trunc_native.h>
|
| 168 |
+
#include <ATen/ops/_foreach_zero_native.h>
|
| 169 |
+
#include <ATen/ops/_functional_assert_async_native.h>
|
| 170 |
+
#include <ATen/ops/_functional_assert_scalar_native.h>
|
| 171 |
+
#include <ATen/ops/_functional_sym_constrain_range_native.h>
|
| 172 |
+
#include <ATen/ops/_functional_sym_constrain_range_for_size_native.h>
|
| 173 |
+
#include <ATen/ops/_fused_adagrad_native.h>
|
| 174 |
+
#include <ATen/ops/_fused_adam_native.h>
|
| 175 |
+
#include <ATen/ops/_fused_adamw_native.h>
|
| 176 |
+
#include <ATen/ops/_fused_dropout_native.h>
|
| 177 |
+
#include <ATen/ops/_fused_moving_avg_obs_fq_helper_native.h>
|
| 178 |
+
#include <ATen/ops/_fused_sdp_choice_native.h>
|
| 179 |
+
#include <ATen/ops/_fused_sgd_native.h>
|
| 180 |
+
#include <ATen/ops/_fw_primal_native.h>
|
| 181 |
+
#include <ATen/ops/_fw_primal_copy_native.h>
|
| 182 |
+
#include <ATen/ops/_gather_sparse_backward_native.h>
|
| 183 |
+
#include <ATen/ops/_grid_sampler_2d_cpu_fallback_native.h>
|
| 184 |
+
#include <ATen/ops/_grid_sampler_2d_cpu_fallback_backward_native.h>
|
| 185 |
+
#include <ATen/ops/_has_compatible_shallow_copy_type_native.h>
|
| 186 |
+
#include <ATen/ops/_has_same_storage_numel_native.h>
|
| 187 |
+
#include <ATen/ops/_histogramdd_bin_edges_native.h>
|
| 188 |
+
#include <ATen/ops/_histogramdd_from_bin_cts_native.h>
|
| 189 |
+
#include <ATen/ops/_histogramdd_from_bin_tensors_native.h>
|
| 190 |
+
#include <ATen/ops/_index_put_impl_native.h>
|
| 191 |
+
#include <ATen/ops/_indices_native.h>
|
| 192 |
+
#include <ATen/ops/_indices_copy_native.h>
|
| 193 |
+
#include <ATen/ops/_int_mm_native.h>
|
| 194 |
+
#include <ATen/ops/_is_all_true_native.h>
|
| 195 |
+
#include <ATen/ops/_is_any_true_native.h>
|
| 196 |
+
#include <ATen/ops/_is_zerotensor_native.h>
|
| 197 |
+
#include <ATen/ops/_jagged_to_padded_dense_forward_native.h>
|
| 198 |
+
#include <ATen/ops/_lazy_clone_native.h>
|
| 199 |
+
#include <ATen/ops/_linalg_check_errors_native.h>
|
| 200 |
+
#include <ATen/ops/_linalg_det_native.h>
|
| 201 |
+
#include <ATen/ops/_linalg_eigh_native.h>
|
| 202 |
+
#include <ATen/ops/_linalg_eigvals_native.h>
|
| 203 |
+
#include <ATen/ops/_linalg_slogdet_native.h>
|
| 204 |
+
#include <ATen/ops/_linalg_solve_ex_native.h>
|
| 205 |
+
#include <ATen/ops/_linalg_svd_native.h>
|
| 206 |
+
#include <ATen/ops/_local_scalar_dense_native.h>
|
| 207 |
+
#include <ATen/ops/_log_softmax_native.h>
|
| 208 |
+
#include <ATen/ops/_log_softmax_backward_data_native.h>
|
| 209 |
+
#include <ATen/ops/_logcumsumexp_native.h>
|
| 210 |
+
#include <ATen/ops/_lstm_mps_native.h>
|
| 211 |
+
#include <ATen/ops/_lu_with_info_native.h>
|
| 212 |
+
#include <ATen/ops/_make_dep_token_native.h>
|
| 213 |
+
#include <ATen/ops/_make_dual_native.h>
|
| 214 |
+
#include <ATen/ops/_make_dual_copy_native.h>
|
| 215 |
+
#include <ATen/ops/_make_per_channel_quantized_tensor_native.h>
|
| 216 |
+
#include <ATen/ops/_make_per_tensor_quantized_tensor_native.h>
|
| 217 |
+
#include <ATen/ops/_masked_scale_native.h>
|
| 218 |
+
#include <ATen/ops/_masked_softmax_native.h>
|
| 219 |
+
#include <ATen/ops/_masked_softmax_backward_native.h>
|
| 220 |
+
#include <ATen/ops/_mixed_dtypes_linear_native.h>
|
| 221 |
+
#include <ATen/ops/_mkldnn_reshape_native.h>
|
| 222 |
+
#include <ATen/ops/_mkldnn_transpose_native.h>
|
| 223 |
+
#include <ATen/ops/_mps_convolution_native.h>
|
| 224 |
+
#include <ATen/ops/_mps_convolution_transpose_native.h>
|
| 225 |
+
#include <ATen/ops/_native_batch_norm_legit_native.h>
|
| 226 |
+
#include <ATen/ops/_native_batch_norm_legit_no_training_native.h>
|
| 227 |
+
#include <ATen/ops/_native_multi_head_attention_native.h>
|
| 228 |
+
#include <ATen/ops/_neg_view_native.h>
|
| 229 |
+
#include <ATen/ops/_neg_view_copy_native.h>
|
| 230 |
+
#include <ATen/ops/_nested_compute_contiguous_strides_offsets_native.h>
|
| 231 |
+
#include <ATen/ops/_nested_from_padded_native.h>
|
| 232 |
+
#include <ATen/ops/_nested_from_padded_and_nested_example_native.h>
|
| 233 |
+
#include <ATen/ops/_nested_get_jagged_dummy_native.h>
|
| 234 |
+
#include <ATen/ops/_nested_get_lengths_native.h>
|
| 235 |
+
#include <ATen/ops/_nested_get_max_seqlen_native.h>
|
| 236 |
+
#include <ATen/ops/_nested_get_min_seqlen_native.h>
|
| 237 |
+
#include <ATen/ops/_nested_get_offsets_native.h>
|
| 238 |
+
#include <ATen/ops/_nested_get_ragged_idx_native.h>
|
| 239 |
+
#include <ATen/ops/_nested_get_values_native.h>
|
| 240 |
+
#include <ATen/ops/_nested_get_values_copy_native.h>
|
| 241 |
+
#include <ATen/ops/_nested_select_backward_native.h>
|
| 242 |
+
#include <ATen/ops/_nested_sum_backward_native.h>
|
| 243 |
+
#include <ATen/ops/_nested_tensor_from_mask_native.h>
|
| 244 |
+
#include <ATen/ops/_nested_tensor_from_mask_left_aligned_native.h>
|
| 245 |
+
#include <ATen/ops/_nested_tensor_from_tensor_list_native.h>
|
| 246 |
+
#include <ATen/ops/_nested_tensor_size_native.h>
|
| 247 |
+
#include <ATen/ops/_nested_tensor_softmax_with_shape_native.h>
|
| 248 |
+
#include <ATen/ops/_nested_tensor_storage_offsets_native.h>
|
| 249 |
+
#include <ATen/ops/_nested_tensor_strides_native.h>
|
| 250 |
+
#include <ATen/ops/_nested_view_from_buffer_native.h>
|
| 251 |
+
#include <ATen/ops/_nested_view_from_buffer_copy_native.h>
|
| 252 |
+
#include <ATen/ops/_nested_view_from_jagged_native.h>
|
| 253 |
+
#include <ATen/ops/_nested_view_from_jagged_copy_native.h>
|
| 254 |
+
#include <ATen/ops/_new_zeros_with_same_feature_meta_native.h>
|
| 255 |
+
#include <ATen/ops/_nnpack_available_native.h>
|
| 256 |
+
#include <ATen/ops/_nnpack_spatial_convolution_native.h>
|
| 257 |
+
#include <ATen/ops/_nnz_native.h>
|
| 258 |
+
#include <ATen/ops/_pack_padded_sequence_native.h>
|
| 259 |
+
#include <ATen/ops/_pack_padded_sequence_backward_native.h>
|
| 260 |
+
#include <ATen/ops/_pad_circular_native.h>
|
| 261 |
+
#include <ATen/ops/_pad_enum_native.h>
|
| 262 |
+
#include <ATen/ops/_pad_packed_sequence_native.h>
|
| 263 |
+
#include <ATen/ops/_padded_dense_to_jagged_forward_native.h>
|
| 264 |
+
#include <ATen/ops/_pdist_backward_native.h>
|
| 265 |
+
#include <ATen/ops/_pdist_forward_native.h>
|
| 266 |
+
#include <ATen/ops/_pin_memory_native.h>
|
| 267 |
+
#include <ATen/ops/_prelu_kernel_native.h>
|
| 268 |
+
#include <ATen/ops/_prelu_kernel_backward_native.h>
|
| 269 |
+
#include <ATen/ops/_print_native.h>
|
| 270 |
+
#include <ATen/ops/_propagate_xla_data_native.h>
|
| 271 |
+
#include <ATen/ops/_remove_batch_dim_native.h>
|
| 272 |
+
#include <ATen/ops/_reshape_alias_native.h>
|
| 273 |
+
#include <ATen/ops/_reshape_alias_copy_native.h>
|
| 274 |
+
#include <ATen/ops/_reshape_copy_native.h>
|
| 275 |
+
#include <ATen/ops/_reshape_from_tensor_native.h>
|
| 276 |
+
#include <ATen/ops/_resize_output_native.h>
|
| 277 |
+
#include <ATen/ops/_rowwise_prune_native.h>
|
| 278 |
+
#include <ATen/ops/_safe_softmax_native.h>
|
| 279 |
+
#include <ATen/ops/_sample_dirichlet_native.h>
|
| 280 |
+
#include <ATen/ops/_saturate_weight_to_fp16_native.h>
|
| 281 |
+
#include <ATen/ops/_scaled_dot_product_attention_math_native.h>
|
| 282 |
+
#include <ATen/ops/_scaled_dot_product_attention_math_for_mps_native.h>
|
| 283 |
+
#include <ATen/ops/_scaled_dot_product_cudnn_attention_native.h>
|
| 284 |
+
#include <ATen/ops/_scaled_dot_product_cudnn_attention_backward_native.h>
|
| 285 |
+
#include <ATen/ops/_scaled_dot_product_efficient_attention_native.h>
|
| 286 |
+
#include <ATen/ops/_scaled_dot_product_efficient_attention_backward_native.h>
|
| 287 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_native.h>
|
| 288 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_backward_native.h>
|
| 289 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_native.h>
|
| 290 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward_native.h>
|
| 291 |
+
#include <ATen/ops/_scaled_dot_product_fused_attention_overrideable_native.h>
|
| 292 |
+
#include <ATen/ops/_scaled_dot_product_fused_attention_overrideable_backward_native.h>
|
| 293 |
+
#include <ATen/ops/_scaled_mm_native.h>
|
| 294 |
+
#include <ATen/ops/_segment_reduce_backward_native.h>
|
| 295 |
+
#include <ATen/ops/_shape_as_tensor_native.h>
|
| 296 |
+
#include <ATen/ops/_slow_conv2d_backward_native.h>
|
| 297 |
+
#include <ATen/ops/_slow_conv2d_forward_native.h>
|
| 298 |
+
#include <ATen/ops/_sobol_engine_draw_native.h>
|
| 299 |
+
#include <ATen/ops/_sobol_engine_ff_native.h>
|
| 300 |
+
#include <ATen/ops/_sobol_engine_initialize_state_native.h>
|
| 301 |
+
#include <ATen/ops/_sobol_engine_scramble_native.h>
|
| 302 |
+
#include <ATen/ops/_softmax_native.h>
|
| 303 |
+
#include <ATen/ops/_softmax_backward_data_native.h>
|
| 304 |
+
#include <ATen/ops/_sparse_addmm_native.h>
|
| 305 |
+
#include <ATen/ops/_sparse_broadcast_to_native.h>
|
| 306 |
+
#include <ATen/ops/_sparse_broadcast_to_copy_native.h>
|
| 307 |
+
#include <ATen/ops/_sparse_bsc_tensor_unsafe_native.h>
|
| 308 |
+
#include <ATen/ops/_sparse_bsr_tensor_unsafe_native.h>
|
| 309 |
+
#include <ATen/ops/_sparse_compressed_tensor_unsafe_native.h>
|
| 310 |
+
#include <ATen/ops/_sparse_compressed_tensor_with_dims_native.h>
|
| 311 |
+
#include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
|
| 312 |
+
#include <ATen/ops/_sparse_coo_tensor_with_dims_native.h>
|
| 313 |
+
#include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors_native.h>
|
| 314 |
+
#include <ATen/ops/_sparse_csc_tensor_unsafe_native.h>
|
| 315 |
+
#include <ATen/ops/_sparse_csr_prod_native.h>
|
| 316 |
+
#include <ATen/ops/_sparse_csr_sum_native.h>
|
| 317 |
+
#include <ATen/ops/_sparse_csr_tensor_unsafe_native.h>
|
| 318 |
+
#include <ATen/ops/_sparse_log_softmax_native.h>
|
| 319 |
+
#include <ATen/ops/_sparse_log_softmax_backward_data_native.h>
|
| 320 |
+
#include <ATen/ops/_sparse_mask_projection_native.h>
|
| 321 |
+
#include <ATen/ops/_sparse_mm_native.h>
|
| 322 |
+
#include <ATen/ops/_sparse_mm_reduce_impl_native.h>
|
| 323 |
+
#include <ATen/ops/_sparse_mm_reduce_impl_backward_native.h>
|
| 324 |
+
#include <ATen/ops/_sparse_semi_structured_addmm_native.h>
|
| 325 |
+
#include <ATen/ops/_sparse_semi_structured_apply_native.h>
|
| 326 |
+
#include <ATen/ops/_sparse_semi_structured_apply_dense_native.h>
|
| 327 |
+
#include <ATen/ops/_sparse_semi_structured_linear_native.h>
|
| 328 |
+
#include <ATen/ops/_sparse_semi_structured_mm_native.h>
|
| 329 |
+
#include <ATen/ops/_sparse_semi_structured_tile_native.h>
|
| 330 |
+
#include <ATen/ops/_sparse_softmax_native.h>
|
| 331 |
+
#include <ATen/ops/_sparse_softmax_backward_data_native.h>
|
| 332 |
+
#include <ATen/ops/_sparse_sparse_matmul_native.h>
|
| 333 |
+
#include <ATen/ops/_sparse_sum_native.h>
|
| 334 |
+
#include <ATen/ops/_sparse_sum_backward_native.h>
|
| 335 |
+
#include <ATen/ops/_spdiags_native.h>
|
| 336 |
+
#include <ATen/ops/_spsolve_native.h>
|
| 337 |
+
#include <ATen/ops/_stack_native.h>
|
| 338 |
+
#include <ATen/ops/_standard_gamma_native.h>
|
| 339 |
+
#include <ATen/ops/_standard_gamma_grad_native.h>
|
| 340 |
+
#include <ATen/ops/_test_ambiguous_defaults_native.h>
|
| 341 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_native.h>
|
| 342 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_view_native.h>
|
| 343 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_view_copy_native.h>
|
| 344 |
+
#include <ATen/ops/_test_check_tensor_native.h>
|
| 345 |
+
#include <ATen/ops/_test_functorch_fallback_native.h>
|
| 346 |
+
#include <ATen/ops/_test_optional_filled_intlist_native.h>
|
| 347 |
+
#include <ATen/ops/_test_optional_floatlist_native.h>
|
| 348 |
+
#include <ATen/ops/_test_optional_intlist_native.h>
|
| 349 |
+
#include <ATen/ops/_test_parallel_materialize_native.h>
|
| 350 |
+
#include <ATen/ops/_test_serialization_subcmul_native.h>
|
| 351 |
+
#include <ATen/ops/_test_string_default_native.h>
|
| 352 |
+
#include <ATen/ops/_test_warn_in_autograd_native.h>
|
| 353 |
+
#include <ATen/ops/_thnn_differentiable_gru_cell_backward_native.h>
|
| 354 |
+
#include <ATen/ops/_thnn_differentiable_lstm_cell_backward_native.h>
|
| 355 |
+
#include <ATen/ops/_thnn_fused_gru_cell_native.h>
|
| 356 |
+
#include <ATen/ops/_thnn_fused_gru_cell_backward_native.h>
|
| 357 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_native.h>
|
| 358 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_backward_native.h>
|
| 359 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_backward_impl_native.h>
|
| 360 |
+
#include <ATen/ops/_to_copy_native.h>
|
| 361 |
+
#include <ATen/ops/_to_cpu_native.h>
|
| 362 |
+
#include <ATen/ops/_to_dense_native.h>
|
| 363 |
+
#include <ATen/ops/_to_sparse_native.h>
|
| 364 |
+
#include <ATen/ops/_to_sparse_bsc_native.h>
|
| 365 |
+
#include <ATen/ops/_to_sparse_bsr_native.h>
|
| 366 |
+
#include <ATen/ops/_to_sparse_csc_native.h>
|
| 367 |
+
#include <ATen/ops/_to_sparse_csr_native.h>
|
| 368 |
+
#include <ATen/ops/_to_sparse_semi_structured_native.h>
|
| 369 |
+
#include <ATen/ops/_transform_bias_rescale_qkv_native.h>
|
| 370 |
+
#include <ATen/ops/_transformer_encoder_layer_fwd_native.h>
|
| 371 |
+
#include <ATen/ops/_trilinear_native.h>
|
| 372 |
+
#include <ATen/ops/_triton_multi_head_attention_native.h>
|
| 373 |
+
#include <ATen/ops/_triton_scaled_dot_attention_native.h>
|
| 374 |
+
#include <ATen/ops/_unique_native.h>
|
| 375 |
+
#include <ATen/ops/_unique2_native.h>
|
| 376 |
+
#include <ATen/ops/_unpack_dual_native.h>
|
| 377 |
+
#include <ATen/ops/_unsafe_index_native.h>
|
| 378 |
+
#include <ATen/ops/_unsafe_index_put_native.h>
|
| 379 |
+
#include <ATen/ops/_unsafe_masked_index_native.h>
|
| 380 |
+
#include <ATen/ops/_unsafe_masked_index_put_accumulate_native.h>
|
| 381 |
+
#include <ATen/ops/_unsafe_view_native.h>
|
| 382 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_native.h>
|
| 383 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_backward_native.h>
|
| 384 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_native.h>
|
| 385 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_backward_native.h>
|
| 386 |
+
#include <ATen/ops/_upsample_nearest_exact1d_native.h>
|
| 387 |
+
#include <ATen/ops/_upsample_nearest_exact1d_backward_native.h>
|
| 388 |
+
#include <ATen/ops/_upsample_nearest_exact2d_native.h>
|
| 389 |
+
#include <ATen/ops/_upsample_nearest_exact2d_backward_native.h>
|
| 390 |
+
#include <ATen/ops/_upsample_nearest_exact3d_native.h>
|
| 391 |
+
#include <ATen/ops/_upsample_nearest_exact3d_backward_native.h>
|
| 392 |
+
#include <ATen/ops/_use_cudnn_ctc_loss_native.h>
|
| 393 |
+
#include <ATen/ops/_use_cudnn_rnn_flatten_weight_native.h>
|
| 394 |
+
#include <ATen/ops/_validate_compressed_sparse_indices_native.h>
|
| 395 |
+
#include <ATen/ops/_validate_sparse_bsc_tensor_args_native.h>
|
| 396 |
+
#include <ATen/ops/_validate_sparse_bsr_tensor_args_native.h>
|
| 397 |
+
#include <ATen/ops/_validate_sparse_compressed_tensor_args_native.h>
|
| 398 |
+
#include <ATen/ops/_validate_sparse_coo_tensor_args_native.h>
|
| 399 |
+
#include <ATen/ops/_validate_sparse_csc_tensor_args_native.h>
|
| 400 |
+
#include <ATen/ops/_validate_sparse_csr_tensor_args_native.h>
|
| 401 |
+
#include <ATen/ops/_values_native.h>
|
| 402 |
+
#include <ATen/ops/_values_copy_native.h>
|
| 403 |
+
#include <ATen/ops/_version_native.h>
|
| 404 |
+
#include <ATen/ops/_weight_int4pack_mm_native.h>
|
| 405 |
+
#include <ATen/ops/_weight_int8pack_mm_native.h>
|
| 406 |
+
#include <ATen/ops/_weight_norm_native.h>
|
| 407 |
+
#include <ATen/ops/_weight_norm_differentiable_backward_native.h>
|
| 408 |
+
#include <ATen/ops/_weight_norm_interface_native.h>
|
| 409 |
+
#include <ATen/ops/_weight_norm_interface_backward_native.h>
|
| 410 |
+
#include <ATen/ops/_wrapped_linear_prepack_native.h>
|
| 411 |
+
#include <ATen/ops/_wrapped_quantized_linear_prepacked_native.h>
|
| 412 |
+
#include <ATen/ops/abs_native.h>
|
| 413 |
+
#include <ATen/ops/absolute_native.h>
|
| 414 |
+
#include <ATen/ops/acos_native.h>
|
| 415 |
+
#include <ATen/ops/acosh_native.h>
|
| 416 |
+
#include <ATen/ops/adaptive_avg_pool1d_native.h>
|
| 417 |
+
#include <ATen/ops/adaptive_avg_pool2d_native.h>
|
| 418 |
+
#include <ATen/ops/adaptive_avg_pool3d_native.h>
|
| 419 |
+
#include <ATen/ops/adaptive_avg_pool3d_backward_native.h>
|
| 420 |
+
#include <ATen/ops/adaptive_max_pool1d_native.h>
|
| 421 |
+
#include <ATen/ops/adaptive_max_pool2d_native.h>
|
| 422 |
+
#include <ATen/ops/adaptive_max_pool2d_backward_native.h>
|
| 423 |
+
#include <ATen/ops/adaptive_max_pool3d_native.h>
|
| 424 |
+
#include <ATen/ops/adaptive_max_pool3d_backward_native.h>
|
| 425 |
+
#include <ATen/ops/add_native.h>
|
| 426 |
+
#include <ATen/ops/addbmm_native.h>
|
| 427 |
+
#include <ATen/ops/addcdiv_native.h>
|
| 428 |
+
#include <ATen/ops/addcmul_native.h>
|
| 429 |
+
#include <ATen/ops/addmm_native.h>
|
| 430 |
+
#include <ATen/ops/addmv_native.h>
|
| 431 |
+
#include <ATen/ops/addr_native.h>
|
| 432 |
+
#include <ATen/ops/adjoint_native.h>
|
| 433 |
+
#include <ATen/ops/affine_grid_generator_native.h>
|
| 434 |
+
#include <ATen/ops/affine_grid_generator_backward_native.h>
|
| 435 |
+
#include <ATen/ops/alias_native.h>
|
| 436 |
+
#include <ATen/ops/alias_copy_native.h>
|
| 437 |
+
#include <ATen/ops/align_as_native.h>
|
| 438 |
+
#include <ATen/ops/align_tensors_native.h>
|
| 439 |
+
#include <ATen/ops/align_to_native.h>
|
| 440 |
+
#include <ATen/ops/all_native.h>
|
| 441 |
+
#include <ATen/ops/allclose_native.h>
|
| 442 |
+
#include <ATen/ops/alpha_dropout_native.h>
|
| 443 |
+
#include <ATen/ops/amax_native.h>
|
| 444 |
+
#include <ATen/ops/amin_native.h>
|
| 445 |
+
#include <ATen/ops/aminmax_native.h>
|
| 446 |
+
#include <ATen/ops/and_native.h>
|
| 447 |
+
#include <ATen/ops/angle_native.h>
|
| 448 |
+
#include <ATen/ops/any_native.h>
|
| 449 |
+
#include <ATen/ops/arange_native.h>
|
| 450 |
+
#include <ATen/ops/arccos_native.h>
|
| 451 |
+
#include <ATen/ops/arccosh_native.h>
|
| 452 |
+
#include <ATen/ops/arcsin_native.h>
|
| 453 |
+
#include <ATen/ops/arcsinh_native.h>
|
| 454 |
+
#include <ATen/ops/arctan_native.h>
|
| 455 |
+
#include <ATen/ops/arctan2_native.h>
|
| 456 |
+
#include <ATen/ops/arctanh_native.h>
|
| 457 |
+
#include <ATen/ops/argmax_native.h>
|
| 458 |
+
#include <ATen/ops/argmin_native.h>
|
| 459 |
+
#include <ATen/ops/argsort_native.h>
|
| 460 |
+
#include <ATen/ops/argwhere_native.h>
|
| 461 |
+
#include <ATen/ops/as_strided_native.h>
|
| 462 |
+
#include <ATen/ops/as_strided_copy_native.h>
|
| 463 |
+
#include <ATen/ops/as_strided_scatter_native.h>
|
| 464 |
+
#include <ATen/ops/asin_native.h>
|
| 465 |
+
#include <ATen/ops/asinh_native.h>
|
| 466 |
+
#include <ATen/ops/atan_native.h>
|
| 467 |
+
#include <ATen/ops/atan2_native.h>
|
| 468 |
+
#include <ATen/ops/atanh_native.h>
|
| 469 |
+
#include <ATen/ops/atleast_1d_native.h>
|
| 470 |
+
#include <ATen/ops/atleast_2d_native.h>
|
| 471 |
+
#include <ATen/ops/atleast_3d_native.h>
|
| 472 |
+
#include <ATen/ops/avg_pool1d_native.h>
|
| 473 |
+
#include <ATen/ops/avg_pool2d_native.h>
|
| 474 |
+
#include <ATen/ops/avg_pool2d_backward_native.h>
|
| 475 |
+
#include <ATen/ops/avg_pool3d_native.h>
|
| 476 |
+
#include <ATen/ops/avg_pool3d_backward_native.h>
|
| 477 |
+
#include <ATen/ops/baddbmm_native.h>
|
| 478 |
+
#include <ATen/ops/bartlett_window_native.h>
|
| 479 |
+
#include <ATen/ops/batch_norm_native.h>
|
| 480 |
+
#include <ATen/ops/batch_norm_backward_native.h>
|
| 481 |
+
#include <ATen/ops/batch_norm_backward_elemt_native.h>
|
| 482 |
+
#include <ATen/ops/batch_norm_backward_reduce_native.h>
|
| 483 |
+
#include <ATen/ops/batch_norm_elemt_native.h>
|
| 484 |
+
#include <ATen/ops/batch_norm_gather_stats_native.h>
|
| 485 |
+
#include <ATen/ops/batch_norm_gather_stats_with_counts_native.h>
|
| 486 |
+
#include <ATen/ops/batch_norm_stats_native.h>
|
| 487 |
+
#include <ATen/ops/batch_norm_update_stats_native.h>
|
| 488 |
+
#include <ATen/ops/bernoulli_native.h>
|
| 489 |
+
#include <ATen/ops/bilinear_native.h>
|
| 490 |
+
#include <ATen/ops/binary_cross_entropy_native.h>
|
| 491 |
+
#include <ATen/ops/binary_cross_entropy_backward_native.h>
|
| 492 |
+
#include <ATen/ops/binary_cross_entropy_with_logits_native.h>
|
| 493 |
+
#include <ATen/ops/bincount_native.h>
|
| 494 |
+
#include <ATen/ops/binomial_native.h>
|
| 495 |
+
#include <ATen/ops/bitwise_and_native.h>
|
| 496 |
+
#include <ATen/ops/bitwise_left_shift_native.h>
|
| 497 |
+
#include <ATen/ops/bitwise_not_native.h>
|
| 498 |
+
#include <ATen/ops/bitwise_or_native.h>
|
| 499 |
+
#include <ATen/ops/bitwise_right_shift_native.h>
|
| 500 |
+
#include <ATen/ops/bitwise_xor_native.h>
|
| 501 |
+
#include <ATen/ops/blackman_window_native.h>
|
| 502 |
+
#include <ATen/ops/block_diag_native.h>
|
| 503 |
+
#include <ATen/ops/bmm_native.h>
|
| 504 |
+
#include <ATen/ops/broadcast_tensors_native.h>
|
| 505 |
+
#include <ATen/ops/broadcast_to_native.h>
|
| 506 |
+
#include <ATen/ops/bucketize_native.h>
|
| 507 |
+
#include <ATen/ops/can_cast_native.h>
|
| 508 |
+
#include <ATen/ops/cartesian_prod_native.h>
|
| 509 |
+
#include <ATen/ops/cat_native.h>
|
| 510 |
+
#include <ATen/ops/cauchy_native.h>
|
| 511 |
+
#include <ATen/ops/ccol_indices_native.h>
|
| 512 |
+
#include <ATen/ops/ccol_indices_copy_native.h>
|
| 513 |
+
#include <ATen/ops/cdist_native.h>
|
| 514 |
+
#include <ATen/ops/ceil_native.h>
|
| 515 |
+
#include <ATen/ops/celu_native.h>
|
| 516 |
+
#include <ATen/ops/chain_matmul_native.h>
|
| 517 |
+
#include <ATen/ops/chalf_native.h>
|
| 518 |
+
#include <ATen/ops/channel_shuffle_native.h>
|
| 519 |
+
#include <ATen/ops/cholesky_native.h>
|
| 520 |
+
#include <ATen/ops/cholesky_inverse_native.h>
|
| 521 |
+
#include <ATen/ops/cholesky_solve_native.h>
|
| 522 |
+
#include <ATen/ops/choose_qparams_optimized_native.h>
|
| 523 |
+
#include <ATen/ops/chunk_native.h>
|
| 524 |
+
#include <ATen/ops/clamp_native.h>
|
| 525 |
+
#include <ATen/ops/clamp_max_native.h>
|
| 526 |
+
#include <ATen/ops/clamp_min_native.h>
|
| 527 |
+
#include <ATen/ops/clip_native.h>
|
| 528 |
+
#include <ATen/ops/clone_native.h>
|
| 529 |
+
#include <ATen/ops/coalesce_native.h>
|
| 530 |
+
#include <ATen/ops/col2im_native.h>
|
| 531 |
+
#include <ATen/ops/col_indices_native.h>
|
| 532 |
+
#include <ATen/ops/col_indices_copy_native.h>
|
| 533 |
+
#include <ATen/ops/column_stack_native.h>
|
| 534 |
+
#include <ATen/ops/combinations_native.h>
|
| 535 |
+
#include <ATen/ops/complex_native.h>
|
| 536 |
+
#include <ATen/ops/concat_native.h>
|
| 537 |
+
#include <ATen/ops/concatenate_native.h>
|
| 538 |
+
#include <ATen/ops/conj_native.h>
|
| 539 |
+
#include <ATen/ops/conj_physical_native.h>
|
| 540 |
+
#include <ATen/ops/constant_pad_nd_native.h>
|
| 541 |
+
#include <ATen/ops/contiguous_native.h>
|
| 542 |
+
#include <ATen/ops/conv1d_native.h>
|
| 543 |
+
#include <ATen/ops/conv2d_native.h>
|
| 544 |
+
#include <ATen/ops/conv3d_native.h>
|
| 545 |
+
#include <ATen/ops/conv_depthwise3d_native.h>
|
| 546 |
+
#include <ATen/ops/conv_tbc_native.h>
|
| 547 |
+
#include <ATen/ops/conv_tbc_backward_native.h>
|
| 548 |
+
#include <ATen/ops/conv_transpose1d_native.h>
|
| 549 |
+
#include <ATen/ops/conv_transpose2d_native.h>
|
| 550 |
+
#include <ATen/ops/conv_transpose3d_native.h>
|
| 551 |
+
#include <ATen/ops/convolution_native.h>
|
| 552 |
+
#include <ATen/ops/convolution_backward_native.h>
|
| 553 |
+
#include <ATen/ops/convolution_backward_overrideable_native.h>
|
| 554 |
+
#include <ATen/ops/convolution_overrideable_native.h>
|
| 555 |
+
#include <ATen/ops/copy_native.h>
|
| 556 |
+
#include <ATen/ops/copy_sparse_to_sparse_native.h>
|
| 557 |
+
#include <ATen/ops/copysign_native.h>
|
| 558 |
+
#include <ATen/ops/corrcoef_native.h>
|
| 559 |
+
#include <ATen/ops/cos_native.h>
|
| 560 |
+
#include <ATen/ops/cosh_native.h>
|
| 561 |
+
#include <ATen/ops/cosine_embedding_loss_native.h>
|
| 562 |
+
#include <ATen/ops/cosine_similarity_native.h>
|
| 563 |
+
#include <ATen/ops/count_nonzero_native.h>
|
| 564 |
+
#include <ATen/ops/cov_native.h>
|
| 565 |
+
#include <ATen/ops/cross_native.h>
|
| 566 |
+
#include <ATen/ops/cross_entropy_loss_native.h>
|
| 567 |
+
#include <ATen/ops/crow_indices_native.h>
|
| 568 |
+
#include <ATen/ops/crow_indices_copy_native.h>
|
| 569 |
+
#include <ATen/ops/ctc_loss_native.h>
|
| 570 |
+
#include <ATen/ops/cudnn_affine_grid_generator_native.h>
|
| 571 |
+
#include <ATen/ops/cudnn_affine_grid_generator_backward_native.h>
|
| 572 |
+
#include <ATen/ops/cudnn_batch_norm_native.h>
|
| 573 |
+
#include <ATen/ops/cudnn_batch_norm_backward_native.h>
|
| 574 |
+
#include <ATen/ops/cudnn_convolution_native.h>
|
| 575 |
+
#include <ATen/ops/cudnn_convolution_add_relu_native.h>
|
| 576 |
+
#include <ATen/ops/cudnn_convolution_relu_native.h>
|
| 577 |
+
#include <ATen/ops/cudnn_convolution_transpose_native.h>
|
| 578 |
+
#include <ATen/ops/cudnn_grid_sampler_native.h>
|
| 579 |
+
#include <ATen/ops/cudnn_grid_sampler_backward_native.h>
|
| 580 |
+
#include <ATen/ops/cudnn_is_acceptable_native.h>
|
| 581 |
+
#include <ATen/ops/cummax_native.h>
|
| 582 |
+
#include <ATen/ops/cummaxmin_backward_native.h>
|
| 583 |
+
#include <ATen/ops/cummin_native.h>
|
| 584 |
+
#include <ATen/ops/cumprod_native.h>
|
| 585 |
+
#include <ATen/ops/cumprod_backward_native.h>
|
| 586 |
+
#include <ATen/ops/cumsum_native.h>
|
| 587 |
+
#include <ATen/ops/cumulative_trapezoid_native.h>
|
| 588 |
+
#include <ATen/ops/data_native.h>
|
| 589 |
+
#include <ATen/ops/deg2rad_native.h>
|
| 590 |
+
#include <ATen/ops/dense_dim_native.h>
|
| 591 |
+
#include <ATen/ops/dequantize_native.h>
|
| 592 |
+
#include <ATen/ops/det_native.h>
|
| 593 |
+
#include <ATen/ops/detach_native.h>
|
| 594 |
+
#include <ATen/ops/detach_copy_native.h>
|
| 595 |
+
#include <ATen/ops/diag_native.h>
|
| 596 |
+
#include <ATen/ops/diag_embed_native.h>
|
| 597 |
+
#include <ATen/ops/diagflat_native.h>
|
| 598 |
+
#include <ATen/ops/diagonal_native.h>
|
| 599 |
+
#include <ATen/ops/diagonal_backward_native.h>
|
| 600 |
+
#include <ATen/ops/diagonal_copy_native.h>
|
| 601 |
+
#include <ATen/ops/diagonal_scatter_native.h>
|
| 602 |
+
#include <ATen/ops/diff_native.h>
|
| 603 |
+
#include <ATen/ops/digamma_native.h>
|
| 604 |
+
#include <ATen/ops/dist_native.h>
|
| 605 |
+
#include <ATen/ops/div_native.h>
|
| 606 |
+
#include <ATen/ops/divide_native.h>
|
| 607 |
+
#include <ATen/ops/dot_native.h>
|
| 608 |
+
#include <ATen/ops/dropout_native.h>
|
| 609 |
+
#include <ATen/ops/dsplit_native.h>
|
| 610 |
+
#include <ATen/ops/dstack_native.h>
|
| 611 |
+
#include <ATen/ops/einsum_native.h>
|
| 612 |
+
#include <ATen/ops/elu_native.h>
|
| 613 |
+
#include <ATen/ops/elu_backward_native.h>
|
| 614 |
+
#include <ATen/ops/embedding_native.h>
|
| 615 |
+
#include <ATen/ops/embedding_backward_native.h>
|
| 616 |
+
#include <ATen/ops/embedding_bag_native.h>
|
| 617 |
+
#include <ATen/ops/embedding_dense_backward_native.h>
|
| 618 |
+
#include <ATen/ops/embedding_renorm_native.h>
|
| 619 |
+
#include <ATen/ops/embedding_sparse_backward_native.h>
|
| 620 |
+
#include <ATen/ops/empty_native.h>
|
| 621 |
+
#include <ATen/ops/empty_like_native.h>
|
| 622 |
+
#include <ATen/ops/empty_permuted_native.h>
|
| 623 |
+
#include <ATen/ops/empty_quantized_native.h>
|
| 624 |
+
#include <ATen/ops/empty_strided_native.h>
|
| 625 |
+
#include <ATen/ops/eq_native.h>
|
| 626 |
+
#include <ATen/ops/equal_native.h>
|
| 627 |
+
#include <ATen/ops/erf_native.h>
|
| 628 |
+
#include <ATen/ops/erfc_native.h>
|
| 629 |
+
#include <ATen/ops/erfinv_native.h>
|
| 630 |
+
#include <ATen/ops/exp_native.h>
|
| 631 |
+
#include <ATen/ops/exp2_native.h>
|
| 632 |
+
#include <ATen/ops/expand_native.h>
|
| 633 |
+
#include <ATen/ops/expand_as_native.h>
|
| 634 |
+
#include <ATen/ops/expand_copy_native.h>
|
| 635 |
+
#include <ATen/ops/expm1_native.h>
|
| 636 |
+
#include <ATen/ops/exponential_native.h>
|
| 637 |
+
#include <ATen/ops/eye_native.h>
|
| 638 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_native.h>
|
| 639 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_native.h>
|
| 640 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_backward_native.h>
|
| 641 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_native.h>
|
| 642 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_native.h>
|
| 643 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_backward_native.h>
|
| 644 |
+
#include <ATen/ops/fbgemm_linear_fp16_weight_native.h>
|
| 645 |
+
#include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation_native.h>
|
| 646 |
+
#include <ATen/ops/fbgemm_linear_int8_weight_native.h>
|
| 647 |
+
#include <ATen/ops/fbgemm_linear_int8_weight_fp32_activation_native.h>
|
| 648 |
+
#include <ATen/ops/fbgemm_linear_quantize_weight_native.h>
|
| 649 |
+
#include <ATen/ops/fbgemm_pack_gemm_matrix_fp16_native.h>
|
| 650 |
+
#include <ATen/ops/fbgemm_pack_quantized_matrix_native.h>
|
| 651 |
+
#include <ATen/ops/feature_alpha_dropout_native.h>
|
| 652 |
+
#include <ATen/ops/feature_dropout_native.h>
|
| 653 |
+
#include <ATen/ops/fft_fft_native.h>
|
| 654 |
+
#include <ATen/ops/fft_fft2_native.h>
|
| 655 |
+
#include <ATen/ops/fft_fftfreq_native.h>
|
| 656 |
+
#include <ATen/ops/fft_fftn_native.h>
|
| 657 |
+
#include <ATen/ops/fft_fftshift_native.h>
|
| 658 |
+
#include <ATen/ops/fft_hfft_native.h>
|
| 659 |
+
#include <ATen/ops/fft_hfft2_native.h>
|
| 660 |
+
#include <ATen/ops/fft_hfftn_native.h>
|
| 661 |
+
#include <ATen/ops/fft_ifft_native.h>
|
| 662 |
+
#include <ATen/ops/fft_ifft2_native.h>
|
| 663 |
+
#include <ATen/ops/fft_ifftn_native.h>
|
| 664 |
+
#include <ATen/ops/fft_ifftshift_native.h>
|
| 665 |
+
#include <ATen/ops/fft_ihfft_native.h>
|
| 666 |
+
#include <ATen/ops/fft_ihfft2_native.h>
|
| 667 |
+
#include <ATen/ops/fft_ihfftn_native.h>
|
| 668 |
+
#include <ATen/ops/fft_irfft_native.h>
|
| 669 |
+
#include <ATen/ops/fft_irfft2_native.h>
|
| 670 |
+
#include <ATen/ops/fft_irfftn_native.h>
|
| 671 |
+
#include <ATen/ops/fft_rfft_native.h>
|
| 672 |
+
#include <ATen/ops/fft_rfft2_native.h>
|
| 673 |
+
#include <ATen/ops/fft_rfftfreq_native.h>
|
| 674 |
+
#include <ATen/ops/fft_rfftn_native.h>
|
| 675 |
+
#include <ATen/ops/fill_native.h>
|
| 676 |
+
#include <ATen/ops/fill_diagonal_native.h>
|
| 677 |
+
#include <ATen/ops/fix_native.h>
|
| 678 |
+
#include <ATen/ops/flatten_native.h>
|
| 679 |
+
#include <ATen/ops/flatten_dense_tensors_native.h>
|
| 680 |
+
#include <ATen/ops/flip_native.h>
|
| 681 |
+
#include <ATen/ops/fliplr_native.h>
|
| 682 |
+
#include <ATen/ops/flipud_native.h>
|
| 683 |
+
#include <ATen/ops/float_power_native.h>
|
| 684 |
+
#include <ATen/ops/floor_native.h>
|
| 685 |
+
#include <ATen/ops/floor_divide_native.h>
|
| 686 |
+
#include <ATen/ops/fmax_native.h>
|
| 687 |
+
#include <ATen/ops/fmin_native.h>
|
| 688 |
+
#include <ATen/ops/fmod_native.h>
|
| 689 |
+
#include <ATen/ops/frac_native.h>
|
| 690 |
+
#include <ATen/ops/fractional_max_pool2d_native.h>
|
| 691 |
+
#include <ATen/ops/fractional_max_pool2d_backward_native.h>
|
| 692 |
+
#include <ATen/ops/fractional_max_pool3d_native.h>
|
| 693 |
+
#include <ATen/ops/fractional_max_pool3d_backward_native.h>
|
| 694 |
+
#include <ATen/ops/frexp_native.h>
|
| 695 |
+
#include <ATen/ops/frobenius_norm_native.h>
|
| 696 |
+
#include <ATen/ops/from_file_native.h>
|
| 697 |
+
#include <ATen/ops/full_native.h>
|
| 698 |
+
#include <ATen/ops/full_like_native.h>
|
| 699 |
+
#include <ATen/ops/fused_moving_avg_obs_fake_quant_native.h>
|
| 700 |
+
#include <ATen/ops/gather_native.h>
|
| 701 |
+
#include <ATen/ops/gather_backward_native.h>
|
| 702 |
+
#include <ATen/ops/gcd_native.h>
|
| 703 |
+
#include <ATen/ops/ge_native.h>
|
| 704 |
+
#include <ATen/ops/gelu_native.h>
|
| 705 |
+
#include <ATen/ops/gelu_backward_native.h>
|
| 706 |
+
#include <ATen/ops/geometric_native.h>
|
| 707 |
+
#include <ATen/ops/geqrf_native.h>
|
| 708 |
+
#include <ATen/ops/ger_native.h>
|
| 709 |
+
#include <ATen/ops/glu_native.h>
|
| 710 |
+
#include <ATen/ops/glu_backward_native.h>
|
| 711 |
+
#include <ATen/ops/glu_backward_jvp_native.h>
|
| 712 |
+
#include <ATen/ops/glu_jvp_native.h>
|
| 713 |
+
#include <ATen/ops/gradient_native.h>
|
| 714 |
+
#include <ATen/ops/greater_native.h>
|
| 715 |
+
#include <ATen/ops/greater_equal_native.h>
|
| 716 |
+
#include <ATen/ops/grid_sampler_native.h>
|
| 717 |
+
#include <ATen/ops/grid_sampler_2d_native.h>
|
| 718 |
+
#include <ATen/ops/grid_sampler_2d_backward_native.h>
|
| 719 |
+
#include <ATen/ops/grid_sampler_3d_native.h>
|
| 720 |
+
#include <ATen/ops/grid_sampler_3d_backward_native.h>
|
| 721 |
+
#include <ATen/ops/group_norm_native.h>
|
| 722 |
+
#include <ATen/ops/gru_native.h>
|
| 723 |
+
#include <ATen/ops/gru_cell_native.h>
|
| 724 |
+
#include <ATen/ops/gt_native.h>
|
| 725 |
+
#include <ATen/ops/hamming_window_native.h>
|
| 726 |
+
#include <ATen/ops/hann_window_native.h>
|
| 727 |
+
#include <ATen/ops/hardshrink_native.h>
|
| 728 |
+
#include <ATen/ops/hardshrink_backward_native.h>
|
| 729 |
+
#include <ATen/ops/hardsigmoid_native.h>
|
| 730 |
+
#include <ATen/ops/hardsigmoid_backward_native.h>
|
| 731 |
+
#include <ATen/ops/hardswish_native.h>
|
| 732 |
+
#include <ATen/ops/hardswish_backward_native.h>
|
| 733 |
+
#include <ATen/ops/hardtanh_native.h>
|
| 734 |
+
#include <ATen/ops/hardtanh_backward_native.h>
|
| 735 |
+
#include <ATen/ops/heaviside_native.h>
|
| 736 |
+
#include <ATen/ops/hinge_embedding_loss_native.h>
|
| 737 |
+
#include <ATen/ops/histc_native.h>
|
| 738 |
+
#include <ATen/ops/histogram_native.h>
|
| 739 |
+
#include <ATen/ops/histogramdd_native.h>
|
| 740 |
+
#include <ATen/ops/hsplit_native.h>
|
| 741 |
+
#include <ATen/ops/hspmm_native.h>
|
| 742 |
+
#include <ATen/ops/hstack_native.h>
|
| 743 |
+
#include <ATen/ops/huber_loss_native.h>
|
| 744 |
+
#include <ATen/ops/huber_loss_backward_native.h>
|
| 745 |
+
#include <ATen/ops/hypot_native.h>
|
| 746 |
+
#include <ATen/ops/i0_native.h>
|
| 747 |
+
#include <ATen/ops/igamma_native.h>
|
| 748 |
+
#include <ATen/ops/igammac_native.h>
|
| 749 |
+
#include <ATen/ops/im2col_native.h>
|
| 750 |
+
#include <ATen/ops/imag_native.h>
|
| 751 |
+
#include <ATen/ops/index_native.h>
|
| 752 |
+
#include <ATen/ops/index_add_native.h>
|
| 753 |
+
#include <ATen/ops/index_copy_native.h>
|
| 754 |
+
#include <ATen/ops/index_fill_native.h>
|
| 755 |
+
#include <ATen/ops/index_put_native.h>
|
| 756 |
+
#include <ATen/ops/index_reduce_native.h>
|
| 757 |
+
#include <ATen/ops/index_select_native.h>
|
| 758 |
+
#include <ATen/ops/index_select_backward_native.h>
|
| 759 |
+
#include <ATen/ops/indices_native.h>
|
| 760 |
+
#include <ATen/ops/indices_copy_native.h>
|
| 761 |
+
#include <ATen/ops/infinitely_differentiable_gelu_backward_native.h>
|
| 762 |
+
#include <ATen/ops/inner_native.h>
|
| 763 |
+
#include <ATen/ops/instance_norm_native.h>
|
| 764 |
+
#include <ATen/ops/int_repr_native.h>
|
| 765 |
+
#include <ATen/ops/inverse_native.h>
|
| 766 |
+
#include <ATen/ops/is_coalesced_native.h>
|
| 767 |
+
#include <ATen/ops/is_complex_native.h>
|
| 768 |
+
#include <ATen/ops/is_conj_native.h>
|
| 769 |
+
#include <ATen/ops/is_distributed_native.h>
|
| 770 |
+
#include <ATen/ops/is_floating_point_native.h>
|
| 771 |
+
#include <ATen/ops/is_inference_native.h>
|
| 772 |
+
#include <ATen/ops/is_leaf_native.h>
|
| 773 |
+
#include <ATen/ops/is_neg_native.h>
|
| 774 |
+
#include <ATen/ops/is_nonzero_native.h>
|
| 775 |
+
#include <ATen/ops/is_pinned_native.h>
|
| 776 |
+
#include <ATen/ops/is_same_size_native.h>
|
| 777 |
+
#include <ATen/ops/is_set_to_native.h>
|
| 778 |
+
#include <ATen/ops/is_signed_native.h>
|
| 779 |
+
#include <ATen/ops/is_vulkan_available_native.h>
|
| 780 |
+
#include <ATen/ops/isclose_native.h>
|
| 781 |
+
#include <ATen/ops/isfinite_native.h>
|
| 782 |
+
#include <ATen/ops/isin_native.h>
|
| 783 |
+
#include <ATen/ops/isinf_native.h>
|
| 784 |
+
#include <ATen/ops/isnan_native.h>
|
| 785 |
+
#include <ATen/ops/isneginf_native.h>
|
| 786 |
+
#include <ATen/ops/isposinf_native.h>
|
| 787 |
+
#include <ATen/ops/isreal_native.h>
|
| 788 |
+
#include <ATen/ops/istft_native.h>
|
| 789 |
+
#include <ATen/ops/item_native.h>
|
| 790 |
+
#include <ATen/ops/kaiser_window_native.h>
|
| 791 |
+
#include <ATen/ops/kl_div_native.h>
|
| 792 |
+
#include <ATen/ops/kron_native.h>
|
| 793 |
+
#include <ATen/ops/kthvalue_native.h>
|
| 794 |
+
#include <ATen/ops/l1_loss_native.h>
|
| 795 |
+
#include <ATen/ops/layer_norm_native.h>
|
| 796 |
+
#include <ATen/ops/lcm_native.h>
|
| 797 |
+
#include <ATen/ops/ldexp_native.h>
|
| 798 |
+
#include <ATen/ops/le_native.h>
|
| 799 |
+
#include <ATen/ops/leaky_relu_native.h>
|
| 800 |
+
#include <ATen/ops/leaky_relu_backward_native.h>
|
| 801 |
+
#include <ATen/ops/lerp_native.h>
|
| 802 |
+
#include <ATen/ops/less_native.h>
|
| 803 |
+
#include <ATen/ops/less_equal_native.h>
|
| 804 |
+
#include <ATen/ops/lgamma_native.h>
|
| 805 |
+
#include <ATen/ops/lift_native.h>
|
| 806 |
+
#include <ATen/ops/lift_fresh_native.h>
|
| 807 |
+
#include <ATen/ops/lift_fresh_copy_native.h>
|
| 808 |
+
#include <ATen/ops/linalg_cholesky_native.h>
|
| 809 |
+
#include <ATen/ops/linalg_cholesky_ex_native.h>
|
| 810 |
+
#include <ATen/ops/linalg_cond_native.h>
|
| 811 |
+
#include <ATen/ops/linalg_cross_native.h>
|
| 812 |
+
#include <ATen/ops/linalg_det_native.h>
|
| 813 |
+
#include <ATen/ops/linalg_diagonal_native.h>
|
| 814 |
+
#include <ATen/ops/linalg_eig_native.h>
|
| 815 |
+
#include <ATen/ops/linalg_eigh_native.h>
|
| 816 |
+
#include <ATen/ops/linalg_eigvals_native.h>
|
| 817 |
+
#include <ATen/ops/linalg_eigvalsh_native.h>
|
| 818 |
+
#include <ATen/ops/linalg_householder_product_native.h>
|
| 819 |
+
#include <ATen/ops/linalg_inv_native.h>
|
| 820 |
+
#include <ATen/ops/linalg_inv_ex_native.h>
|
| 821 |
+
#include <ATen/ops/linalg_ldl_factor_native.h>
|
| 822 |
+
#include <ATen/ops/linalg_ldl_factor_ex_native.h>
|
| 823 |
+
#include <ATen/ops/linalg_ldl_solve_native.h>
|
| 824 |
+
#include <ATen/ops/linalg_lstsq_native.h>
|
| 825 |
+
#include <ATen/ops/linalg_lu_native.h>
|
| 826 |
+
#include <ATen/ops/linalg_lu_factor_native.h>
|
| 827 |
+
#include <ATen/ops/linalg_lu_factor_ex_native.h>
|
| 828 |
+
#include <ATen/ops/linalg_lu_solve_native.h>
|
| 829 |
+
#include <ATen/ops/linalg_matmul_native.h>
|
| 830 |
+
#include <ATen/ops/linalg_matrix_exp_native.h>
|
| 831 |
+
#include <ATen/ops/linalg_matrix_norm_native.h>
|
| 832 |
+
#include <ATen/ops/linalg_matrix_power_native.h>
|
| 833 |
+
#include <ATen/ops/linalg_matrix_rank_native.h>
|
| 834 |
+
#include <ATen/ops/linalg_multi_dot_native.h>
|
| 835 |
+
#include <ATen/ops/linalg_norm_native.h>
|
| 836 |
+
#include <ATen/ops/linalg_pinv_native.h>
|
| 837 |
+
#include <ATen/ops/linalg_qr_native.h>
|
| 838 |
+
#include <ATen/ops/linalg_slogdet_native.h>
|
| 839 |
+
#include <ATen/ops/linalg_solve_native.h>
|
| 840 |
+
#include <ATen/ops/linalg_solve_ex_native.h>
|
| 841 |
+
#include <ATen/ops/linalg_solve_triangular_native.h>
|
| 842 |
+
#include <ATen/ops/linalg_svd_native.h>
|
| 843 |
+
#include <ATen/ops/linalg_svdvals_native.h>
|
| 844 |
+
#include <ATen/ops/linalg_tensorinv_native.h>
|
| 845 |
+
#include <ATen/ops/linalg_tensorsolve_native.h>
|
| 846 |
+
#include <ATen/ops/linalg_vander_native.h>
|
| 847 |
+
#include <ATen/ops/linalg_vecdot_native.h>
|
| 848 |
+
#include <ATen/ops/linalg_vector_norm_native.h>
|
| 849 |
+
#include <ATen/ops/linear_native.h>
|
| 850 |
+
#include <ATen/ops/linear_backward_native.h>
|
| 851 |
+
#include <ATen/ops/linspace_native.h>
|
| 852 |
+
#include <ATen/ops/log_native.h>
|
| 853 |
+
#include <ATen/ops/log10_native.h>
|
| 854 |
+
#include <ATen/ops/log1p_native.h>
|
| 855 |
+
#include <ATen/ops/log2_native.h>
|
| 856 |
+
#include <ATen/ops/log_normal_native.h>
|
| 857 |
+
#include <ATen/ops/log_sigmoid_native.h>
|
| 858 |
+
#include <ATen/ops/log_sigmoid_backward_native.h>
|
| 859 |
+
#include <ATen/ops/log_sigmoid_forward_native.h>
|
| 860 |
+
#include <ATen/ops/log_softmax_native.h>
|
| 861 |
+
#include <ATen/ops/logaddexp_native.h>
|
| 862 |
+
#include <ATen/ops/logaddexp2_native.h>
|
| 863 |
+
#include <ATen/ops/logcumsumexp_native.h>
|
| 864 |
+
#include <ATen/ops/logdet_native.h>
|
| 865 |
+
#include <ATen/ops/logical_and_native.h>
|
| 866 |
+
#include <ATen/ops/logical_not_native.h>
|
| 867 |
+
#include <ATen/ops/logical_or_native.h>
|
| 868 |
+
#include <ATen/ops/logical_xor_native.h>
|
| 869 |
+
#include <ATen/ops/logit_native.h>
|
| 870 |
+
#include <ATen/ops/logit_backward_native.h>
|
| 871 |
+
#include <ATen/ops/logspace_native.h>
|
| 872 |
+
#include <ATen/ops/logsumexp_native.h>
|
| 873 |
+
#include <ATen/ops/lshift_native.h>
|
| 874 |
+
#include <ATen/ops/lstm_native.h>
|
| 875 |
+
#include <ATen/ops/lstm_cell_native.h>
|
| 876 |
+
#include <ATen/ops/lstm_mps_backward_native.h>
|
| 877 |
+
#include <ATen/ops/lt_native.h>
|
| 878 |
+
#include <ATen/ops/lu_solve_native.h>
|
| 879 |
+
#include <ATen/ops/lu_unpack_native.h>
|
| 880 |
+
#include <ATen/ops/mH_native.h>
|
| 881 |
+
#include <ATen/ops/mT_native.h>
|
| 882 |
+
#include <ATen/ops/margin_ranking_loss_native.h>
|
| 883 |
+
#include <ATen/ops/masked_fill_native.h>
|
| 884 |
+
#include <ATen/ops/masked_scatter_native.h>
|
| 885 |
+
#include <ATen/ops/masked_scatter_backward_native.h>
|
| 886 |
+
#include <ATen/ops/masked_select_native.h>
|
| 887 |
+
#include <ATen/ops/masked_select_backward_native.h>
|
| 888 |
+
#include <ATen/ops/matmul_native.h>
|
| 889 |
+
#include <ATen/ops/matmul_backward_native.h>
|
| 890 |
+
#include <ATen/ops/matrix_H_native.h>
|
| 891 |
+
#include <ATen/ops/matrix_exp_native.h>
|
| 892 |
+
#include <ATen/ops/matrix_exp_backward_native.h>
|
| 893 |
+
#include <ATen/ops/matrix_power_native.h>
|
| 894 |
+
#include <ATen/ops/max_native.h>
|
| 895 |
+
#include <ATen/ops/max_pool1d_native.h>
|
| 896 |
+
#include <ATen/ops/max_pool1d_with_indices_native.h>
|
| 897 |
+
#include <ATen/ops/max_pool2d_native.h>
|
| 898 |
+
#include <ATen/ops/max_pool2d_backward_native.h>
|
| 899 |
+
#include <ATen/ops/max_pool2d_with_indices_native.h>
|
| 900 |
+
#include <ATen/ops/max_pool2d_with_indices_backward_native.h>
|
| 901 |
+
#include <ATen/ops/max_pool3d_native.h>
|
| 902 |
+
#include <ATen/ops/max_pool3d_with_indices_native.h>
|
| 903 |
+
#include <ATen/ops/max_pool3d_with_indices_backward_native.h>
|
| 904 |
+
#include <ATen/ops/max_unpool2d_native.h>
|
| 905 |
+
#include <ATen/ops/max_unpool3d_native.h>
|
| 906 |
+
#include <ATen/ops/maximum_native.h>
|
| 907 |
+
#include <ATen/ops/mean_native.h>
|
| 908 |
+
#include <ATen/ops/median_native.h>
|
| 909 |
+
#include <ATen/ops/meshgrid_native.h>
|
| 910 |
+
#include <ATen/ops/min_native.h>
|
| 911 |
+
#include <ATen/ops/minimum_native.h>
|
| 912 |
+
#include <ATen/ops/miopen_batch_norm_native.h>
|
| 913 |
+
#include <ATen/ops/miopen_batch_norm_backward_native.h>
|
| 914 |
+
#include <ATen/ops/miopen_convolution_native.h>
|
| 915 |
+
#include <ATen/ops/miopen_convolution_add_relu_native.h>
|
| 916 |
+
#include <ATen/ops/miopen_convolution_relu_native.h>
|
| 917 |
+
#include <ATen/ops/miopen_convolution_transpose_native.h>
|
| 918 |
+
#include <ATen/ops/miopen_depthwise_convolution_native.h>
|
| 919 |
+
#include <ATen/ops/miopen_rnn_native.h>
|
| 920 |
+
#include <ATen/ops/miopen_rnn_backward_native.h>
|
| 921 |
+
#include <ATen/ops/mish_native.h>
|
| 922 |
+
#include <ATen/ops/mish_backward_native.h>
|
| 923 |
+
#include <ATen/ops/mkldnn_adaptive_avg_pool2d_native.h>
|
| 924 |
+
#include <ATen/ops/mkldnn_adaptive_avg_pool2d_backward_native.h>
|
| 925 |
+
#include <ATen/ops/mkldnn_convolution_native.h>
|
| 926 |
+
#include <ATen/ops/mkldnn_linear_native.h>
|
| 927 |
+
#include <ATen/ops/mkldnn_linear_backward_native.h>
|
| 928 |
+
#include <ATen/ops/mkldnn_linear_backward_input_native.h>
|
| 929 |
+
#include <ATen/ops/mkldnn_linear_backward_weights_native.h>
|
| 930 |
+
#include <ATen/ops/mkldnn_max_pool2d_native.h>
|
| 931 |
+
#include <ATen/ops/mkldnn_max_pool2d_backward_native.h>
|
| 932 |
+
#include <ATen/ops/mkldnn_max_pool3d_native.h>
|
| 933 |
+
#include <ATen/ops/mkldnn_max_pool3d_backward_native.h>
|
| 934 |
+
#include <ATen/ops/mkldnn_reorder_conv2d_weight_native.h>
|
| 935 |
+
#include <ATen/ops/mkldnn_reorder_conv3d_weight_native.h>
|
| 936 |
+
#include <ATen/ops/mkldnn_rnn_layer_native.h>
|
| 937 |
+
#include <ATen/ops/mkldnn_rnn_layer_backward_native.h>
|
| 938 |
+
#include <ATen/ops/mm_native.h>
|
| 939 |
+
#include <ATen/ops/mode_native.h>
|
| 940 |
+
#include <ATen/ops/moveaxis_native.h>
|
| 941 |
+
#include <ATen/ops/movedim_native.h>
|
| 942 |
+
#include <ATen/ops/mps_convolution_backward_native.h>
|
| 943 |
+
#include <ATen/ops/mps_convolution_transpose_backward_native.h>
|
| 944 |
+
#include <ATen/ops/mse_loss_native.h>
|
| 945 |
+
#include <ATen/ops/mse_loss_backward_native.h>
|
| 946 |
+
#include <ATen/ops/msort_native.h>
|
| 947 |
+
#include <ATen/ops/mul_native.h>
|
| 948 |
+
#include <ATen/ops/multi_margin_loss_native.h>
|
| 949 |
+
#include <ATen/ops/multi_margin_loss_backward_native.h>
|
| 950 |
+
#include <ATen/ops/multilabel_margin_loss_native.h>
|
| 951 |
+
#include <ATen/ops/multilabel_margin_loss_backward_native.h>
|
| 952 |
+
#include <ATen/ops/multilabel_margin_loss_forward_native.h>
|
| 953 |
+
#include <ATen/ops/multinomial_native.h>
|
| 954 |
+
#include <ATen/ops/multiply_native.h>
|
| 955 |
+
#include <ATen/ops/mv_native.h>
|
| 956 |
+
#include <ATen/ops/mvlgamma_native.h>
|
| 957 |
+
#include <ATen/ops/nan_to_num_native.h>
|
| 958 |
+
#include <ATen/ops/nanmean_native.h>
|
| 959 |
+
#include <ATen/ops/nanmedian_native.h>
|
| 960 |
+
#include <ATen/ops/nanquantile_native.h>
|
| 961 |
+
#include <ATen/ops/nansum_native.h>
|
| 962 |
+
#include <ATen/ops/narrow_native.h>
|
| 963 |
+
#include <ATen/ops/narrow_copy_native.h>
|
| 964 |
+
#include <ATen/ops/native_batch_norm_native.h>
|
| 965 |
+
#include <ATen/ops/native_batch_norm_backward_native.h>
|
| 966 |
+
#include <ATen/ops/native_channel_shuffle_native.h>
|
| 967 |
+
#include <ATen/ops/native_dropout_native.h>
|
| 968 |
+
#include <ATen/ops/native_dropout_backward_native.h>
|
| 969 |
+
#include <ATen/ops/native_group_norm_native.h>
|
| 970 |
+
#include <ATen/ops/native_group_norm_backward_native.h>
|
| 971 |
+
#include <ATen/ops/native_layer_norm_native.h>
|
| 972 |
+
#include <ATen/ops/native_layer_norm_backward_native.h>
|
| 973 |
+
#include <ATen/ops/native_norm_native.h>
|
| 974 |
+
#include <ATen/ops/ne_native.h>
|
| 975 |
+
#include <ATen/ops/neg_native.h>
|
| 976 |
+
#include <ATen/ops/negative_native.h>
|
| 977 |
+
#include <ATen/ops/nested_to_padded_tensor_native.h>
|
| 978 |
+
#include <ATen/ops/new_empty_native.h>
|
| 979 |
+
#include <ATen/ops/new_empty_strided_native.h>
|
| 980 |
+
#include <ATen/ops/new_full_native.h>
|
| 981 |
+
#include <ATen/ops/new_ones_native.h>
|
| 982 |
+
#include <ATen/ops/new_zeros_native.h>
|
| 983 |
+
#include <ATen/ops/nextafter_native.h>
|
| 984 |
+
#include <ATen/ops/nll_loss_native.h>
|
| 985 |
+
#include <ATen/ops/nll_loss2d_native.h>
|
| 986 |
+
#include <ATen/ops/nll_loss2d_backward_native.h>
|
| 987 |
+
#include <ATen/ops/nll_loss2d_forward_native.h>
|
| 988 |
+
#include <ATen/ops/nll_loss_backward_native.h>
|
| 989 |
+
#include <ATen/ops/nll_loss_forward_native.h>
|
| 990 |
+
#include <ATen/ops/nll_loss_nd_native.h>
|
| 991 |
+
#include <ATen/ops/nonzero_native.h>
|
| 992 |
+
#include <ATen/ops/nonzero_numpy_native.h>
|
| 993 |
+
#include <ATen/ops/nonzero_static_native.h>
|
| 994 |
+
#include <ATen/ops/norm_native.h>
|
| 995 |
+
#include <ATen/ops/norm_except_dim_native.h>
|
| 996 |
+
#include <ATen/ops/normal_native.h>
|
| 997 |
+
#include <ATen/ops/not_equal_native.h>
|
| 998 |
+
#include <ATen/ops/nuclear_norm_native.h>
|
| 999 |
+
#include <ATen/ops/numpy_T_native.h>
|
| 1000 |
+
#include <ATen/ops/one_hot_native.h>
|
| 1001 |
+
#include <ATen/ops/ones_native.h>
|
| 1002 |
+
#include <ATen/ops/ones_like_native.h>
|
| 1003 |
+
#include <ATen/ops/or_native.h>
|
| 1004 |
+
#include <ATen/ops/orgqr_native.h>
|
| 1005 |
+
#include <ATen/ops/ormqr_native.h>
|
| 1006 |
+
#include <ATen/ops/outer_native.h>
|
| 1007 |
+
#include <ATen/ops/output_nr_native.h>
|
| 1008 |
+
#include <ATen/ops/pad_native.h>
|
| 1009 |
+
#include <ATen/ops/pad_sequence_native.h>
|
| 1010 |
+
#include <ATen/ops/pairwise_distance_native.h>
|
| 1011 |
+
#include <ATen/ops/pdist_native.h>
|
| 1012 |
+
#include <ATen/ops/permute_native.h>
|
| 1013 |
+
#include <ATen/ops/permute_copy_native.h>
|
| 1014 |
+
#include <ATen/ops/pin_memory_native.h>
|
| 1015 |
+
#include <ATen/ops/pinverse_native.h>
|
| 1016 |
+
#include <ATen/ops/pixel_shuffle_native.h>
|
| 1017 |
+
#include <ATen/ops/pixel_unshuffle_native.h>
|
| 1018 |
+
#include <ATen/ops/poisson_native.h>
|
| 1019 |
+
#include <ATen/ops/poisson_nll_loss_native.h>
|
| 1020 |
+
#include <ATen/ops/polar_native.h>
|
| 1021 |
+
#include <ATen/ops/polygamma_native.h>
|
| 1022 |
+
#include <ATen/ops/positive_native.h>
|
| 1023 |
+
#include <ATen/ops/pow_native.h>
|
| 1024 |
+
#include <ATen/ops/prelu_native.h>
|
| 1025 |
+
#include <ATen/ops/prod_native.h>
|
| 1026 |
+
#include <ATen/ops/promote_types_native.h>
|
| 1027 |
+
#include <ATen/ops/put_native.h>
|
| 1028 |
+
#include <ATen/ops/q_per_channel_axis_native.h>
|
| 1029 |
+
#include <ATen/ops/q_per_channel_scales_native.h>
|
| 1030 |
+
#include <ATen/ops/q_per_channel_zero_points_native.h>
|
| 1031 |
+
#include <ATen/ops/q_scale_native.h>
|
| 1032 |
+
#include <ATen/ops/q_zero_point_native.h>
|
| 1033 |
+
#include <ATen/ops/qr_native.h>
|
| 1034 |
+
#include <ATen/ops/qscheme_native.h>
|
| 1035 |
+
#include <ATen/ops/quantile_native.h>
|
| 1036 |
+
#include <ATen/ops/quantize_per_channel_native.h>
|
| 1037 |
+
#include <ATen/ops/quantize_per_tensor_native.h>
|
| 1038 |
+
#include <ATen/ops/quantize_per_tensor_dynamic_native.h>
|
| 1039 |
+
#include <ATen/ops/quantized_batch_norm_native.h>
|
| 1040 |
+
#include <ATen/ops/quantized_gru_cell_native.h>
|
| 1041 |
+
#include <ATen/ops/quantized_lstm_cell_native.h>
|
| 1042 |
+
#include <ATen/ops/quantized_max_pool1d_native.h>
|
| 1043 |
+
#include <ATen/ops/quantized_max_pool2d_native.h>
|
| 1044 |
+
#include <ATen/ops/quantized_max_pool3d_native.h>
|
| 1045 |
+
#include <ATen/ops/quantized_rnn_relu_cell_native.h>
|
| 1046 |
+
#include <ATen/ops/quantized_rnn_tanh_cell_native.h>
|
| 1047 |
+
#include <ATen/ops/rad2deg_native.h>
|
| 1048 |
+
#include <ATen/ops/rand_native.h>
|
| 1049 |
+
#include <ATen/ops/rand_like_native.h>
|
| 1050 |
+
#include <ATen/ops/randint_native.h>
|
| 1051 |
+
#include <ATen/ops/randint_like_native.h>
|
| 1052 |
+
#include <ATen/ops/randn_native.h>
|
| 1053 |
+
#include <ATen/ops/randn_like_native.h>
|
| 1054 |
+
#include <ATen/ops/random_native.h>
|
| 1055 |
+
#include <ATen/ops/randperm_native.h>
|
| 1056 |
+
#include <ATen/ops/range_native.h>
|
| 1057 |
+
#include <ATen/ops/ravel_native.h>
|
| 1058 |
+
#include <ATen/ops/real_native.h>
|
| 1059 |
+
#include <ATen/ops/reciprocal_native.h>
|
| 1060 |
+
#include <ATen/ops/record_stream_native.h>
|
| 1061 |
+
#include <ATen/ops/refine_names_native.h>
|
| 1062 |
+
#include <ATen/ops/reflection_pad1d_native.h>
|
| 1063 |
+
#include <ATen/ops/reflection_pad1d_backward_native.h>
|
| 1064 |
+
#include <ATen/ops/reflection_pad2d_native.h>
|
| 1065 |
+
#include <ATen/ops/reflection_pad2d_backward_native.h>
|
| 1066 |
+
#include <ATen/ops/reflection_pad3d_native.h>
|
| 1067 |
+
#include <ATen/ops/reflection_pad3d_backward_native.h>
|
| 1068 |
+
#include <ATen/ops/relu_native.h>
|
| 1069 |
+
#include <ATen/ops/relu6_native.h>
|
| 1070 |
+
#include <ATen/ops/remainder_native.h>
|
| 1071 |
+
#include <ATen/ops/rename_native.h>
|
| 1072 |
+
#include <ATen/ops/renorm_native.h>
|
| 1073 |
+
#include <ATen/ops/repeat_native.h>
|
| 1074 |
+
#include <ATen/ops/repeat_interleave_native.h>
|
| 1075 |
+
#include <ATen/ops/replication_pad1d_native.h>
|
| 1076 |
+
#include <ATen/ops/replication_pad1d_backward_native.h>
|
| 1077 |
+
#include <ATen/ops/replication_pad2d_native.h>
|
| 1078 |
+
#include <ATen/ops/replication_pad2d_backward_native.h>
|
| 1079 |
+
#include <ATen/ops/replication_pad3d_native.h>
|
| 1080 |
+
#include <ATen/ops/replication_pad3d_backward_native.h>
|
| 1081 |
+
#include <ATen/ops/requires_grad_native.h>
|
| 1082 |
+
#include <ATen/ops/reshape_native.h>
|
| 1083 |
+
#include <ATen/ops/reshape_as_native.h>
|
| 1084 |
+
#include <ATen/ops/resize_native.h>
|
| 1085 |
+
#include <ATen/ops/resize_as_native.h>
|
| 1086 |
+
#include <ATen/ops/resize_as_sparse_native.h>
|
| 1087 |
+
#include <ATen/ops/resolve_conj_native.h>
|
| 1088 |
+
#include <ATen/ops/resolve_neg_native.h>
|
| 1089 |
+
#include <ATen/ops/result_type_native.h>
|
| 1090 |
+
#include <ATen/ops/retain_grad_native.h>
|
| 1091 |
+
#include <ATen/ops/retains_grad_native.h>
|
| 1092 |
+
#include <ATen/ops/rms_norm_native.h>
|
| 1093 |
+
#include <ATen/ops/rnn_relu_native.h>
|
| 1094 |
+
#include <ATen/ops/rnn_relu_cell_native.h>
|
| 1095 |
+
#include <ATen/ops/rnn_tanh_native.h>
|
| 1096 |
+
#include <ATen/ops/rnn_tanh_cell_native.h>
|
| 1097 |
+
#include <ATen/ops/roll_native.h>
|
| 1098 |
+
#include <ATen/ops/rot90_native.h>
|
| 1099 |
+
#include <ATen/ops/round_native.h>
|
| 1100 |
+
#include <ATen/ops/row_indices_native.h>
|
| 1101 |
+
#include <ATen/ops/row_indices_copy_native.h>
|
| 1102 |
+
#include <ATen/ops/row_stack_native.h>
|
| 1103 |
+
#include <ATen/ops/rrelu_native.h>
|
| 1104 |
+
#include <ATen/ops/rrelu_with_noise_native.h>
|
| 1105 |
+
#include <ATen/ops/rrelu_with_noise_backward_native.h>
|
| 1106 |
+
#include <ATen/ops/rshift_native.h>
|
| 1107 |
+
#include <ATen/ops/rsqrt_native.h>
|
| 1108 |
+
#include <ATen/ops/rsub_native.h>
|
| 1109 |
+
#include <ATen/ops/scalar_tensor_native.h>
|
| 1110 |
+
#include <ATen/ops/scaled_dot_product_attention_native.h>
|
| 1111 |
+
#include <ATen/ops/scatter_native.h>
|
| 1112 |
+
#include <ATen/ops/scatter_add_native.h>
|
| 1113 |
+
#include <ATen/ops/scatter_reduce_native.h>
|
| 1114 |
+
#include <ATen/ops/searchsorted_native.h>
|
| 1115 |
+
#include <ATen/ops/segment_reduce_native.h>
|
| 1116 |
+
#include <ATen/ops/select_native.h>
|
| 1117 |
+
#include <ATen/ops/select_backward_native.h>
|
| 1118 |
+
#include <ATen/ops/select_copy_native.h>
|
| 1119 |
+
#include <ATen/ops/select_scatter_native.h>
|
| 1120 |
+
#include <ATen/ops/selu_native.h>
|
| 1121 |
+
#include <ATen/ops/set_native.h>
|
| 1122 |
+
#include <ATen/ops/set_data_native.h>
|
| 1123 |
+
#include <ATen/ops/sgn_native.h>
|
| 1124 |
+
#include <ATen/ops/sigmoid_native.h>
|
| 1125 |
+
#include <ATen/ops/sigmoid_backward_native.h>
|
| 1126 |
+
#include <ATen/ops/sign_native.h>
|
| 1127 |
+
#include <ATen/ops/signbit_native.h>
|
| 1128 |
+
#include <ATen/ops/silu_native.h>
|
| 1129 |
+
#include <ATen/ops/silu_backward_native.h>
|
| 1130 |
+
#include <ATen/ops/sin_native.h>
|
| 1131 |
+
#include <ATen/ops/sinc_native.h>
|
| 1132 |
+
#include <ATen/ops/sinh_native.h>
|
| 1133 |
+
#include <ATen/ops/size_native.h>
|
| 1134 |
+
#include <ATen/ops/slice_native.h>
|
| 1135 |
+
#include <ATen/ops/slice_backward_native.h>
|
| 1136 |
+
#include <ATen/ops/slice_copy_native.h>
|
| 1137 |
+
#include <ATen/ops/slice_inverse_native.h>
|
| 1138 |
+
#include <ATen/ops/slice_scatter_native.h>
|
| 1139 |
+
#include <ATen/ops/slogdet_native.h>
|
| 1140 |
+
#include <ATen/ops/slow_conv3d_native.h>
|
| 1141 |
+
#include <ATen/ops/slow_conv3d_forward_native.h>
|
| 1142 |
+
#include <ATen/ops/slow_conv_dilated2d_native.h>
|
| 1143 |
+
#include <ATen/ops/slow_conv_dilated3d_native.h>
|
| 1144 |
+
#include <ATen/ops/slow_conv_transpose2d_native.h>
|
| 1145 |
+
#include <ATen/ops/slow_conv_transpose3d_native.h>
|
| 1146 |
+
#include <ATen/ops/smm_native.h>
|
| 1147 |
+
#include <ATen/ops/smooth_l1_loss_native.h>
|
| 1148 |
+
#include <ATen/ops/smooth_l1_loss_backward_native.h>
|
| 1149 |
+
#include <ATen/ops/soft_margin_loss_native.h>
|
| 1150 |
+
#include <ATen/ops/soft_margin_loss_backward_native.h>
|
| 1151 |
+
#include <ATen/ops/softmax_native.h>
|
| 1152 |
+
#include <ATen/ops/softplus_native.h>
|
| 1153 |
+
#include <ATen/ops/softplus_backward_native.h>
|
| 1154 |
+
#include <ATen/ops/softshrink_native.h>
|
| 1155 |
+
#include <ATen/ops/softshrink_backward_native.h>
|
| 1156 |
+
#include <ATen/ops/sort_native.h>
|
| 1157 |
+
#include <ATen/ops/sparse_bsc_tensor_native.h>
|
| 1158 |
+
#include <ATen/ops/sparse_bsr_tensor_native.h>
|
| 1159 |
+
#include <ATen/ops/sparse_compressed_tensor_native.h>
|
| 1160 |
+
#include <ATen/ops/sparse_coo_tensor_native.h>
|
| 1161 |
+
#include <ATen/ops/sparse_csc_tensor_native.h>
|
| 1162 |
+
#include <ATen/ops/sparse_csr_tensor_native.h>
|
| 1163 |
+
#include <ATen/ops/sparse_dim_native.h>
|
| 1164 |
+
#include <ATen/ops/sparse_mask_native.h>
|
| 1165 |
+
#include <ATen/ops/sparse_resize_native.h>
|
| 1166 |
+
#include <ATen/ops/sparse_resize_and_clear_native.h>
|
| 1167 |
+
#include <ATen/ops/sparse_sampled_addmm_native.h>
|
| 1168 |
+
#include <ATen/ops/special_airy_ai_native.h>
|
| 1169 |
+
#include <ATen/ops/special_bessel_j0_native.h>
|
| 1170 |
+
#include <ATen/ops/special_bessel_j1_native.h>
|
| 1171 |
+
#include <ATen/ops/special_bessel_y0_native.h>
|
| 1172 |
+
#include <ATen/ops/special_bessel_y1_native.h>
|
| 1173 |
+
#include <ATen/ops/special_chebyshev_polynomial_t_native.h>
|
| 1174 |
+
#include <ATen/ops/special_chebyshev_polynomial_u_native.h>
|
| 1175 |
+
#include <ATen/ops/special_chebyshev_polynomial_v_native.h>
|
| 1176 |
+
#include <ATen/ops/special_chebyshev_polynomial_w_native.h>
|
| 1177 |
+
#include <ATen/ops/special_digamma_native.h>
|
| 1178 |
+
#include <ATen/ops/special_entr_native.h>
|
| 1179 |
+
#include <ATen/ops/special_erf_native.h>
|
| 1180 |
+
#include <ATen/ops/special_erfc_native.h>
|
| 1181 |
+
#include <ATen/ops/special_erfcx_native.h>
|
| 1182 |
+
#include <ATen/ops/special_erfinv_native.h>
|
| 1183 |
+
#include <ATen/ops/special_exp2_native.h>
|
| 1184 |
+
#include <ATen/ops/special_expit_native.h>
|
| 1185 |
+
#include <ATen/ops/special_expm1_native.h>
|
| 1186 |
+
#include <ATen/ops/special_gammainc_native.h>
|
| 1187 |
+
#include <ATen/ops/special_gammaincc_native.h>
|
| 1188 |
+
#include <ATen/ops/special_gammaln_native.h>
|
| 1189 |
+
#include <ATen/ops/special_hermite_polynomial_h_native.h>
|
| 1190 |
+
#include <ATen/ops/special_hermite_polynomial_he_native.h>
|
| 1191 |
+
#include <ATen/ops/special_i0_native.h>
|
| 1192 |
+
#include <ATen/ops/special_i0e_native.h>
|
| 1193 |
+
#include <ATen/ops/special_i1_native.h>
|
| 1194 |
+
#include <ATen/ops/special_i1e_native.h>
|
| 1195 |
+
#include <ATen/ops/special_laguerre_polynomial_l_native.h>
|
| 1196 |
+
#include <ATen/ops/special_legendre_polynomial_p_native.h>
|
| 1197 |
+
#include <ATen/ops/special_log1p_native.h>
|
| 1198 |
+
#include <ATen/ops/special_log_ndtr_native.h>
|
| 1199 |
+
#include <ATen/ops/special_log_softmax_native.h>
|
| 1200 |
+
#include <ATen/ops/special_logit_native.h>
|
| 1201 |
+
#include <ATen/ops/special_logsumexp_native.h>
|
| 1202 |
+
#include <ATen/ops/special_modified_bessel_i0_native.h>
|
| 1203 |
+
#include <ATen/ops/special_modified_bessel_i1_native.h>
|
| 1204 |
+
#include <ATen/ops/special_modified_bessel_k0_native.h>
|
| 1205 |
+
#include <ATen/ops/special_modified_bessel_k1_native.h>
|
| 1206 |
+
#include <ATen/ops/special_multigammaln_native.h>
|
| 1207 |
+
#include <ATen/ops/special_ndtr_native.h>
|
| 1208 |
+
#include <ATen/ops/special_ndtri_native.h>
|
| 1209 |
+
#include <ATen/ops/special_polygamma_native.h>
|
| 1210 |
+
#include <ATen/ops/special_psi_native.h>
|
| 1211 |
+
#include <ATen/ops/special_round_native.h>
|
| 1212 |
+
#include <ATen/ops/special_scaled_modified_bessel_k0_native.h>
|
| 1213 |
+
#include <ATen/ops/special_scaled_modified_bessel_k1_native.h>
|
| 1214 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_t_native.h>
|
| 1215 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_u_native.h>
|
| 1216 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_v_native.h>
|
| 1217 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_w_native.h>
|
| 1218 |
+
#include <ATen/ops/special_sinc_native.h>
|
| 1219 |
+
#include <ATen/ops/special_softmax_native.h>
|
| 1220 |
+
#include <ATen/ops/special_spherical_bessel_j0_native.h>
|
| 1221 |
+
#include <ATen/ops/special_xlog1py_native.h>
|
| 1222 |
+
#include <ATen/ops/special_xlogy_native.h>
|
| 1223 |
+
#include <ATen/ops/special_zeta_native.h>
|
| 1224 |
+
#include <ATen/ops/split_native.h>
|
| 1225 |
+
#include <ATen/ops/split_copy_native.h>
|
| 1226 |
+
#include <ATen/ops/split_with_sizes_native.h>
|
| 1227 |
+
#include <ATen/ops/split_with_sizes_copy_native.h>
|
| 1228 |
+
#include <ATen/ops/sqrt_native.h>
|
| 1229 |
+
#include <ATen/ops/square_native.h>
|
| 1230 |
+
#include <ATen/ops/squeeze_native.h>
|
| 1231 |
+
#include <ATen/ops/squeeze_copy_native.h>
|
| 1232 |
+
#include <ATen/ops/sspaddmm_native.h>
|
| 1233 |
+
#include <ATen/ops/stack_native.h>
|
| 1234 |
+
#include <ATen/ops/std_native.h>
|
| 1235 |
+
#include <ATen/ops/std_mean_native.h>
|
| 1236 |
+
#include <ATen/ops/stft_native.h>
|
| 1237 |
+
#include <ATen/ops/stride_native.h>
|
| 1238 |
+
#include <ATen/ops/sub_native.h>
|
| 1239 |
+
#include <ATen/ops/subtract_native.h>
|
| 1240 |
+
#include <ATen/ops/sum_native.h>
|
| 1241 |
+
#include <ATen/ops/sum_to_size_native.h>
|
| 1242 |
+
#include <ATen/ops/svd_native.h>
|
| 1243 |
+
#include <ATen/ops/swapaxes_native.h>
|
| 1244 |
+
#include <ATen/ops/swapdims_native.h>
|
| 1245 |
+
#include <ATen/ops/sym_constrain_range_native.h>
|
| 1246 |
+
#include <ATen/ops/sym_constrain_range_for_size_native.h>
|
| 1247 |
+
#include <ATen/ops/sym_numel_native.h>
|
| 1248 |
+
#include <ATen/ops/sym_size_native.h>
|
| 1249 |
+
#include <ATen/ops/sym_storage_offset_native.h>
|
| 1250 |
+
#include <ATen/ops/sym_stride_native.h>
|
| 1251 |
+
#include <ATen/ops/t_native.h>
|
| 1252 |
+
#include <ATen/ops/t_copy_native.h>
|
| 1253 |
+
#include <ATen/ops/take_native.h>
|
| 1254 |
+
#include <ATen/ops/take_along_dim_native.h>
|
| 1255 |
+
#include <ATen/ops/tan_native.h>
|
| 1256 |
+
#include <ATen/ops/tanh_native.h>
|
| 1257 |
+
#include <ATen/ops/tanh_backward_native.h>
|
| 1258 |
+
#include <ATen/ops/tensor_split_native.h>
|
| 1259 |
+
#include <ATen/ops/tensordot_native.h>
|
| 1260 |
+
#include <ATen/ops/thnn_conv2d_native.h>
|
| 1261 |
+
#include <ATen/ops/threshold_native.h>
|
| 1262 |
+
#include <ATen/ops/threshold_backward_native.h>
|
| 1263 |
+
#include <ATen/ops/tile_native.h>
|
| 1264 |
+
#include <ATen/ops/to_native.h>
|
| 1265 |
+
#include <ATen/ops/to_dense_native.h>
|
| 1266 |
+
#include <ATen/ops/to_dense_backward_native.h>
|
| 1267 |
+
#include <ATen/ops/to_mkldnn_native.h>
|
| 1268 |
+
#include <ATen/ops/to_mkldnn_backward_native.h>
|
| 1269 |
+
#include <ATen/ops/to_padded_tensor_native.h>
|
| 1270 |
+
#include <ATen/ops/to_sparse_native.h>
|
| 1271 |
+
#include <ATen/ops/to_sparse_bsc_native.h>
|
| 1272 |
+
#include <ATen/ops/to_sparse_bsr_native.h>
|
| 1273 |
+
#include <ATen/ops/to_sparse_csc_native.h>
|
| 1274 |
+
#include <ATen/ops/to_sparse_csr_native.h>
|
| 1275 |
+
#include <ATen/ops/topk_native.h>
|
| 1276 |
+
#include <ATen/ops/trace_native.h>
|
| 1277 |
+
#include <ATen/ops/trace_backward_native.h>
|
| 1278 |
+
#include <ATen/ops/transpose_native.h>
|
| 1279 |
+
#include <ATen/ops/transpose_copy_native.h>
|
| 1280 |
+
#include <ATen/ops/trapezoid_native.h>
|
| 1281 |
+
#include <ATen/ops/trapz_native.h>
|
| 1282 |
+
#include <ATen/ops/triangular_solve_native.h>
|
| 1283 |
+
#include <ATen/ops/tril_native.h>
|
| 1284 |
+
#include <ATen/ops/tril_indices_native.h>
|
| 1285 |
+
#include <ATen/ops/triplet_margin_loss_native.h>
|
| 1286 |
+
#include <ATen/ops/triu_native.h>
|
| 1287 |
+
#include <ATen/ops/triu_indices_native.h>
|
| 1288 |
+
#include <ATen/ops/true_divide_native.h>
|
| 1289 |
+
#include <ATen/ops/trunc_native.h>
|
| 1290 |
+
#include <ATen/ops/type_as_native.h>
|
| 1291 |
+
#include <ATen/ops/unbind_native.h>
|
| 1292 |
+
#include <ATen/ops/unbind_copy_native.h>
|
| 1293 |
+
#include <ATen/ops/unflatten_native.h>
|
| 1294 |
+
#include <ATen/ops/unflatten_dense_tensors_native.h>
|
| 1295 |
+
#include <ATen/ops/unfold_native.h>
|
| 1296 |
+
#include <ATen/ops/unfold_backward_native.h>
|
| 1297 |
+
#include <ATen/ops/unfold_copy_native.h>
|
| 1298 |
+
#include <ATen/ops/uniform_native.h>
|
| 1299 |
+
#include <ATen/ops/unique_consecutive_native.h>
|
| 1300 |
+
#include <ATen/ops/unique_dim_native.h>
|
| 1301 |
+
#include <ATen/ops/unique_dim_consecutive_native.h>
|
| 1302 |
+
#include <ATen/ops/unsafe_chunk_native.h>
|
| 1303 |
+
#include <ATen/ops/unsafe_split_native.h>
|
| 1304 |
+
#include <ATen/ops/unsafe_split_with_sizes_native.h>
|
| 1305 |
+
#include <ATen/ops/unsqueeze_native.h>
|
| 1306 |
+
#include <ATen/ops/unsqueeze_copy_native.h>
|
| 1307 |
+
#include <ATen/ops/upsample_bicubic2d_native.h>
|
| 1308 |
+
#include <ATen/ops/upsample_bicubic2d_backward_native.h>
|
| 1309 |
+
#include <ATen/ops/upsample_bilinear2d_native.h>
|
| 1310 |
+
#include <ATen/ops/upsample_bilinear2d_backward_native.h>
|
| 1311 |
+
#include <ATen/ops/upsample_linear1d_native.h>
|
| 1312 |
+
#include <ATen/ops/upsample_linear1d_backward_native.h>
|
| 1313 |
+
#include <ATen/ops/upsample_nearest1d_native.h>
|
| 1314 |
+
#include <ATen/ops/upsample_nearest1d_backward_native.h>
|
| 1315 |
+
#include <ATen/ops/upsample_nearest2d_native.h>
|
| 1316 |
+
#include <ATen/ops/upsample_nearest2d_backward_native.h>
|
| 1317 |
+
#include <ATen/ops/upsample_nearest3d_native.h>
|
| 1318 |
+
#include <ATen/ops/upsample_nearest3d_backward_native.h>
|
| 1319 |
+
#include <ATen/ops/upsample_trilinear3d_native.h>
|
| 1320 |
+
#include <ATen/ops/upsample_trilinear3d_backward_native.h>
|
| 1321 |
+
#include <ATen/ops/value_selecting_reduction_backward_native.h>
|
| 1322 |
+
#include <ATen/ops/values_native.h>
|
| 1323 |
+
#include <ATen/ops/values_copy_native.h>
|
| 1324 |
+
#include <ATen/ops/vander_native.h>
|
| 1325 |
+
#include <ATen/ops/var_native.h>
|
| 1326 |
+
#include <ATen/ops/var_mean_native.h>
|
| 1327 |
+
#include <ATen/ops/vdot_native.h>
|
| 1328 |
+
#include <ATen/ops/view_native.h>
|
| 1329 |
+
#include <ATen/ops/view_as_native.h>
|
| 1330 |
+
#include <ATen/ops/view_as_complex_native.h>
|
| 1331 |
+
#include <ATen/ops/view_as_complex_copy_native.h>
|
| 1332 |
+
#include <ATen/ops/view_as_real_native.h>
|
| 1333 |
+
#include <ATen/ops/view_as_real_copy_native.h>
|
| 1334 |
+
#include <ATen/ops/view_copy_native.h>
|
| 1335 |
+
#include <ATen/ops/vsplit_native.h>
|
| 1336 |
+
#include <ATen/ops/vstack_native.h>
|
| 1337 |
+
#include <ATen/ops/where_native.h>
|
| 1338 |
+
#include <ATen/ops/xlogy_native.h>
|
| 1339 |
+
#include <ATen/ops/xor_native.h>
|
| 1340 |
+
#include <ATen/ops/zero_native.h>
|
| 1341 |
+
#include <ATen/ops/zeros_native.h>
|
| 1342 |
+
#include <ATen/ops/zeros_like_native.h>
|
| 1343 |
+
|
| 1344 |
+
|
.venv/lib/python3.11/site-packages/torch/include/ATen/NestedTensorImpl.h
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/MemoryOverlap.h>
|
| 3 |
+
#include <ATen/Tensor.h>
|
| 4 |
+
#include <c10/core/DispatchKey.h>
|
| 5 |
+
#include <c10/core/DispatchKeySet.h>
|
| 6 |
+
#include <c10/core/MemoryFormat.h>
|
| 7 |
+
#include <c10/core/TensorImpl.h>
|
| 8 |
+
#include <c10/util/ArrayRef.h>
|
| 9 |
+
#include <c10/util/Exception.h>
|
| 10 |
+
#include <c10/util/Metaprogramming.h>
|
| 11 |
+
#include <c10/util/irange.h>
|
| 12 |
+
|
| 13 |
+
namespace at::native {
|
| 14 |
+
struct NestedTensorImpl;
|
| 15 |
+
inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt);
|
| 16 |
+
int64_t get_numel_from_nested_size_tensor(const at::Tensor& tensor);
|
| 17 |
+
at::Tensor construct_nested_strides(const at::Tensor& nested_size);
|
| 18 |
+
at::Tensor construct_offsets(const at::Tensor& nested_size);
|
| 19 |
+
|
| 20 |
+
struct TORCH_API NestedTensorImpl : public c10::TensorImpl {
|
| 21 |
+
explicit NestedTensorImpl(
|
| 22 |
+
Storage storage,
|
| 23 |
+
c10::DispatchKeySet key_set,
|
| 24 |
+
const caffe2::TypeMeta data_type,
|
| 25 |
+
at::Tensor nested_sizes,
|
| 26 |
+
at::Tensor nested_strides,
|
| 27 |
+
at::Tensor storage_offsets);
|
| 28 |
+
|
| 29 |
+
explicit NestedTensorImpl(
|
| 30 |
+
const at::Tensor& buffer,
|
| 31 |
+
at::Tensor nested_sizes,
|
| 32 |
+
at::Tensor nested_strides,
|
| 33 |
+
at::Tensor storage_offsets);
|
| 34 |
+
// assume contiguous, `nested_strides` and `offsets`
|
| 35 |
+
// can be infered from `nested_sizes`
|
| 36 |
+
explicit NestedTensorImpl(
|
| 37 |
+
const at::Tensor& buffer,
|
| 38 |
+
const at::Tensor& nested_sizes);
|
| 39 |
+
|
| 40 |
+
// This constructor is used creating view tensors from nested tensors
|
| 41 |
+
explicit NestedTensorImpl(
|
| 42 |
+
c10::TensorImpl::ImplType impl_type,
|
| 43 |
+
const at::Tensor& base_tensor,
|
| 44 |
+
at::Tensor nested_sizes,
|
| 45 |
+
at::Tensor nested_strides,
|
| 46 |
+
at::Tensor storage_offsets);
|
| 47 |
+
|
| 48 |
+
// TODO: don't expose private implementation details like this; in
|
| 49 |
+
// particular, resizing this tensor will mess up our dim() and
|
| 50 |
+
// callers cannot fix it.
|
| 51 |
+
const Tensor& get_nested_sizes() const {
|
| 52 |
+
return nested_sizes_;
|
| 53 |
+
}
|
| 54 |
+
// TODO: don't expose private implementation details like this
|
| 55 |
+
const Tensor& get_nested_strides() const {
|
| 56 |
+
return nested_strides_;
|
| 57 |
+
}
|
| 58 |
+
const Tensor& get_storage_offsets() const {
|
| 59 |
+
return storage_offsets_;
|
| 60 |
+
}
|
| 61 |
+
// Returns nullopt if the ith dimension is irregular. The ith dimension
|
| 62 |
+
// of a NestedTensor is regular if the unbound tensors match in
|
| 63 |
+
// size at the (i-1)th dimension.
|
| 64 |
+
std::optional<int64_t> opt_size(int64_t d) const;
|
| 65 |
+
|
| 66 |
+
int64_t size(int64_t d) const {
|
| 67 |
+
std::optional<int64_t> optional_size = this->opt_size(d);
|
| 68 |
+
TORCH_CHECK(
|
| 69 |
+
optional_size.has_value(),
|
| 70 |
+
"Given dimension ",
|
| 71 |
+
d,
|
| 72 |
+
" is irregular and does not have a size.");
|
| 73 |
+
return *optional_size;
|
| 74 |
+
}
|
| 75 |
+
/**
|
| 76 |
+
* Return a view of the nested tensor as a 1 dimensional contiguous tensor.
|
| 77 |
+
*
|
| 78 |
+
* The buffer tensor created by this function shares the same storage_impl as
|
| 79 |
+
* the original nested tensor, and therefore can be seen as a view.
|
| 80 |
+
*
|
| 81 |
+
* @return A newly constructed view tensor
|
| 82 |
+
*/
|
| 83 |
+
at::Tensor get_buffer() const {
|
| 84 |
+
TORCH_CHECK(
|
| 85 |
+
nested_tensor_impl_is_contiguous(this),
|
| 86 |
+
"NestedTensor must be contiguous to get buffer.");
|
| 87 |
+
return get_unsafe_storage_as_tensor();
|
| 88 |
+
}
|
| 89 |
+
/**
|
| 90 |
+
* If possible use get_buffer() instead. This function returns the storage
|
| 91 |
+
* as a tensor directly, which is not safe to use in general. If using this
|
| 92 |
+
* function, The caller must ensure to account for nested_sizes,
|
| 93 |
+
* nested_strides and storage_offsets.
|
| 94 |
+
*
|
| 95 |
+
* @return A newly constructed view tensor
|
| 96 |
+
*/
|
| 97 |
+
at::Tensor get_unsafe_storage_as_tensor() const {
|
| 98 |
+
auto buffer_key_set_ = generate_buffer_key_set();
|
| 99 |
+
const auto buffer_size = get_buffer_size();
|
| 100 |
+
auto buffer_tensor_impl = c10::make_intrusive<TensorImpl>(
|
| 101 |
+
c10::TensorImpl::VIEW, Storage(storage_), buffer_key_set_, data_type_);
|
| 102 |
+
buffer_tensor_impl->set_sizes_contiguous(
|
| 103 |
+
c10::makeArrayRef(static_cast<int64_t>(buffer_size)));
|
| 104 |
+
return Tensor(buffer_tensor_impl);
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
size_t get_buffer_size() const {
|
| 108 |
+
return storage_.nbytes() / data_type_.itemsize();
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
protected:
|
| 112 |
+
const char* tensorimpl_type_name() const override;
|
| 113 |
+
|
| 114 |
+
// TODO: numel_custom and is_contiguous_custom can be profitably overridden
|
| 115 |
+
// with real implementations
|
| 116 |
+
int64_t numel_custom() const override;
|
| 117 |
+
c10::SymInt sym_numel_custom() const override;
|
| 118 |
+
bool is_contiguous_custom(MemoryFormat) const override;
|
| 119 |
+
int64_t size_custom(int64_t d) const override {
|
| 120 |
+
return this->size(d);
|
| 121 |
+
}
|
| 122 |
+
c10::SymInt sym_size_custom(int64_t d) const override {
|
| 123 |
+
return c10::SymInt{this->size(d)};
|
| 124 |
+
}
|
| 125 |
+
IntArrayRef sizes_custom() const override;
|
| 126 |
+
c10::SymIntArrayRef sym_sizes_custom() const override;
|
| 127 |
+
IntArrayRef strides_custom() const override;
|
| 128 |
+
c10::SymIntArrayRef sym_strides_custom() const override;
|
| 129 |
+
|
| 130 |
+
// this one is real
|
| 131 |
+
int64_t dim_custom() const override;
|
| 132 |
+
|
| 133 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 134 |
+
const c10::VariableVersion& version_counter,
|
| 135 |
+
bool allow_tensor_metadata_change) const override;
|
| 136 |
+
|
| 137 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 138 |
+
c10::VariableVersion&& version_counter,
|
| 139 |
+
bool allow_tensor_metadata_change) const override;
|
| 140 |
+
|
| 141 |
+
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
|
| 142 |
+
copy_tensor_metadata(
|
| 143 |
+
/*src_impl=*/impl.get(),
|
| 144 |
+
/*dest_impl=*/this,
|
| 145 |
+
/*version_counter=*/version_counter(),
|
| 146 |
+
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
private:
|
| 150 |
+
// Must be called after any changes to our dim() to sync the state
|
| 151 |
+
// to TensorImpl.
|
| 152 |
+
void refresh_dim();
|
| 153 |
+
|
| 154 |
+
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
| 155 |
+
const at::Tensor nested_sizes_, nested_strides_;
|
| 156 |
+
// The starting positions of the underlying tensors in contiguous buffer
|
| 157 |
+
// i.e. the buffer memory offsets to get the underlying tensors
|
| 158 |
+
// The reason to keep this metadata is that, without strong enough constraint
|
| 159 |
+
// it cannot be derived from `nested_sizes_`
|
| 160 |
+
// and `nested_strides_`:
|
| 161 |
+
// 1. when buffer has blanks, e.g. [tensor1, blank, tensor2]
|
| 162 |
+
// this can happen e.g. after slicing a nested tensor
|
| 163 |
+
// 2. when multiple tensors share a same memory
|
| 164 |
+
// 3. when the nesting ordering is changed, e.g. [tensor1, tensor3, tensor2]
|
| 165 |
+
// Some strong enough constraints are:
|
| 166 |
+
// 1. every underlying tensor is contiguous in memory
|
| 167 |
+
// && nesting in ascending order
|
| 168 |
+
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
| 169 |
+
const at::Tensor storage_offsets_;
|
| 170 |
+
// NOTE: -1 here means the size is missing
|
| 171 |
+
// Optional to allow it to be computed lazily from nested.
|
| 172 |
+
// TODO: maybe we can remove this metadata since
|
| 173 |
+
// we can compute it from `nested_sizes_`
|
| 174 |
+
mutable std::optional<std::vector<int64_t>> opt_sizes_;
|
| 175 |
+
|
| 176 |
+
template <typename VariableVersion>
|
| 177 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
|
| 178 |
+
VariableVersion&& version_counter,
|
| 179 |
+
bool allow_tensor_metadata_change) const;
|
| 180 |
+
|
| 181 |
+
/**
|
| 182 |
+
* Generates a non-nested key_set from a nested tensor.
|
| 183 |
+
*
|
| 184 |
+
* For many nested tensor kernel implementations a buffer tensor
|
| 185 |
+
* is generated and redispatched to a non-nested kernel this function
|
| 186 |
+
* generates the key set used by that buffer tensor
|
| 187 |
+
*
|
| 188 |
+
* @return Appropriate key set for non-nested tensor
|
| 189 |
+
*/
|
| 190 |
+
inline c10::DispatchKeySet generate_buffer_key_set() const {
|
| 191 |
+
auto buffer_key_set = this->key_set();
|
| 192 |
+
const bool Autograd = buffer_key_set.has_any(c10::autograd_dispatch_keyset);
|
| 193 |
+
// Remove nested tensor specific keys
|
| 194 |
+
buffer_key_set = buffer_key_set -
|
| 195 |
+
c10::DispatchKeySet{
|
| 196 |
+
c10::DispatchKey::NestedTensor,
|
| 197 |
+
c10::DispatchKey::AutogradNestedTensor};
|
| 198 |
+
|
| 199 |
+
// Add dense tensor specific keys
|
| 200 |
+
buffer_key_set =
|
| 201 |
+
buffer_key_set | c10::DispatchKeySet{c10::DispatchKey::Dense};
|
| 202 |
+
buffer_key_set = Autograd
|
| 203 |
+
? c10::DispatchKeySet{c10::DispatchKey::Autograd} | buffer_key_set
|
| 204 |
+
: buffer_key_set;
|
| 205 |
+
|
| 206 |
+
return buffer_key_set;
|
| 207 |
+
}
|
| 208 |
+
};
|
| 209 |
+
|
| 210 |
+
inline NestedTensorImpl* get_nested_tensor_impl_or_null(
|
| 211 |
+
const at::Tensor& tensor) {
|
| 212 |
+
if (tensor.is_nested()) {
|
| 213 |
+
return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
|
| 214 |
+
}
|
| 215 |
+
return nullptr;
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
inline NestedTensorImpl* get_nested_tensor_impl(const at::Tensor& tensor) {
|
| 219 |
+
TORCH_CHECK(
|
| 220 |
+
tensor.is_nested(), "get_nested_tensor_impl requires a NestedTensor.");
|
| 221 |
+
return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt) {
|
| 225 |
+
int64_t ntensors = nt->size(0);
|
| 226 |
+
if (ntensors == 0) {
|
| 227 |
+
return true;
|
| 228 |
+
}
|
| 229 |
+
const Tensor &sizemat = nt->get_nested_sizes(),
|
| 230 |
+
&stridemat = nt->get_nested_strides();
|
| 231 |
+
const int64_t* offsets_ptr =
|
| 232 |
+
nt->get_storage_offsets().const_data_ptr<int64_t>();
|
| 233 |
+
int64_t orig_dim = sizemat.size(1);
|
| 234 |
+
// nesting scalars
|
| 235 |
+
if (orig_dim == 0) {
|
| 236 |
+
// each scalar must be contiguous
|
| 237 |
+
// if there is blank memory between underlying scalars
|
| 238 |
+
for (int64_t i = 0; i < ntensors; i++) {
|
| 239 |
+
if (offsets_ptr[i] != i) {
|
| 240 |
+
return false;
|
| 241 |
+
}
|
| 242 |
+
}
|
| 243 |
+
}
|
| 244 |
+
// nesting tensors
|
| 245 |
+
else {
|
| 246 |
+
// if any underlying tensor is non-contiguous
|
| 247 |
+
const int64_t *sizemat_ptr = sizemat.const_data_ptr<int64_t>(),
|
| 248 |
+
*stridemat_ptr = stridemat.const_data_ptr<int64_t>();
|
| 249 |
+
for (int64_t i = 0; i < ntensors; i++) {
|
| 250 |
+
if (stridemat_ptr[orig_dim - 1] != 1) {
|
| 251 |
+
return false;
|
| 252 |
+
}
|
| 253 |
+
int64_t product = sizemat_ptr[orig_dim - 1];
|
| 254 |
+
for (int64_t j = orig_dim - 2; j >= 0; j--) {
|
| 255 |
+
if (stridemat_ptr[j] != product) {
|
| 256 |
+
return false;
|
| 257 |
+
}
|
| 258 |
+
product *= sizemat_ptr[j];
|
| 259 |
+
}
|
| 260 |
+
sizemat_ptr += orig_dim;
|
| 261 |
+
stridemat_ptr += orig_dim;
|
| 262 |
+
}
|
| 263 |
+
// if there is blank memory between underlying tensors
|
| 264 |
+
if (offsets_ptr[0] != 0) {
|
| 265 |
+
return false;
|
| 266 |
+
}
|
| 267 |
+
sizemat_ptr = sizemat.const_data_ptr<int64_t>();
|
| 268 |
+
stridemat_ptr = stridemat.const_data_ptr<int64_t>();
|
| 269 |
+
for (int64_t i = 1; i < ntensors; i++) {
|
| 270 |
+
if (offsets_ptr[i] !=
|
| 271 |
+
offsets_ptr[i - 1] + *sizemat_ptr * *stridemat_ptr) {
|
| 272 |
+
return false;
|
| 273 |
+
}
|
| 274 |
+
sizemat_ptr += orig_dim;
|
| 275 |
+
stridemat_ptr += orig_dim;
|
| 276 |
+
}
|
| 277 |
+
}
|
| 278 |
+
// everything is fine
|
| 279 |
+
return true;
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
inline const at::Tensor& get_nested_sizes(const at::Tensor& tensor) {
|
| 283 |
+
return get_nested_tensor_impl(tensor)->get_nested_sizes();
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/OpMathType.h
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/ScalarType.h>
|
| 4 |
+
#include <c10/util/BFloat16.h>
|
| 5 |
+
#include <c10/util/Exception.h>
|
| 6 |
+
#include <c10/util/Float8_e4m3fn.h>
|
| 7 |
+
#include <c10/util/Float8_e4m3fnuz.h>
|
| 8 |
+
#include <c10/util/Float8_e5m2.h>
|
| 9 |
+
#include <c10/util/Float8_e5m2fnuz.h>
|
| 10 |
+
#include <c10/util/Half.h>
|
| 11 |
+
|
| 12 |
+
namespace at {
|
| 13 |
+
|
| 14 |
+
// For FP16 or BFloat16 inputs, ops should perform internal math in FP32.
|
| 15 |
+
template <typename scalar_t>
|
| 16 |
+
struct OpMathType {
|
| 17 |
+
using type = scalar_t;
|
| 18 |
+
};
|
| 19 |
+
template <>
|
| 20 |
+
struct OpMathType<at::Half> {
|
| 21 |
+
using type = float;
|
| 22 |
+
};
|
| 23 |
+
template <>
|
| 24 |
+
struct OpMathType<at::BFloat16> {
|
| 25 |
+
using type = float;
|
| 26 |
+
};
|
| 27 |
+
template <>
|
| 28 |
+
struct OpMathType<at::Float8_e5m2> {
|
| 29 |
+
using type = float;
|
| 30 |
+
};
|
| 31 |
+
template <>
|
| 32 |
+
struct OpMathType<at::Float8_e4m3fn> {
|
| 33 |
+
using type = float;
|
| 34 |
+
};
|
| 35 |
+
template <>
|
| 36 |
+
struct OpMathType<at::Float8_e5m2fnuz> {
|
| 37 |
+
using type = float;
|
| 38 |
+
};
|
| 39 |
+
template <>
|
| 40 |
+
struct OpMathType<at::Float8_e4m3fnuz> {
|
| 41 |
+
using type = float;
|
| 42 |
+
};
|
| 43 |
+
template <>
|
| 44 |
+
struct OpMathType<c10::complex<Half>> {
|
| 45 |
+
using type = c10::complex<float>;
|
| 46 |
+
};
|
| 47 |
+
|
| 48 |
+
template <typename T>
|
| 49 |
+
using opmath_type = typename OpMathType<T>::type;
|
| 50 |
+
|
| 51 |
+
namespace {
|
| 52 |
+
|
| 53 |
+
inline c10::ScalarType toOpMathType(const c10::ScalarType type) {
|
| 54 |
+
switch (type) {
|
| 55 |
+
#define DEFINE_CASE(scalar_t, TypeNum) \
|
| 56 |
+
case ScalarType::TypeNum: \
|
| 57 |
+
return CppTypeToScalarType<at::opmath_type<scalar_t>>::value;
|
| 58 |
+
|
| 59 |
+
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
|
| 60 |
+
#undef DEFINE_CASE
|
| 61 |
+
|
| 62 |
+
default:
|
| 63 |
+
TORCH_INTERNAL_ASSERT(false, "Unrecognized ScalarType: ", type);
|
| 64 |
+
}
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
} // namespace
|
| 68 |
+
|
| 69 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/PadNd.h
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <c10/util/Exception.h>
|
| 3 |
+
#include <c10/util/string_view.h>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
|
| 7 |
+
enum class padding_mode {
|
| 8 |
+
reflect,
|
| 9 |
+
replicate,
|
| 10 |
+
circular,
|
| 11 |
+
constant,
|
| 12 |
+
};
|
| 13 |
+
|
| 14 |
+
static inline c10::string_view padding_mode_string(padding_mode m) {
|
| 15 |
+
switch (m) {
|
| 16 |
+
case padding_mode::reflect:
|
| 17 |
+
return "reflect";
|
| 18 |
+
case padding_mode::replicate:
|
| 19 |
+
return "replicate";
|
| 20 |
+
case padding_mode::circular:
|
| 21 |
+
return "circular";
|
| 22 |
+
case padding_mode::constant:
|
| 23 |
+
return "constant";
|
| 24 |
+
}
|
| 25 |
+
TORCH_CHECK(false, "Invalid padding mode (", static_cast<int64_t>(m), ")");
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/Parallel.h
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/Config.h>
|
| 3 |
+
#include <c10/macros/Macros.h>
|
| 4 |
+
#include <functional>
|
| 5 |
+
#include <string>
|
| 6 |
+
|
| 7 |
+
namespace at {
|
| 8 |
+
|
| 9 |
+
inline int64_t divup(int64_t x, int64_t y) {
|
| 10 |
+
return (x + y - 1) / y;
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
// Called during new thread initialization
|
| 14 |
+
TORCH_API void init_num_threads();
|
| 15 |
+
|
| 16 |
+
// Sets the number of threads to be used in parallel region
|
| 17 |
+
TORCH_API void set_num_threads(int);
|
| 18 |
+
|
| 19 |
+
// Returns the maximum number of threads that may be used in a parallel region
|
| 20 |
+
TORCH_API int get_num_threads();
|
| 21 |
+
|
| 22 |
+
// Returns the current thread number (starting from 0)
|
| 23 |
+
// in the current parallel region, or 0 in the sequential region
|
| 24 |
+
TORCH_API int get_thread_num();
|
| 25 |
+
|
| 26 |
+
// Checks whether the code runs in parallel region
|
| 27 |
+
TORCH_API bool in_parallel_region();
|
| 28 |
+
|
| 29 |
+
namespace internal {
|
| 30 |
+
|
| 31 |
+
// Initialise num_threads lazily at first parallel call
|
| 32 |
+
inline void lazy_init_num_threads() {
|
| 33 |
+
thread_local bool init = false;
|
| 34 |
+
if (C10_UNLIKELY(!init)) {
|
| 35 |
+
at::init_num_threads();
|
| 36 |
+
init = true;
|
| 37 |
+
}
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
TORCH_API void set_thread_num(int);
|
| 41 |
+
|
| 42 |
+
class TORCH_API ThreadIdGuard {
|
| 43 |
+
public:
|
| 44 |
+
ThreadIdGuard(int new_id) : old_id_(at::get_thread_num()) {
|
| 45 |
+
set_thread_num(new_id);
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
~ThreadIdGuard() {
|
| 49 |
+
set_thread_num(old_id_);
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
private:
|
| 53 |
+
int old_id_;
|
| 54 |
+
};
|
| 55 |
+
|
| 56 |
+
} // namespace internal
|
| 57 |
+
|
| 58 |
+
/*
|
| 59 |
+
parallel_for
|
| 60 |
+
|
| 61 |
+
begin: index at which to start applying user function
|
| 62 |
+
|
| 63 |
+
end: index at which to stop applying user function
|
| 64 |
+
|
| 65 |
+
grain_size: number of elements per chunk. impacts the degree of parallelization
|
| 66 |
+
|
| 67 |
+
f: user function applied in parallel to the chunks, signature:
|
| 68 |
+
void f(int64_t begin, int64_t end)
|
| 69 |
+
|
| 70 |
+
Warning: parallel_for does NOT copy thread local
|
| 71 |
+
states from the current thread to the worker threads.
|
| 72 |
+
This means for example that Tensor operations CANNOT be used in the
|
| 73 |
+
body of your function, only data pointers.
|
| 74 |
+
*/
|
| 75 |
+
template <class F>
|
| 76 |
+
inline void parallel_for(
|
| 77 |
+
const int64_t begin,
|
| 78 |
+
const int64_t end,
|
| 79 |
+
const int64_t grain_size,
|
| 80 |
+
const F& f);
|
| 81 |
+
|
| 82 |
+
/*
|
| 83 |
+
parallel_reduce
|
| 84 |
+
|
| 85 |
+
begin: index at which to start applying reduction
|
| 86 |
+
|
| 87 |
+
end: index at which to stop applying reduction
|
| 88 |
+
|
| 89 |
+
grain_size: number of elements per chunk. impacts number of elements in
|
| 90 |
+
intermediate results tensor and degree of parallelization.
|
| 91 |
+
|
| 92 |
+
ident: identity for binary combination function sf. sf(ident, x) needs to return
|
| 93 |
+
x.
|
| 94 |
+
|
| 95 |
+
f: function for reduction over a chunk. f needs to be of signature scalar_t
|
| 96 |
+
f(int64_t partial_begin, int64_t partial_end, scalar_t identifiy)
|
| 97 |
+
|
| 98 |
+
sf: function to combine two partial results. sf needs to be of signature
|
| 99 |
+
scalar_t sf(scalar_t x, scalar_t y)
|
| 100 |
+
|
| 101 |
+
For example, you might have a tensor of 10000 entires and want to sum together
|
| 102 |
+
all the elements. Parallel_reduce with a grain_size of 2500 will then allocate
|
| 103 |
+
an intermediate result tensor with 4 elements. Then it will execute the function
|
| 104 |
+
"f" you provide and pass the beginning and end index of these chunks, so
|
| 105 |
+
0-2499, 2500-4999, etc. and the combination identity. It will then write out
|
| 106 |
+
the result from each of these chunks into the intermediate result tensor. After
|
| 107 |
+
that it'll reduce the partial results from each chunk into a single number using
|
| 108 |
+
the combination function sf and the identity ident. For a total summation this
|
| 109 |
+
would be "+" and 0 respectively. This is similar to tbb's approach [1], where
|
| 110 |
+
you need to provide a function to accumulate a subrange, a function to combine
|
| 111 |
+
two partial results and an identity.
|
| 112 |
+
|
| 113 |
+
Warning: parallel_reduce does NOT copy thread local
|
| 114 |
+
states from the current thread to the worker threads.
|
| 115 |
+
This means for example that Tensor operations CANNOT be used in the
|
| 116 |
+
body of your function, only data pointers.
|
| 117 |
+
|
| 118 |
+
[1] https://software.intel.com/en-us/node/506154
|
| 119 |
+
*/
|
| 120 |
+
template <class scalar_t, class F, class SF>
|
| 121 |
+
inline scalar_t parallel_reduce(
|
| 122 |
+
const int64_t begin,
|
| 123 |
+
const int64_t end,
|
| 124 |
+
const int64_t grain_size,
|
| 125 |
+
const scalar_t ident,
|
| 126 |
+
const F& f,
|
| 127 |
+
const SF& sf);
|
| 128 |
+
|
| 129 |
+
// Returns a detailed string describing parallelization settings
|
| 130 |
+
TORCH_API std::string get_parallel_info();
|
| 131 |
+
|
| 132 |
+
// Sets number of threads used for inter-op parallelism
|
| 133 |
+
TORCH_API void set_num_interop_threads(int);
|
| 134 |
+
|
| 135 |
+
// Returns the number of threads used for inter-op parallelism
|
| 136 |
+
TORCH_API int get_num_interop_threads();
|
| 137 |
+
|
| 138 |
+
// Launches inter-op parallel task
|
| 139 |
+
TORCH_API void launch(std::function<void()> func);
|
| 140 |
+
namespace internal {
|
| 141 |
+
void launch_no_thread_state(std::function<void()> fn);
|
| 142 |
+
} // namespace internal
|
| 143 |
+
|
| 144 |
+
// Launches intra-op parallel task
|
| 145 |
+
TORCH_API void intraop_launch(std::function<void()> func);
|
| 146 |
+
|
| 147 |
+
// Returns number of intra-op threads used by default
|
| 148 |
+
TORCH_API int intraop_default_num_threads();
|
| 149 |
+
|
| 150 |
+
} // namespace at
|
| 151 |
+
|
| 152 |
+
#if AT_PARALLEL_OPENMP
|
| 153 |
+
#include <ATen/ParallelOpenMP.h> // IWYU pragma: keep
|
| 154 |
+
#elif AT_PARALLEL_NATIVE
|
| 155 |
+
#include <ATen/ParallelNative.h> // IWYU pragma: keep
|
| 156 |
+
#endif
|
| 157 |
+
|
| 158 |
+
#include <ATen/Parallel-inl.h> // IWYU pragma: keep
|
.venv/lib/python3.11/site-packages/torch/include/ATen/ParallelFuture.h
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/ivalue.h>
|
| 4 |
+
#include <c10/macros/Macros.h>
|
| 5 |
+
#include <functional>
|
| 6 |
+
|
| 7 |
+
namespace at {
|
| 8 |
+
|
| 9 |
+
// Launches intra-op parallel task, returns a future
|
| 10 |
+
TORCH_API c10::intrusive_ptr<c10::ivalue::Future> intraop_launch_future(
|
| 11 |
+
std::function<void()> func);
|
| 12 |
+
|
| 13 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/RegistrationDeclarations.h
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/torch/include/ATen/SavedTensorHooks.h
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/SafePyObject.h>
|
| 4 |
+
#include <c10/macros/Export.h>
|
| 5 |
+
#include <c10/util/python_stub.h>
|
| 6 |
+
#include <optional>
|
| 7 |
+
#include <stack>
|
| 8 |
+
#include <string>
|
| 9 |
+
|
| 10 |
+
#include <utility>
|
| 11 |
+
|
| 12 |
+
namespace at {
|
| 13 |
+
|
| 14 |
+
namespace impl {
|
| 15 |
+
|
| 16 |
+
struct TORCH_API SavedTensorDefaultHooksTLS {
|
| 17 |
+
// PyObject is defined in c10/util/python_stub.h
|
| 18 |
+
std::stack<std::pair<c10::SafePyObject, c10::SafePyObject>> stack;
|
| 19 |
+
|
| 20 |
+
// See NOTE: [Disabling SavedTensorDefaultHooks] for context
|
| 21 |
+
// NOTE: [disabled_error_message invariant]
|
| 22 |
+
// disabled_error_message is nullopt IFF Saved Tensor hooks is enabled
|
| 23 |
+
// We did this for efficiency (so we didn't have to keep a separate bool
|
| 24 |
+
// around)
|
| 25 |
+
std::optional<std::string> disabled_error_message;
|
| 26 |
+
|
| 27 |
+
// See NOTE: [Deferring tensor pack/unpack hooks until runtime]
|
| 28 |
+
bool is_tracing = false;
|
| 29 |
+
};
|
| 30 |
+
|
| 31 |
+
} // namespace impl
|
| 32 |
+
|
| 33 |
+
struct TORCH_API SavedTensorDefaultHooks {
|
| 34 |
+
static void push_hooks(
|
| 35 |
+
c10::SafePyObject pack_hook,
|
| 36 |
+
c10::SafePyObject unpack_hook);
|
| 37 |
+
static std::pair<c10::SafePyObject, c10::SafePyObject> pop_hooks();
|
| 38 |
+
static std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
|
| 39 |
+
get_hooks();
|
| 40 |
+
static void lazy_initialize();
|
| 41 |
+
|
| 42 |
+
static const impl::SavedTensorDefaultHooksTLS& get_tls_state();
|
| 43 |
+
static void set_tls_state(const impl::SavedTensorDefaultHooksTLS& tls);
|
| 44 |
+
|
| 45 |
+
// NOTE: [Disabling SavedTensorDefaultHooks]
|
| 46 |
+
// A developer of a PyTorch feature may choose to disable SavedTensorDefault
|
| 47 |
+
// hooks, especially if their feature does not work with it. If they are
|
| 48 |
+
// disabled, then the following will raise an error:
|
| 49 |
+
// - Attempting to push_hooks
|
| 50 |
+
// - calling disable(message) with a non-zero stack (hooks) size
|
| 51 |
+
static void disable(const std::string& error_message);
|
| 52 |
+
static void enable();
|
| 53 |
+
static bool is_enabled();
|
| 54 |
+
static const std::optional<std::string>& get_disabled_error_message();
|
| 55 |
+
|
| 56 |
+
// NOTE: [Deferring tensor pack/unpack hooks until runtime]
|
| 57 |
+
// To preserve eager semantics of pack/unpack hooks firing only once per saved
|
| 58 |
+
// variable, Dynamo/AOTAutograd need to defer hook firing until runtime. Using
|
| 59 |
+
// disable() would loud error at trace time, and pushing a no-op hook would
|
| 60 |
+
// fail when the traced code is wrapped in a disable_saved_tensors_hooks ctx.
|
| 61 |
+
// To do so, we disable these hooks during tracing. See
|
| 62 |
+
// https://github.com/pytorch/pytorch/issues/113263.
|
| 63 |
+
static bool set_tracing(bool is_tracing);
|
| 64 |
+
};
|
| 65 |
+
|
| 66 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/Scalar.h
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Scalar.h>
|
.venv/lib/python3.11/site-packages/torch/include/ATen/ScalarOps.h
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/Tensor.h>
|
| 4 |
+
#include <c10/core/Scalar.h>
|
| 5 |
+
|
| 6 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 7 |
+
#include <ATen/Functions.h>
|
| 8 |
+
#else
|
| 9 |
+
#include <ATen/ops/scalar_tensor.h>
|
| 10 |
+
#endif
|
| 11 |
+
|
| 12 |
+
namespace at::detail {
|
| 13 |
+
// When filling a number to 1-element CPU tensor, we want to skip
|
| 14 |
+
// everything but manipulate data ptr directly.
|
| 15 |
+
// Ideally this fast pass should be implemented in TensorIterator,
|
| 16 |
+
// but we also want to skip compute_types which in not avoidable
|
| 17 |
+
// in TensorIterator for now.
|
| 18 |
+
Tensor& scalar_fill(Tensor& self, const Scalar& value);
|
| 19 |
+
TORCH_API Tensor scalar_tensor_static(
|
| 20 |
+
const Scalar& s,
|
| 21 |
+
std::optional<ScalarType> dtype_opt,
|
| 22 |
+
std::optional<Device> device_opt);
|
| 23 |
+
} // namespace at::detail
|
| 24 |
+
|
| 25 |
+
// This is in the c10 namespace because we use ADL to find the functions in it.
|
| 26 |
+
namespace c10 {
|
| 27 |
+
|
| 28 |
+
// FIXME: this should be (and was) Scalar::toTensor, but there is currently no
|
| 29 |
+
// way to implement this without going through Derived Types (which are not part
|
| 30 |
+
// of core).
|
| 31 |
+
inline at::Tensor scalar_to_tensor(
|
| 32 |
+
const Scalar& s,
|
| 33 |
+
const Device device = at::kCPU) {
|
| 34 |
+
// This is the fast track we have for CPU scalar tensors.
|
| 35 |
+
if (device == at::kCPU) {
|
| 36 |
+
return at::detail::scalar_tensor_static(s, s.type(), at::kCPU);
|
| 37 |
+
}
|
| 38 |
+
return at::scalar_tensor(s, at::device(device).dtype(s.type()));
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
} // namespace c10
|
| 42 |
+
|
| 43 |
+
namespace at::native {
|
| 44 |
+
|
| 45 |
+
inline Tensor wrapped_scalar_tensor(
|
| 46 |
+
const Scalar& scalar,
|
| 47 |
+
const Device device = at::kCPU) {
|
| 48 |
+
auto tensor = scalar_to_tensor(scalar, device);
|
| 49 |
+
tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
|
| 50 |
+
return tensor;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/ScalarType.h
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/ATenGeneral.h> // for BC reasons
|
| 3 |
+
#include <c10/core/Backend.h>
|
| 4 |
+
#include <c10/core/ScalarType.h>
|
.venv/lib/python3.11/site-packages/torch/include/ATen/SparseCsrTensorImpl.h
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/Tensor.h>
|
| 4 |
+
#include <c10/core/TensorImpl.h>
|
| 5 |
+
#include <c10/core/impl/TorchDispatchModeTLS.h>
|
| 6 |
+
#include <c10/util/Exception.h>
|
| 7 |
+
namespace at {
|
| 8 |
+
|
| 9 |
+
// Struct implementing a sparse CSR tensor. It uses three 1-D tensors for
|
| 10 |
+
// denoting the data: `crow_indices_`, `col_indices_` and `values_`.
|
| 11 |
+
// The `crow_indices_` tensor is a integer tensor of shape `(size(0) + 1)`
|
| 12 |
+
// that represents the compressed row indices of the CSR tensor. The
|
| 13 |
+
// `col_indices_` tensor is an integer tensor of shape `(nnz())`
|
| 14 |
+
// that explicitly stores the column indices of each value of the sparse
|
| 15 |
+
// tensor. The `values_` tensor can be of any pytorch-supported data type
|
| 16 |
+
// and has shape `(nnz())`.
|
| 17 |
+
//
|
| 18 |
+
// Since the main advantage of the CSR format over the COO format is speed of
|
| 19 |
+
// computation, care must be taken to facilitate smooth interfacing of
|
| 20 |
+
// these data structures with optimized libraries such as MKL and MAGMA.
|
| 21 |
+
// Since the MKL interface for pytorch currently uses indexing with int32
|
| 22 |
+
// type, it is important to make sure that the `crow_indices` and `col_indices`
|
| 23 |
+
// are of type int32 when calling MKL routines such as SPMM or SPMV.
|
| 24 |
+
//
|
| 25 |
+
// If not calling MKL, it should be alright to use 64 bit integer tensors
|
| 26 |
+
// for indexing.
|
| 27 |
+
struct TORCH_API SparseCsrTensorImpl : public TensorImpl {
|
| 28 |
+
Tensor crow_indices_;
|
| 29 |
+
Tensor col_indices_;
|
| 30 |
+
Tensor values_;
|
| 31 |
+
Layout layout_;
|
| 32 |
+
|
| 33 |
+
public:
|
| 34 |
+
explicit SparseCsrTensorImpl(
|
| 35 |
+
at::DispatchKeySet,
|
| 36 |
+
at::Device device,
|
| 37 |
+
Layout layout,
|
| 38 |
+
const caffe2::TypeMeta);
|
| 39 |
+
|
| 40 |
+
void resize_(int64_t nnz, IntArrayRef size);
|
| 41 |
+
void resize_and_clear_(
|
| 42 |
+
int64_t sparse_dim,
|
| 43 |
+
int64_t dense_dim,
|
| 44 |
+
IntArrayRef size);
|
| 45 |
+
void resize_as_sparse_compressed_tensor_(const Tensor& src);
|
| 46 |
+
void set_member_tensors(
|
| 47 |
+
const Tensor& crow_indices,
|
| 48 |
+
const Tensor& col_indices,
|
| 49 |
+
const Tensor& values,
|
| 50 |
+
c10::SymIntArrayRef size);
|
| 51 |
+
void set_member_tensors(
|
| 52 |
+
const Tensor& crow_indices,
|
| 53 |
+
const Tensor& col_indices,
|
| 54 |
+
const Tensor& values,
|
| 55 |
+
IntArrayRef size);
|
| 56 |
+
const Tensor& compressed_indices() const {
|
| 57 |
+
return crow_indices_;
|
| 58 |
+
}
|
| 59 |
+
const Tensor& plain_indices() const {
|
| 60 |
+
return col_indices_;
|
| 61 |
+
}
|
| 62 |
+
const Tensor& values() const {
|
| 63 |
+
return values_;
|
| 64 |
+
}
|
| 65 |
+
int64_t nnz() {
|
| 66 |
+
return col_indices_.size(-1);
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
inline int64_t batch_dim() const noexcept {
|
| 70 |
+
return crow_indices_.dim() - 1;
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
inline int64_t sparse_dim() const noexcept {
|
| 74 |
+
return 2;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
inline int64_t dense_dim() const noexcept {
|
| 78 |
+
return values_.dim() - batch_dim() - block_dim() - 1;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
private:
|
| 82 |
+
inline int64_t block_dim() const noexcept {
|
| 83 |
+
return (layout_ == kSparseBsr || layout_ == kSparseBsc ? 2 : 0);
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
protected:
|
| 87 |
+
IntArrayRef strides_custom() const override;
|
| 88 |
+
SymIntArrayRef sym_strides_custom() const override;
|
| 89 |
+
bool is_contiguous_custom(MemoryFormat) const override;
|
| 90 |
+
|
| 91 |
+
public:
|
| 92 |
+
void set_size(int64_t dim, int64_t new_size) override;
|
| 93 |
+
void set_stride(int64_t dim, int64_t new_stride) override;
|
| 94 |
+
void set_storage_offset(int64_t storage_offset) override;
|
| 95 |
+
Layout layout_impl() const override {
|
| 96 |
+
return layout_;
|
| 97 |
+
}
|
| 98 |
+
void set_layout(Layout layout) {
|
| 99 |
+
switch (layout) {
|
| 100 |
+
case kSparseCsr:
|
| 101 |
+
case kSparseCsc:
|
| 102 |
+
case kSparseBsr:
|
| 103 |
+
case kSparseBsc:
|
| 104 |
+
layout_ = layout;
|
| 105 |
+
break;
|
| 106 |
+
default:
|
| 107 |
+
TORCH_CHECK(false, "unsupported layout ", layout);
|
| 108 |
+
}
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
template <typename VariableVersion>
|
| 112 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
|
| 113 |
+
VariableVersion&& version_counter,
|
| 114 |
+
bool allow_tensor_metadata_change) const {
|
| 115 |
+
const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len();
|
| 116 |
+
c10::impl::PyInterpreter&& interpreter = nullptr;
|
| 117 |
+
if (mode_stack_len > 0 &&
|
| 118 |
+
!c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
|
| 119 |
+
const auto& cur_torch_dispatch_mode_state =
|
| 120 |
+
c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
|
| 121 |
+
interpreter = cur_torch_dispatch_mode_state->pyinterpreter();
|
| 122 |
+
} else if (
|
| 123 |
+
key_set_.has(DispatchKey::Python) &&
|
| 124 |
+
!c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
|
| 125 |
+
interpreter = pyobj_slot_.load_pyobj_interpreter();
|
| 126 |
+
} else {
|
| 127 |
+
// otherwise just copy the SparseTensorImpl and not the PyObject.
|
| 128 |
+
auto impl = c10::make_intrusive<SparseCsrTensorImpl>(
|
| 129 |
+
key_set(), device(), layout_impl(), dtype());
|
| 130 |
+
copy_tensor_metadata(
|
| 131 |
+
/*src_sparse_impl=*/this,
|
| 132 |
+
/*dest_sparse_impl=*/impl.get(),
|
| 133 |
+
/*version_counter=*/version_counter,
|
| 134 |
+
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
|
| 135 |
+
impl->refresh_numel();
|
| 136 |
+
return impl;
|
| 137 |
+
}
|
| 138 |
+
auto r = interpreter->detach(this);
|
| 139 |
+
r->set_version_counter(std::forward<VariableVersion>(version_counter));
|
| 140 |
+
r->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
|
| 141 |
+
return r;
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
/**
|
| 145 |
+
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
|
| 146 |
+
*
|
| 147 |
+
* For usage of `version_counter` and `allow_tensor_metadata_change`,
|
| 148 |
+
* see NOTE [ TensorImpl Shallow-Copying ].
|
| 149 |
+
*/
|
| 150 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 151 |
+
const c10::VariableVersion& version_counter,
|
| 152 |
+
bool allow_tensor_metadata_change) const override {
|
| 153 |
+
return shallow_copy_and_detach_core(
|
| 154 |
+
version_counter, allow_tensor_metadata_change);
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
/**
|
| 158 |
+
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
|
| 159 |
+
*
|
| 160 |
+
* For usage of `version_counter` and `allow_tensor_metadata_change`,
|
| 161 |
+
* see NOTE [ TensorImpl Shallow-Copying ].
|
| 162 |
+
*/
|
| 163 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 164 |
+
c10::VariableVersion&& version_counter,
|
| 165 |
+
bool allow_tensor_metadata_change) const override {
|
| 166 |
+
return shallow_copy_and_detach_core(
|
| 167 |
+
std::move(version_counter), allow_tensor_metadata_change);
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
private:
|
| 171 |
+
explicit SparseCsrTensorImpl(
|
| 172 |
+
at::DispatchKeySet key_set,
|
| 173 |
+
const caffe2::TypeMeta data_type,
|
| 174 |
+
at::Tensor crow_indices,
|
| 175 |
+
at::Tensor col_indices,
|
| 176 |
+
at::Tensor values,
|
| 177 |
+
at::Layout layout);
|
| 178 |
+
|
| 179 |
+
const char* tensorimpl_type_name() const override;
|
| 180 |
+
|
| 181 |
+
/**
|
| 182 |
+
* Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
|
| 183 |
+
* storage_offset) from one TensorImpl to another TensorImpl.
|
| 184 |
+
*
|
| 185 |
+
* For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
|
| 186 |
+
* [ TensorImpl Shallow-Copying ].
|
| 187 |
+
*/
|
| 188 |
+
static void copy_tensor_metadata(
|
| 189 |
+
const SparseCsrTensorImpl* src_sparse_impl,
|
| 190 |
+
SparseCsrTensorImpl* dest_sparse_impl,
|
| 191 |
+
c10::VariableVersion version_counter,
|
| 192 |
+
bool allow_tensor_metadata_change) {
|
| 193 |
+
TensorImpl::copy_tensor_metadata(
|
| 194 |
+
src_sparse_impl,
|
| 195 |
+
dest_sparse_impl,
|
| 196 |
+
std::move(version_counter),
|
| 197 |
+
allow_tensor_metadata_change);
|
| 198 |
+
|
| 199 |
+
// Sparse-specific fields
|
| 200 |
+
dest_sparse_impl->crow_indices_ = src_sparse_impl->compressed_indices();
|
| 201 |
+
dest_sparse_impl->col_indices_ = src_sparse_impl->plain_indices();
|
| 202 |
+
dest_sparse_impl->values_ = src_sparse_impl->values();
|
| 203 |
+
dest_sparse_impl->layout_ = src_sparse_impl->layout_impl();
|
| 204 |
+
}
|
| 205 |
+
};
|
| 206 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/SparseCsrTensorUtils.h
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/SparseCsrTensorImpl.h>
|
| 4 |
+
#include <ATen/SparseTensorImpl.h>
|
| 5 |
+
#include <ATen/core/Tensor.h>
|
| 6 |
+
|
| 7 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 8 |
+
#include <ATen/Functions.h>
|
| 9 |
+
#include <ATen/NativeFunctions.h>
|
| 10 |
+
#include <ATen/Operators.h>
|
| 11 |
+
#else
|
| 12 |
+
#include <ATen/ops/_sparse_compressed_tensor_unsafe.h>
|
| 13 |
+
#include <ATen/ops/resize_as_sparse_native.h>
|
| 14 |
+
#endif
|
| 15 |
+
|
| 16 |
+
#define AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(LAYOUT, NAME, ...) \
|
| 17 |
+
[&] { \
|
| 18 |
+
const auto& the_layout = LAYOUT; \
|
| 19 |
+
switch (the_layout) { \
|
| 20 |
+
case kSparseCsr: \
|
| 21 |
+
case kSparseCsc: \
|
| 22 |
+
case kSparseBsr: \
|
| 23 |
+
case kSparseBsc: \
|
| 24 |
+
return __VA_ARGS__(); \
|
| 25 |
+
default: \
|
| 26 |
+
AT_ERROR( \
|
| 27 |
+
NAME, \
|
| 28 |
+
" expected sparse compressed tensor layout but got ", \
|
| 29 |
+
the_layout); \
|
| 30 |
+
} \
|
| 31 |
+
}()
|
| 32 |
+
|
| 33 |
+
#define AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( \
|
| 34 |
+
LAYOUT, NAME, ROW_DIM_ACTION, COLUMN_DIM_ACTION) \
|
| 35 |
+
[&]() { \
|
| 36 |
+
const auto& the_layout = LAYOUT; \
|
| 37 |
+
switch (the_layout) { \
|
| 38 |
+
case kSparseCsr: \
|
| 39 |
+
case kSparseBsr: \
|
| 40 |
+
return (ROW_DIM_ACTION)(); \
|
| 41 |
+
case kSparseCsc: \
|
| 42 |
+
case kSparseBsc: \
|
| 43 |
+
return (COLUMN_DIM_ACTION)(); \
|
| 44 |
+
default: \
|
| 45 |
+
AT_ERROR( \
|
| 46 |
+
NAME, \
|
| 47 |
+
" expected sparse compressed tensor layout but got ", \
|
| 48 |
+
the_layout); \
|
| 49 |
+
} \
|
| 50 |
+
}()
|
| 51 |
+
|
| 52 |
+
#define AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS( \
|
| 53 |
+
LAYOUT, NAME, NO_BLOCK_ACTION, BLOCK_ACTION) \
|
| 54 |
+
[&]() { \
|
| 55 |
+
const auto& the_layout = LAYOUT; \
|
| 56 |
+
switch (the_layout) { \
|
| 57 |
+
case kSparseCsr: \
|
| 58 |
+
case kSparseCsc: \
|
| 59 |
+
return (NO_BLOCK_ACTION)(); \
|
| 60 |
+
case kSparseBsr: \
|
| 61 |
+
case kSparseBsc: \
|
| 62 |
+
return (BLOCK_ACTION)(); \
|
| 63 |
+
default: \
|
| 64 |
+
AT_ERROR( \
|
| 65 |
+
NAME, \
|
| 66 |
+
" expected sparse compressed tensor layout but got ", \
|
| 67 |
+
the_layout); \
|
| 68 |
+
} \
|
| 69 |
+
}()
|
| 70 |
+
|
| 71 |
+
#define AT_DISPATCH_SPARSE_ROW_COMPRESSED_LAYOUTS( \
|
| 72 |
+
LAYOUT, NAME, ROW_DIM_ACTION) \
|
| 73 |
+
[&]() { \
|
| 74 |
+
const auto& the_layout = LAYOUT; \
|
| 75 |
+
switch (the_layout) { \
|
| 76 |
+
case kSparseCsr: \
|
| 77 |
+
case kSparseBsr: \
|
| 78 |
+
return (ROW_DIM_ACTION)(); \
|
| 79 |
+
default: \
|
| 80 |
+
AT_ERROR( \
|
| 81 |
+
NAME, \
|
| 82 |
+
" expected sparse row compressed tensor layout but got ", \
|
| 83 |
+
the_layout); \
|
| 84 |
+
} \
|
| 85 |
+
}()
|
| 86 |
+
|
| 87 |
+
#define AT_DISPATCH_SPARSE_COL_COMPRESSED_LAYOUTS( \
|
| 88 |
+
LAYOUT, NAME, COL_DIM_ACTION) \
|
| 89 |
+
[&]() { \
|
| 90 |
+
const auto& the_layout = LAYOUT; \
|
| 91 |
+
switch (the_layout) { \
|
| 92 |
+
case kSparseCsc: \
|
| 93 |
+
case kSparseBsc: \
|
| 94 |
+
return (COL_DIM_ACTION)(); \
|
| 95 |
+
default: \
|
| 96 |
+
AT_ERROR( \
|
| 97 |
+
NAME, \
|
| 98 |
+
" expected sparse column compressed tensor layout but got ", \
|
| 99 |
+
the_layout); \
|
| 100 |
+
} \
|
| 101 |
+
}()
|
| 102 |
+
|
| 103 |
+
#define AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \
|
| 104 |
+
[&]() { \
|
| 105 |
+
const auto& the_layout = LAYOUT; \
|
| 106 |
+
switch (the_layout) { \
|
| 107 |
+
case kSparseCsr: \
|
| 108 |
+
case kSparseCsc: \
|
| 109 |
+
return (ACTION)(); \
|
| 110 |
+
default: \
|
| 111 |
+
AT_ERROR( \
|
| 112 |
+
NAME, \
|
| 113 |
+
" expected sparse compressed (non-block) tensor layout but got ", \
|
| 114 |
+
the_layout); \
|
| 115 |
+
} \
|
| 116 |
+
}()
|
| 117 |
+
|
| 118 |
+
#define AT_DISPATCH_SPARSE_COMPRESSED_BLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \
|
| 119 |
+
[&]() { \
|
| 120 |
+
const auto& the_layout = LAYOUT; \
|
| 121 |
+
switch (the_layout) { \
|
| 122 |
+
case kSparseBsr: \
|
| 123 |
+
case kSparseBsc: \
|
| 124 |
+
return (ACTION)(); \
|
| 125 |
+
default: \
|
| 126 |
+
AT_ERROR( \
|
| 127 |
+
NAME, \
|
| 128 |
+
" expected sparse compressed block tensor layout but got ", \
|
| 129 |
+
the_layout); \
|
| 130 |
+
} \
|
| 131 |
+
}()
|
| 132 |
+
|
| 133 |
+
#define AT_DISPATCH_SPARSE_VALUE_TYPES(TYPE, NAME, ...) \
|
| 134 |
+
AT_DISPATCH_SWITCH( \
|
| 135 |
+
TYPE, \
|
| 136 |
+
NAME, \
|
| 137 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
|
| 138 |
+
kComplexHalf, kHalf, kBool, kBFloat16, __VA_ARGS__))
|
| 139 |
+
|
| 140 |
+
namespace at::sparse_csr {
|
| 141 |
+
|
| 142 |
+
// Implements RAII object to manage checking sparse tensor invariants:
|
| 143 |
+
class CheckSparseTensorInvariants {
|
| 144 |
+
bool old_state;
|
| 145 |
+
|
| 146 |
+
public:
|
| 147 |
+
CheckSparseTensorInvariants(bool state) {
|
| 148 |
+
old_state = at::globalContext().checkSparseTensorInvariants();
|
| 149 |
+
at::globalContext().setCheckSparseTensorInvariants(state);
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
~CheckSparseTensorInvariants() {
|
| 153 |
+
at::globalContext().setCheckSparseTensorInvariants(old_state);
|
| 154 |
+
}
|
| 155 |
+
};
|
| 156 |
+
|
| 157 |
+
using SparseCsrTensor = Tensor;
|
| 158 |
+
|
| 159 |
+
inline bool is_sparse_compressed(const Layout& layout) {
|
| 160 |
+
switch (layout) {
|
| 161 |
+
case kSparseCsr:
|
| 162 |
+
case kSparseCsc:
|
| 163 |
+
case kSparseBsr:
|
| 164 |
+
case kSparseBsc:
|
| 165 |
+
return true;
|
| 166 |
+
default:;
|
| 167 |
+
}
|
| 168 |
+
return false;
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
inline bool is_sparse_compressed(const Tensor& self) {
|
| 172 |
+
return is_sparse_compressed(self.layout());
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
inline SparseCsrTensorImpl* get_sparse_csr_impl(const SparseCsrTensor& self) {
|
| 176 |
+
AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(
|
| 177 |
+
self.layout(), "get_sparse_csr_impl", [&] {});
|
| 178 |
+
return static_cast<SparseCsrTensorImpl*>(self.unsafeGetTensorImpl());
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
inline std::string layoutToString(
|
| 182 |
+
Layout layout,
|
| 183 |
+
bool upper = false,
|
| 184 |
+
bool lower = false) {
|
| 185 |
+
switch (layout) {
|
| 186 |
+
case kSparseCsr:
|
| 187 |
+
return (upper ? "CSR" : (lower ? "csr" : "Csr"));
|
| 188 |
+
case kSparseCsc:
|
| 189 |
+
return (upper ? "CSC" : (lower ? "csc" : "Csc"));
|
| 190 |
+
case kSparseBsr:
|
| 191 |
+
return (upper ? "BSR" : (lower ? "bsr" : "Bsr"));
|
| 192 |
+
case kSparseBsc:
|
| 193 |
+
return (upper ? "BSC" : (lower ? "bsc" : "Bsc"));
|
| 194 |
+
default:
|
| 195 |
+
TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
|
| 196 |
+
return "";
|
| 197 |
+
}
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
inline bool isCompressedRow(Layout layout) {
|
| 201 |
+
return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
|
| 202 |
+
layout, "isCompressedRow", [&] { return true; }, [&] { return false; });
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
inline bool isCompressedColumn(Layout layout) {
|
| 206 |
+
return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
|
| 207 |
+
layout,
|
| 208 |
+
"isCompressedColumn",
|
| 209 |
+
[&] { return false; },
|
| 210 |
+
[&] { return true; });
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
inline std::string compressedIndicesName(Layout layout) {
|
| 214 |
+
return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
|
| 215 |
+
layout,
|
| 216 |
+
"compressedIndicesName",
|
| 217 |
+
[&] { return "crow_indices"; },
|
| 218 |
+
[&] { return "ccol_indices"; });
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
inline std::string plainIndicesName(Layout layout) {
|
| 222 |
+
return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
|
| 223 |
+
layout,
|
| 224 |
+
"plainIndicesName",
|
| 225 |
+
[&] { return "col_indices"; },
|
| 226 |
+
[&] { return "row_indices"; });
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
inline std::string compressedDimName(Layout layout) {
|
| 230 |
+
switch (layout) {
|
| 231 |
+
case kSparseCsr:
|
| 232 |
+
return "row";
|
| 233 |
+
case kSparseCsc:
|
| 234 |
+
return "column";
|
| 235 |
+
case kSparseBsr:
|
| 236 |
+
return "row block";
|
| 237 |
+
case kSparseBsc:
|
| 238 |
+
return "column block";
|
| 239 |
+
default:
|
| 240 |
+
TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
|
| 241 |
+
return "";
|
| 242 |
+
}
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
inline std::string plainDimName(Layout layout) {
|
| 246 |
+
switch (layout) {
|
| 247 |
+
case kSparseCsr:
|
| 248 |
+
return "column";
|
| 249 |
+
case kSparseCsc:
|
| 250 |
+
return "row";
|
| 251 |
+
case kSparseBsr:
|
| 252 |
+
return "column block";
|
| 253 |
+
case kSparseBsc:
|
| 254 |
+
return "row block";
|
| 255 |
+
default:
|
| 256 |
+
TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
|
| 257 |
+
return "";
|
| 258 |
+
}
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
inline size_t rowDimension(Layout layout, IntArrayRef size) {
|
| 262 |
+
return size.size() - (isCompressedRow(layout) ? 2 : 1);
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
inline size_t columnDimension(Layout layout, IntArrayRef size) {
|
| 266 |
+
return size.size() - (isCompressedColumn(layout) ? 2 : 1);
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
inline size_t compressedDimension(
|
| 270 |
+
Layout layout,
|
| 271 |
+
IntArrayRef size,
|
| 272 |
+
size_t dense_ndim = 0) {
|
| 273 |
+
return size.size() - dense_ndim - (isCompressedRow(layout) ? 2 : 1);
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
inline size_t plainDimension(
|
| 277 |
+
Layout layout,
|
| 278 |
+
IntArrayRef size,
|
| 279 |
+
size_t dense_ndim = 0) {
|
| 280 |
+
return size.size() - dense_ndim - (isCompressedRow(layout) ? 1 : 2);
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
inline int64_t numBatchDimensions(Tensor const& self) {
|
| 284 |
+
return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
|
| 285 |
+
self.layout(),
|
| 286 |
+
"numBatchDimensions",
|
| 287 |
+
[&self] { return self.crow_indices().dim() - 1; },
|
| 288 |
+
[&self] { return self.ccol_indices().dim() - 1; });
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
inline std::pair<Tensor, Tensor> getCompressedPlainIndices(Tensor const& self) {
|
| 292 |
+
return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
|
| 293 |
+
self.layout(),
|
| 294 |
+
"getCompressedPlainIndices",
|
| 295 |
+
[&self] {
|
| 296 |
+
return std::make_pair(self.crow_indices(), self.col_indices());
|
| 297 |
+
},
|
| 298 |
+
[&self] {
|
| 299 |
+
return std::make_pair(self.ccol_indices(), self.row_indices());
|
| 300 |
+
});
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
inline ScalarType getIndexDtype(Tensor const& self) {
|
| 304 |
+
switch (self.layout()) {
|
| 305 |
+
case kSparseCsr:
|
| 306 |
+
case kSparseBsr:
|
| 307 |
+
return self.crow_indices().scalar_type();
|
| 308 |
+
case kSparseCsc:
|
| 309 |
+
case kSparseBsc:
|
| 310 |
+
return self.ccol_indices().scalar_type();
|
| 311 |
+
case kSparse:
|
| 312 |
+
return self._indices().scalar_type();
|
| 313 |
+
default:
|
| 314 |
+
return ScalarType::Long;
|
| 315 |
+
}
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
inline Layout flip_compressed_layout(Layout layout) {
|
| 319 |
+
switch (layout) {
|
| 320 |
+
case kSparseCsr:
|
| 321 |
+
return kSparseCsc;
|
| 322 |
+
case kSparseCsc:
|
| 323 |
+
return kSparseCsr;
|
| 324 |
+
case kSparseBsr:
|
| 325 |
+
return kSparseBsc;
|
| 326 |
+
case kSparseBsc:
|
| 327 |
+
return kSparseBsr;
|
| 328 |
+
default:
|
| 329 |
+
TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
|
| 330 |
+
return kSparseCsr;
|
| 331 |
+
}
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
inline DimVector getBlockSize(Tensor const& self) {
|
| 335 |
+
int64_t n_batch = numBatchDimensions(self);
|
| 336 |
+
return at::DimVector(self.values().sizes().slice(n_batch + 1, 2));
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
inline at::OptionalArray<at::SymInt> getSymIntBlockSize(Tensor const& self) {
|
| 340 |
+
if (self.layout() == at::kSparseBsr || self.layout() == at::kSparseBsc) {
|
| 341 |
+
int64_t n_batch = numBatchDimensions(self);
|
| 342 |
+
return self.values().sym_sizes().slice(n_batch + 1, 2).vec();
|
| 343 |
+
} else {
|
| 344 |
+
return {};
|
| 345 |
+
}
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
template <typename binary_op_t, typename binary_op_out_t>
|
| 349 |
+
inline bool only_sparse_compressed_binary_op_trivial_cases(
|
| 350 |
+
const Tensor& self,
|
| 351 |
+
const Tensor& other,
|
| 352 |
+
const Scalar& alpha,
|
| 353 |
+
Tensor& out,
|
| 354 |
+
const binary_op_t& binary_op,
|
| 355 |
+
const binary_op_out_t& binary_op_out) {
|
| 356 |
+
// Only sparse compressed! Just like the name says :)
|
| 357 |
+
TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(self));
|
| 358 |
+
TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(other));
|
| 359 |
+
TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(out));
|
| 360 |
+
|
| 361 |
+
// Bypass BLAS if there are matches in (self, other, out)
|
| 362 |
+
if (self.is_same(out) && self.is_same(other)) {
|
| 363 |
+
binary_op_out(self.values(), other.values(), alpha);
|
| 364 |
+
return true;
|
| 365 |
+
}
|
| 366 |
+
if (self.is_same(other)) {
|
| 367 |
+
auto [compressed_indices, plain_indices] =
|
| 368 |
+
at::sparse_csr::getCompressedPlainIndices(self);
|
| 369 |
+
static_cast<SparseCsrTensorImpl*>(out.unsafeGetTensorImpl())
|
| 370 |
+
->set_member_tensors(
|
| 371 |
+
compressed_indices,
|
| 372 |
+
plain_indices,
|
| 373 |
+
binary_op(self.values(), other.values(), alpha),
|
| 374 |
+
self.sizes());
|
| 375 |
+
return true;
|
| 376 |
+
}
|
| 377 |
+
return false;
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
inline bool only_sparse_compressed_add_trivial_cases(
|
| 381 |
+
const Tensor& self,
|
| 382 |
+
const Tensor& other,
|
| 383 |
+
const Scalar& alpha,
|
| 384 |
+
Tensor& out) {
|
| 385 |
+
return only_sparse_compressed_binary_op_trivial_cases(
|
| 386 |
+
self,
|
| 387 |
+
other,
|
| 388 |
+
alpha,
|
| 389 |
+
out,
|
| 390 |
+
[](const Tensor& v1, const Tensor& v2, const Scalar& alpha) {
|
| 391 |
+
return v1.add(v2, alpha);
|
| 392 |
+
},
|
| 393 |
+
[](const Tensor& v1, const Tensor& v2, const Scalar& alpha) {
|
| 394 |
+
return v1.add_(v2, alpha);
|
| 395 |
+
});
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
inline Tensor to_type(const Tensor& input, ScalarType dtype) {
|
| 399 |
+
auto [compressed_indices, plain_indices] =
|
| 400 |
+
at::sparse_csr::getCompressedPlainIndices(input);
|
| 401 |
+
return at::_sparse_compressed_tensor_unsafe(
|
| 402 |
+
compressed_indices,
|
| 403 |
+
plain_indices,
|
| 404 |
+
std::move(input.values()).to(dtype),
|
| 405 |
+
input.sizes(),
|
| 406 |
+
dtype,
|
| 407 |
+
input.layout(),
|
| 408 |
+
input.device(),
|
| 409 |
+
input.options().pinned_memory_opt());
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
template <typename acc_t, typename scalar_t>
|
| 413 |
+
inline std::tuple<Tensor, Tensor> create_acc_buffer(
|
| 414 |
+
TensorOptions option,
|
| 415 |
+
ScalarType type,
|
| 416 |
+
int64_t nnz = -1) {
|
| 417 |
+
Tensor new_values, new_values_acc;
|
| 418 |
+
constexpr bool need_acc = !std::is_same_v<scalar_t, acc_t>;
|
| 419 |
+
bool is_integral = at::isIntegralType(type, /*includeBool=*/true);
|
| 420 |
+
if constexpr (need_acc) {
|
| 421 |
+
auto acc_dtype = CppTypeToScalarType<acc_t>::value;
|
| 422 |
+
new_values_acc = at::empty({}, option.dtype(acc_dtype));
|
| 423 |
+
new_values = is_integral ? new_values_acc : at::empty({}, option);
|
| 424 |
+
} else {
|
| 425 |
+
new_values = new_values_acc = at::empty({}, option);
|
| 426 |
+
}
|
| 427 |
+
if (nnz != -1) {
|
| 428 |
+
return std::make_tuple(
|
| 429 |
+
new_values.resize_(nnz), new_values_acc.resize_(nnz));
|
| 430 |
+
} else {
|
| 431 |
+
return std::make_tuple(new_values, new_values_acc);
|
| 432 |
+
}
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
inline void copy_from_acc_buffer(Tensor& new_values, Tensor& new_values_acc) {
|
| 436 |
+
if (!new_values_acc.is_same(new_values)) {
|
| 437 |
+
new_values.copy_(new_values_acc);
|
| 438 |
+
}
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
} // namespace at::sparse_csr
|
.venv/lib/python3.11/site-packages/torch/include/ATen/Storage.h
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <c10/core/Storage.h>
|
.venv/lib/python3.11/site-packages/torch/include/ATen/Tensor.h
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|