Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/torch/include/ATen/ATen.h +37 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/Backtrace.h +2 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/BlasBackend.h +27 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/CPUFunctions_inl.h +540 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/CUDAFunctions.h +29 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/CUDAFunctions_inl.h +623 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/CachedTensorUtils.h +24 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/CollapseDims.h +94 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/CompositeExplicitAutogradFunctions.h +29 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions.h +29 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h +323 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions_inl.h +502 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/DLConvertor.h +25 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/DeviceGuard.h +41 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/Dispatch_v2.h +186 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/DynamicLibrary.h +34 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/ExpandUtils.h +527 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/Functions.h +1454 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/Generator.h +2 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/LinalgBackend.h +31 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/MemoryOverlap.h +42 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/NativeMetaFunctions.h +1330 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/NumericUtils.h +203 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/OpaqueTensorImpl.h +187 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/Operators.h +1385 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/Parallel-inl.h +93 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/ParallelNative.h +15 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/ParallelOpenMP.h +54 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/PythonTorchFunctionTLS.h +36 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/RedispatchFunctions.h +0 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/SmallVector.h +2 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/SparseTensorImpl.h +421 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/StorageUtils.h +49 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/TensorAccessor.h +2 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/TensorIndexing.h +737 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/TensorIteratorInternal.h +72 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/TensorOptions.h +2 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/TensorSubclassLikeUtils.h +88 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/TensorUtils.h +190 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/TypeDefault.h +30 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/Utils.h +134 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/VmapGeneratedPlumbing.h +0 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cpp_custom_type_hack.h +110 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/ATenCUDAGeneral.h +9 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/ApplyGridUtils.cuh +47 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/Atomic.cuh +514 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAApplyUtils.cuh +537 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDABlas.h +358 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAConfig.h +19 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAContext.h +9 -0
.venv/lib/python3.11/site-packages/torch/include/ATen/ATen.h
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#if !defined(_MSC_VER) && __cplusplus < 201703L
|
| 4 |
+
#error C++17 or later compatible compiler is required to use ATen.
|
| 5 |
+
#endif
|
| 6 |
+
|
| 7 |
+
#include <ATen/Context.h>
|
| 8 |
+
#include <ATen/Device.h>
|
| 9 |
+
#include <ATen/DeviceGuard.h>
|
| 10 |
+
#include <ATen/DimVector.h>
|
| 11 |
+
#include <ATen/Dispatch.h>
|
| 12 |
+
#include <ATen/Formatting.h>
|
| 13 |
+
#include <ATen/Functions.h>
|
| 14 |
+
#include <ATen/NamedTensor.h>
|
| 15 |
+
#include <ATen/ScalarOps.h>
|
| 16 |
+
#include <ATen/Tensor.h>
|
| 17 |
+
#include <ATen/TensorGeometry.h>
|
| 18 |
+
#include <ATen/TensorIndexing.h>
|
| 19 |
+
#include <ATen/TensorOperators.h>
|
| 20 |
+
#include <ATen/Version.h>
|
| 21 |
+
#include <ATen/core/ATenGeneral.h>
|
| 22 |
+
#include <ATen/core/Generator.h>
|
| 23 |
+
#include <ATen/core/Reduction.h>
|
| 24 |
+
#include <ATen/core/Scalar.h>
|
| 25 |
+
#include <ATen/core/UnsafeFromTH.h>
|
| 26 |
+
#include <ATen/core/ivalue.h>
|
| 27 |
+
#include <ATen/core/jit_type.h>
|
| 28 |
+
#include <c10/core/Allocator.h>
|
| 29 |
+
#include <c10/core/InferenceMode.h>
|
| 30 |
+
#include <c10/core/Layout.h>
|
| 31 |
+
#include <c10/core/Storage.h>
|
| 32 |
+
#include <c10/core/TensorOptions.h>
|
| 33 |
+
#include <c10/util/Exception.h>
|
| 34 |
+
|
| 35 |
+
// TODO: try to remove this
|
| 36 |
+
// There is some back story, see https://github.com/pytorch/pytorch/issues/48684
|
| 37 |
+
#include <ATen/NativeFunctions.h>
|
.venv/lib/python3.11/site-packages/torch/include/ATen/Backtrace.h
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Backtrace.h>
|
.venv/lib/python3.11/site-packages/torch/include/ATen/BlasBackend.h
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/util/Exception.h>
|
| 4 |
+
|
| 5 |
+
#include <ostream>
|
| 6 |
+
#include <string>
|
| 7 |
+
|
| 8 |
+
namespace at {
|
| 9 |
+
|
| 10 |
+
enum class BlasBackend : int8_t { Cublas, Cublaslt };
|
| 11 |
+
|
| 12 |
+
inline std::string BlasBackendToString(at::BlasBackend backend) {
|
| 13 |
+
switch (backend) {
|
| 14 |
+
case BlasBackend::Cublas:
|
| 15 |
+
return "at::BlasBackend::Cublas";
|
| 16 |
+
case BlasBackend::Cublaslt:
|
| 17 |
+
return "at::BlasBackend::Cublaslt";
|
| 18 |
+
default:
|
| 19 |
+
TORCH_CHECK(false, "Unknown blas backend");
|
| 20 |
+
}
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
inline std::ostream& operator<<(std::ostream& stream, at::BlasBackend backend) {
|
| 24 |
+
return stream << BlasBackendToString(backend);
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/CPUFunctions_inl.h
ADDED
|
@@ -0,0 +1,540 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 12 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 13 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 14 |
+
Consider including a specific operator from \
|
| 15 |
+
<ATen/ops/{my_operator}_cpu_dispatch.h>. \
|
| 16 |
+
See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
#include <ATen/ops/_adaptive_avg_pool2d_cpu_dispatch.h>
|
| 20 |
+
#include <ATen/ops/_adaptive_avg_pool2d_backward_cpu_dispatch.h>
|
| 21 |
+
#include <ATen/ops/_adaptive_avg_pool3d_cpu_dispatch.h>
|
| 22 |
+
#include <ATen/ops/_adaptive_avg_pool3d_backward_cpu_dispatch.h>
|
| 23 |
+
#include <ATen/ops/_add_relu_cpu_dispatch.h>
|
| 24 |
+
#include <ATen/ops/_addmm_activation_cpu_dispatch.h>
|
| 25 |
+
#include <ATen/ops/_aminmax_cpu_dispatch.h>
|
| 26 |
+
#include <ATen/ops/_amp_foreach_non_finite_check_and_unscale_cpu_dispatch.h>
|
| 27 |
+
#include <ATen/ops/_amp_update_scale_cpu_dispatch.h>
|
| 28 |
+
#include <ATen/ops/_assert_async_cpu_dispatch.h>
|
| 29 |
+
#include <ATen/ops/_batch_norm_with_update_cpu_dispatch.h>
|
| 30 |
+
#include <ATen/ops/_cdist_backward_cpu_dispatch.h>
|
| 31 |
+
#include <ATen/ops/_cdist_forward_cpu_dispatch.h>
|
| 32 |
+
#include <ATen/ops/_cholesky_solve_helper_cpu_dispatch.h>
|
| 33 |
+
#include <ATen/ops/_compute_linear_combination_cpu_dispatch.h>
|
| 34 |
+
#include <ATen/ops/_convert_indices_from_coo_to_csr_cpu_dispatch.h>
|
| 35 |
+
#include <ATen/ops/_convert_indices_from_csr_to_coo_cpu_dispatch.h>
|
| 36 |
+
#include <ATen/ops/_convert_weight_to_int4pack_cpu_dispatch.h>
|
| 37 |
+
#include <ATen/ops/_ctc_loss_cpu_dispatch.h>
|
| 38 |
+
#include <ATen/ops/_ctc_loss_backward_cpu_dispatch.h>
|
| 39 |
+
#include <ATen/ops/_cummax_helper_cpu_dispatch.h>
|
| 40 |
+
#include <ATen/ops/_cummin_helper_cpu_dispatch.h>
|
| 41 |
+
#include <ATen/ops/_dirichlet_grad_cpu_dispatch.h>
|
| 42 |
+
#include <ATen/ops/_efficientzerotensor_cpu_dispatch.h>
|
| 43 |
+
#include <ATen/ops/_embedding_bag_cpu_dispatch.h>
|
| 44 |
+
#include <ATen/ops/_embedding_bag_backward_cpu_dispatch.h>
|
| 45 |
+
#include <ATen/ops/_embedding_bag_dense_backward_cpu_dispatch.h>
|
| 46 |
+
#include <ATen/ops/_embedding_bag_forward_only_cpu_dispatch.h>
|
| 47 |
+
#include <ATen/ops/_embedding_bag_per_sample_weights_backward_cpu_dispatch.h>
|
| 48 |
+
#include <ATen/ops/_empty_affine_quantized_cpu_dispatch.h>
|
| 49 |
+
#include <ATen/ops/_empty_per_channel_affine_quantized_cpu_dispatch.h>
|
| 50 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_cpu_dispatch.h>
|
| 51 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_backward_cpu_dispatch.h>
|
| 52 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_cpu_dispatch.h>
|
| 53 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_backward_cpu_dispatch.h>
|
| 54 |
+
#include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_cpu_dispatch.h>
|
| 55 |
+
#include <ATen/ops/_fft_c2c_cpu_dispatch.h>
|
| 56 |
+
#include <ATen/ops/_fft_c2r_cpu_dispatch.h>
|
| 57 |
+
#include <ATen/ops/_fft_r2c_cpu_dispatch.h>
|
| 58 |
+
#include <ATen/ops/_foobar_cpu_dispatch.h>
|
| 59 |
+
#include <ATen/ops/_functional_assert_async_cpu_dispatch.h>
|
| 60 |
+
#include <ATen/ops/_fused_adagrad_cpu_dispatch.h>
|
| 61 |
+
#include <ATen/ops/_fused_adam_cpu_dispatch.h>
|
| 62 |
+
#include <ATen/ops/_fused_adamw_cpu_dispatch.h>
|
| 63 |
+
#include <ATen/ops/_fused_moving_avg_obs_fq_helper_cpu_dispatch.h>
|
| 64 |
+
#include <ATen/ops/_fused_sdp_choice_cpu_dispatch.h>
|
| 65 |
+
#include <ATen/ops/_fused_sgd_cpu_dispatch.h>
|
| 66 |
+
#include <ATen/ops/_histogramdd_bin_edges_cpu_dispatch.h>
|
| 67 |
+
#include <ATen/ops/_histogramdd_from_bin_cts_cpu_dispatch.h>
|
| 68 |
+
#include <ATen/ops/_histogramdd_from_bin_tensors_cpu_dispatch.h>
|
| 69 |
+
#include <ATen/ops/_index_put_impl_cpu_dispatch.h>
|
| 70 |
+
#include <ATen/ops/_int_mm_cpu_dispatch.h>
|
| 71 |
+
#include <ATen/ops/_jagged_to_padded_dense_forward_cpu_dispatch.h>
|
| 72 |
+
#include <ATen/ops/_linalg_det_cpu_dispatch.h>
|
| 73 |
+
#include <ATen/ops/_linalg_eigh_cpu_dispatch.h>
|
| 74 |
+
#include <ATen/ops/_linalg_eigvals_cpu_dispatch.h>
|
| 75 |
+
#include <ATen/ops/_linalg_slogdet_cpu_dispatch.h>
|
| 76 |
+
#include <ATen/ops/_linalg_solve_ex_cpu_dispatch.h>
|
| 77 |
+
#include <ATen/ops/_linalg_svd_cpu_dispatch.h>
|
| 78 |
+
#include <ATen/ops/_local_scalar_dense_cpu_dispatch.h>
|
| 79 |
+
#include <ATen/ops/_log_softmax_cpu_dispatch.h>
|
| 80 |
+
#include <ATen/ops/_log_softmax_backward_data_cpu_dispatch.h>
|
| 81 |
+
#include <ATen/ops/_logcumsumexp_cpu_dispatch.h>
|
| 82 |
+
#include <ATen/ops/_make_dep_token_cpu_dispatch.h>
|
| 83 |
+
#include <ATen/ops/_make_per_channel_quantized_tensor_cpu_dispatch.h>
|
| 84 |
+
#include <ATen/ops/_make_per_tensor_quantized_tensor_cpu_dispatch.h>
|
| 85 |
+
#include <ATen/ops/_masked_softmax_cpu_dispatch.h>
|
| 86 |
+
#include <ATen/ops/_masked_softmax_backward_cpu_dispatch.h>
|
| 87 |
+
#include <ATen/ops/_native_batch_norm_legit_cpu_dispatch.h>
|
| 88 |
+
#include <ATen/ops/_native_multi_head_attention_cpu_dispatch.h>
|
| 89 |
+
#include <ATen/ops/_nested_compute_contiguous_strides_offsets_cpu_dispatch.h>
|
| 90 |
+
#include <ATen/ops/_nested_from_padded_cpu_dispatch.h>
|
| 91 |
+
#include <ATen/ops/_nested_tensor_from_mask_cpu_dispatch.h>
|
| 92 |
+
#include <ATen/ops/_nested_tensor_from_mask_left_aligned_cpu_dispatch.h>
|
| 93 |
+
#include <ATen/ops/_nested_view_from_buffer_cpu_dispatch.h>
|
| 94 |
+
#include <ATen/ops/_padded_dense_to_jagged_forward_cpu_dispatch.h>
|
| 95 |
+
#include <ATen/ops/_pdist_backward_cpu_dispatch.h>
|
| 96 |
+
#include <ATen/ops/_pdist_forward_cpu_dispatch.h>
|
| 97 |
+
#include <ATen/ops/_prelu_kernel_cpu_dispatch.h>
|
| 98 |
+
#include <ATen/ops/_prelu_kernel_backward_cpu_dispatch.h>
|
| 99 |
+
#include <ATen/ops/_reshape_alias_cpu_dispatch.h>
|
| 100 |
+
#include <ATen/ops/_sample_dirichlet_cpu_dispatch.h>
|
| 101 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_cpu_dispatch.h>
|
| 102 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward_cpu_dispatch.h>
|
| 103 |
+
#include <ATen/ops/_segment_reduce_backward_cpu_dispatch.h>
|
| 104 |
+
#include <ATen/ops/_slow_conv2d_backward_cpu_dispatch.h>
|
| 105 |
+
#include <ATen/ops/_slow_conv2d_forward_cpu_dispatch.h>
|
| 106 |
+
#include <ATen/ops/_softmax_cpu_dispatch.h>
|
| 107 |
+
#include <ATen/ops/_softmax_backward_data_cpu_dispatch.h>
|
| 108 |
+
#include <ATen/ops/_spdiags_cpu_dispatch.h>
|
| 109 |
+
#include <ATen/ops/_stack_cpu_dispatch.h>
|
| 110 |
+
#include <ATen/ops/_standard_gamma_cpu_dispatch.h>
|
| 111 |
+
#include <ATen/ops/_standard_gamma_grad_cpu_dispatch.h>
|
| 112 |
+
#include <ATen/ops/_test_functorch_fallback_cpu_dispatch.h>
|
| 113 |
+
#include <ATen/ops/_test_optional_filled_intlist_cpu_dispatch.h>
|
| 114 |
+
#include <ATen/ops/_test_optional_floatlist_cpu_dispatch.h>
|
| 115 |
+
#include <ATen/ops/_test_optional_intlist_cpu_dispatch.h>
|
| 116 |
+
#include <ATen/ops/_to_sparse_cpu_dispatch.h>
|
| 117 |
+
#include <ATen/ops/_to_sparse_bsc_cpu_dispatch.h>
|
| 118 |
+
#include <ATen/ops/_to_sparse_bsr_cpu_dispatch.h>
|
| 119 |
+
#include <ATen/ops/_to_sparse_csc_cpu_dispatch.h>
|
| 120 |
+
#include <ATen/ops/_to_sparse_csr_cpu_dispatch.h>
|
| 121 |
+
#include <ATen/ops/_transform_bias_rescale_qkv_cpu_dispatch.h>
|
| 122 |
+
#include <ATen/ops/_transformer_encoder_layer_fwd_cpu_dispatch.h>
|
| 123 |
+
#include <ATen/ops/_unique_cpu_dispatch.h>
|
| 124 |
+
#include <ATen/ops/_unique2_cpu_dispatch.h>
|
| 125 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_cpu_dispatch.h>
|
| 126 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_backward_cpu_dispatch.h>
|
| 127 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_cpu_dispatch.h>
|
| 128 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_backward_cpu_dispatch.h>
|
| 129 |
+
#include <ATen/ops/_upsample_nearest_exact1d_cpu_dispatch.h>
|
| 130 |
+
#include <ATen/ops/_upsample_nearest_exact1d_backward_cpu_dispatch.h>
|
| 131 |
+
#include <ATen/ops/_upsample_nearest_exact2d_cpu_dispatch.h>
|
| 132 |
+
#include <ATen/ops/_upsample_nearest_exact2d_backward_cpu_dispatch.h>
|
| 133 |
+
#include <ATen/ops/_upsample_nearest_exact3d_cpu_dispatch.h>
|
| 134 |
+
#include <ATen/ops/_upsample_nearest_exact3d_backward_cpu_dispatch.h>
|
| 135 |
+
#include <ATen/ops/_validate_compressed_sparse_indices_cpu_dispatch.h>
|
| 136 |
+
#include <ATen/ops/_weight_int4pack_mm_cpu_dispatch.h>
|
| 137 |
+
#include <ATen/ops/_weight_int8pack_mm_cpu_dispatch.h>
|
| 138 |
+
#include <ATen/ops/_weight_norm_interface_cpu_dispatch.h>
|
| 139 |
+
#include <ATen/ops/_weight_norm_interface_backward_cpu_dispatch.h>
|
| 140 |
+
#include <ATen/ops/abs_cpu_dispatch.h>
|
| 141 |
+
#include <ATen/ops/acos_cpu_dispatch.h>
|
| 142 |
+
#include <ATen/ops/acosh_cpu_dispatch.h>
|
| 143 |
+
#include <ATen/ops/adaptive_avg_pool2d_cpu_dispatch.h>
|
| 144 |
+
#include <ATen/ops/adaptive_avg_pool3d_cpu_dispatch.h>
|
| 145 |
+
#include <ATen/ops/adaptive_avg_pool3d_backward_cpu_dispatch.h>
|
| 146 |
+
#include <ATen/ops/adaptive_max_pool2d_cpu_dispatch.h>
|
| 147 |
+
#include <ATen/ops/adaptive_max_pool2d_backward_cpu_dispatch.h>
|
| 148 |
+
#include <ATen/ops/adaptive_max_pool3d_cpu_dispatch.h>
|
| 149 |
+
#include <ATen/ops/adaptive_max_pool3d_backward_cpu_dispatch.h>
|
| 150 |
+
#include <ATen/ops/add_cpu_dispatch.h>
|
| 151 |
+
#include <ATen/ops/addbmm_cpu_dispatch.h>
|
| 152 |
+
#include <ATen/ops/addcdiv_cpu_dispatch.h>
|
| 153 |
+
#include <ATen/ops/addcmul_cpu_dispatch.h>
|
| 154 |
+
#include <ATen/ops/addmm_cpu_dispatch.h>
|
| 155 |
+
#include <ATen/ops/addmv_cpu_dispatch.h>
|
| 156 |
+
#include <ATen/ops/addr_cpu_dispatch.h>
|
| 157 |
+
#include <ATen/ops/all_cpu_dispatch.h>
|
| 158 |
+
#include <ATen/ops/amax_cpu_dispatch.h>
|
| 159 |
+
#include <ATen/ops/amin_cpu_dispatch.h>
|
| 160 |
+
#include <ATen/ops/aminmax_cpu_dispatch.h>
|
| 161 |
+
#include <ATen/ops/angle_cpu_dispatch.h>
|
| 162 |
+
#include <ATen/ops/any_cpu_dispatch.h>
|
| 163 |
+
#include <ATen/ops/arange_cpu_dispatch.h>
|
| 164 |
+
#include <ATen/ops/argmax_cpu_dispatch.h>
|
| 165 |
+
#include <ATen/ops/argmin_cpu_dispatch.h>
|
| 166 |
+
#include <ATen/ops/as_strided_cpu_dispatch.h>
|
| 167 |
+
#include <ATen/ops/asin_cpu_dispatch.h>
|
| 168 |
+
#include <ATen/ops/asinh_cpu_dispatch.h>
|
| 169 |
+
#include <ATen/ops/atan_cpu_dispatch.h>
|
| 170 |
+
#include <ATen/ops/atan2_cpu_dispatch.h>
|
| 171 |
+
#include <ATen/ops/atanh_cpu_dispatch.h>
|
| 172 |
+
#include <ATen/ops/avg_pool2d_cpu_dispatch.h>
|
| 173 |
+
#include <ATen/ops/avg_pool2d_backward_cpu_dispatch.h>
|
| 174 |
+
#include <ATen/ops/avg_pool3d_cpu_dispatch.h>
|
| 175 |
+
#include <ATen/ops/avg_pool3d_backward_cpu_dispatch.h>
|
| 176 |
+
#include <ATen/ops/baddbmm_cpu_dispatch.h>
|
| 177 |
+
#include <ATen/ops/batch_norm_backward_cpu_dispatch.h>
|
| 178 |
+
#include <ATen/ops/batch_norm_update_stats_cpu_dispatch.h>
|
| 179 |
+
#include <ATen/ops/bernoulli_cpu_dispatch.h>
|
| 180 |
+
#include <ATen/ops/binary_cross_entropy_cpu_dispatch.h>
|
| 181 |
+
#include <ATen/ops/binary_cross_entropy_backward_cpu_dispatch.h>
|
| 182 |
+
#include <ATen/ops/bincount_cpu_dispatch.h>
|
| 183 |
+
#include <ATen/ops/binomial_cpu_dispatch.h>
|
| 184 |
+
#include <ATen/ops/bitwise_and_cpu_dispatch.h>
|
| 185 |
+
#include <ATen/ops/bitwise_left_shift_cpu_dispatch.h>
|
| 186 |
+
#include <ATen/ops/bitwise_not_cpu_dispatch.h>
|
| 187 |
+
#include <ATen/ops/bitwise_or_cpu_dispatch.h>
|
| 188 |
+
#include <ATen/ops/bitwise_right_shift_cpu_dispatch.h>
|
| 189 |
+
#include <ATen/ops/bitwise_xor_cpu_dispatch.h>
|
| 190 |
+
#include <ATen/ops/bmm_cpu_dispatch.h>
|
| 191 |
+
#include <ATen/ops/bucketize_cpu_dispatch.h>
|
| 192 |
+
#include <ATen/ops/cat_cpu_dispatch.h>
|
| 193 |
+
#include <ATen/ops/cauchy_cpu_dispatch.h>
|
| 194 |
+
#include <ATen/ops/ceil_cpu_dispatch.h>
|
| 195 |
+
#include <ATen/ops/channel_shuffle_cpu_dispatch.h>
|
| 196 |
+
#include <ATen/ops/cholesky_cpu_dispatch.h>
|
| 197 |
+
#include <ATen/ops/cholesky_inverse_cpu_dispatch.h>
|
| 198 |
+
#include <ATen/ops/clamp_cpu_dispatch.h>
|
| 199 |
+
#include <ATen/ops/clamp_max_cpu_dispatch.h>
|
| 200 |
+
#include <ATen/ops/clamp_min_cpu_dispatch.h>
|
| 201 |
+
#include <ATen/ops/col2im_cpu_dispatch.h>
|
| 202 |
+
#include <ATen/ops/complex_cpu_dispatch.h>
|
| 203 |
+
#include <ATen/ops/conj_physical_cpu_dispatch.h>
|
| 204 |
+
#include <ATen/ops/copysign_cpu_dispatch.h>
|
| 205 |
+
#include <ATen/ops/cos_cpu_dispatch.h>
|
| 206 |
+
#include <ATen/ops/cosh_cpu_dispatch.h>
|
| 207 |
+
#include <ATen/ops/count_nonzero_cpu_dispatch.h>
|
| 208 |
+
#include <ATen/ops/cumprod_cpu_dispatch.h>
|
| 209 |
+
#include <ATen/ops/cumsum_cpu_dispatch.h>
|
| 210 |
+
#include <ATen/ops/dequantize_cpu_dispatch.h>
|
| 211 |
+
#include <ATen/ops/digamma_cpu_dispatch.h>
|
| 212 |
+
#include <ATen/ops/div_cpu_dispatch.h>
|
| 213 |
+
#include <ATen/ops/dot_cpu_dispatch.h>
|
| 214 |
+
#include <ATen/ops/elu_cpu_dispatch.h>
|
| 215 |
+
#include <ATen/ops/elu_backward_cpu_dispatch.h>
|
| 216 |
+
#include <ATen/ops/embedding_dense_backward_cpu_dispatch.h>
|
| 217 |
+
#include <ATen/ops/embedding_renorm_cpu_dispatch.h>
|
| 218 |
+
#include <ATen/ops/empty_cpu_dispatch.h>
|
| 219 |
+
#include <ATen/ops/empty_strided_cpu_dispatch.h>
|
| 220 |
+
#include <ATen/ops/eq_cpu_dispatch.h>
|
| 221 |
+
#include <ATen/ops/equal_cpu_dispatch.h>
|
| 222 |
+
#include <ATen/ops/erf_cpu_dispatch.h>
|
| 223 |
+
#include <ATen/ops/erfc_cpu_dispatch.h>
|
| 224 |
+
#include <ATen/ops/erfinv_cpu_dispatch.h>
|
| 225 |
+
#include <ATen/ops/exp_cpu_dispatch.h>
|
| 226 |
+
#include <ATen/ops/exp2_cpu_dispatch.h>
|
| 227 |
+
#include <ATen/ops/expm1_cpu_dispatch.h>
|
| 228 |
+
#include <ATen/ops/exponential_cpu_dispatch.h>
|
| 229 |
+
#include <ATen/ops/eye_cpu_dispatch.h>
|
| 230 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_cpu_dispatch.h>
|
| 231 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_cpu_dispatch.h>
|
| 232 |
+
#include <ATen/ops/fill_cpu_dispatch.h>
|
| 233 |
+
#include <ATen/ops/flip_cpu_dispatch.h>
|
| 234 |
+
#include <ATen/ops/floor_cpu_dispatch.h>
|
| 235 |
+
#include <ATen/ops/floor_divide_cpu_dispatch.h>
|
| 236 |
+
#include <ATen/ops/fmax_cpu_dispatch.h>
|
| 237 |
+
#include <ATen/ops/fmin_cpu_dispatch.h>
|
| 238 |
+
#include <ATen/ops/fmod_cpu_dispatch.h>
|
| 239 |
+
#include <ATen/ops/frac_cpu_dispatch.h>
|
| 240 |
+
#include <ATen/ops/fractional_max_pool2d_cpu_dispatch.h>
|
| 241 |
+
#include <ATen/ops/fractional_max_pool2d_backward_cpu_dispatch.h>
|
| 242 |
+
#include <ATen/ops/fractional_max_pool3d_cpu_dispatch.h>
|
| 243 |
+
#include <ATen/ops/fractional_max_pool3d_backward_cpu_dispatch.h>
|
| 244 |
+
#include <ATen/ops/frexp_cpu_dispatch.h>
|
| 245 |
+
#include <ATen/ops/from_file_cpu_dispatch.h>
|
| 246 |
+
#include <ATen/ops/gather_cpu_dispatch.h>
|
| 247 |
+
#include <ATen/ops/gcd_cpu_dispatch.h>
|
| 248 |
+
#include <ATen/ops/ge_cpu_dispatch.h>
|
| 249 |
+
#include <ATen/ops/gelu_cpu_dispatch.h>
|
| 250 |
+
#include <ATen/ops/gelu_backward_cpu_dispatch.h>
|
| 251 |
+
#include <ATen/ops/geometric_cpu_dispatch.h>
|
| 252 |
+
#include <ATen/ops/geqrf_cpu_dispatch.h>
|
| 253 |
+
#include <ATen/ops/glu_cpu_dispatch.h>
|
| 254 |
+
#include <ATen/ops/glu_backward_cpu_dispatch.h>
|
| 255 |
+
#include <ATen/ops/glu_backward_jvp_cpu_dispatch.h>
|
| 256 |
+
#include <ATen/ops/glu_jvp_cpu_dispatch.h>
|
| 257 |
+
#include <ATen/ops/grid_sampler_2d_cpu_dispatch.h>
|
| 258 |
+
#include <ATen/ops/grid_sampler_2d_backward_cpu_dispatch.h>
|
| 259 |
+
#include <ATen/ops/grid_sampler_3d_cpu_dispatch.h>
|
| 260 |
+
#include <ATen/ops/grid_sampler_3d_backward_cpu_dispatch.h>
|
| 261 |
+
#include <ATen/ops/gt_cpu_dispatch.h>
|
| 262 |
+
#include <ATen/ops/hardshrink_cpu_dispatch.h>
|
| 263 |
+
#include <ATen/ops/hardshrink_backward_cpu_dispatch.h>
|
| 264 |
+
#include <ATen/ops/hardsigmoid_cpu_dispatch.h>
|
| 265 |
+
#include <ATen/ops/hardsigmoid_backward_cpu_dispatch.h>
|
| 266 |
+
#include <ATen/ops/hardswish_cpu_dispatch.h>
|
| 267 |
+
#include <ATen/ops/hardswish_backward_cpu_dispatch.h>
|
| 268 |
+
#include <ATen/ops/hardtanh_cpu_dispatch.h>
|
| 269 |
+
#include <ATen/ops/hardtanh_backward_cpu_dispatch.h>
|
| 270 |
+
#include <ATen/ops/heaviside_cpu_dispatch.h>
|
| 271 |
+
#include <ATen/ops/histc_cpu_dispatch.h>
|
| 272 |
+
#include <ATen/ops/histogram_cpu_dispatch.h>
|
| 273 |
+
#include <ATen/ops/huber_loss_cpu_dispatch.h>
|
| 274 |
+
#include <ATen/ops/huber_loss_backward_cpu_dispatch.h>
|
| 275 |
+
#include <ATen/ops/hypot_cpu_dispatch.h>
|
| 276 |
+
#include <ATen/ops/i0_cpu_dispatch.h>
|
| 277 |
+
#include <ATen/ops/igamma_cpu_dispatch.h>
|
| 278 |
+
#include <ATen/ops/igammac_cpu_dispatch.h>
|
| 279 |
+
#include <ATen/ops/im2col_cpu_dispatch.h>
|
| 280 |
+
#include <ATen/ops/index_cpu_dispatch.h>
|
| 281 |
+
#include <ATen/ops/index_add_cpu_dispatch.h>
|
| 282 |
+
#include <ATen/ops/index_copy_cpu_dispatch.h>
|
| 283 |
+
#include <ATen/ops/index_fill_cpu_dispatch.h>
|
| 284 |
+
#include <ATen/ops/index_reduce_cpu_dispatch.h>
|
| 285 |
+
#include <ATen/ops/index_select_cpu_dispatch.h>
|
| 286 |
+
#include <ATen/ops/is_set_to_cpu_dispatch.h>
|
| 287 |
+
#include <ATen/ops/isin_cpu_dispatch.h>
|
| 288 |
+
#include <ATen/ops/isnan_cpu_dispatch.h>
|
| 289 |
+
#include <ATen/ops/isneginf_cpu_dispatch.h>
|
| 290 |
+
#include <ATen/ops/isposinf_cpu_dispatch.h>
|
| 291 |
+
#include <ATen/ops/kthvalue_cpu_dispatch.h>
|
| 292 |
+
#include <ATen/ops/lcm_cpu_dispatch.h>
|
| 293 |
+
#include <ATen/ops/le_cpu_dispatch.h>
|
| 294 |
+
#include <ATen/ops/leaky_relu_cpu_dispatch.h>
|
| 295 |
+
#include <ATen/ops/leaky_relu_backward_cpu_dispatch.h>
|
| 296 |
+
#include <ATen/ops/lerp_cpu_dispatch.h>
|
| 297 |
+
#include <ATen/ops/lgamma_cpu_dispatch.h>
|
| 298 |
+
#include <ATen/ops/linalg_cholesky_ex_cpu_dispatch.h>
|
| 299 |
+
#include <ATen/ops/linalg_cross_cpu_dispatch.h>
|
| 300 |
+
#include <ATen/ops/linalg_eig_cpu_dispatch.h>
|
| 301 |
+
#include <ATen/ops/linalg_eigvals_cpu_dispatch.h>
|
| 302 |
+
#include <ATen/ops/linalg_householder_product_cpu_dispatch.h>
|
| 303 |
+
#include <ATen/ops/linalg_inv_ex_cpu_dispatch.h>
|
| 304 |
+
#include <ATen/ops/linalg_ldl_factor_ex_cpu_dispatch.h>
|
| 305 |
+
#include <ATen/ops/linalg_ldl_solve_cpu_dispatch.h>
|
| 306 |
+
#include <ATen/ops/linalg_lstsq_cpu_dispatch.h>
|
| 307 |
+
#include <ATen/ops/linalg_lu_cpu_dispatch.h>
|
| 308 |
+
#include <ATen/ops/linalg_lu_factor_ex_cpu_dispatch.h>
|
| 309 |
+
#include <ATen/ops/linalg_lu_solve_cpu_dispatch.h>
|
| 310 |
+
#include <ATen/ops/linalg_matrix_exp_cpu_dispatch.h>
|
| 311 |
+
#include <ATen/ops/linalg_qr_cpu_dispatch.h>
|
| 312 |
+
#include <ATen/ops/linalg_solve_triangular_cpu_dispatch.h>
|
| 313 |
+
#include <ATen/ops/linalg_vector_norm_cpu_dispatch.h>
|
| 314 |
+
#include <ATen/ops/linspace_cpu_dispatch.h>
|
| 315 |
+
#include <ATen/ops/log_cpu_dispatch.h>
|
| 316 |
+
#include <ATen/ops/log10_cpu_dispatch.h>
|
| 317 |
+
#include <ATen/ops/log1p_cpu_dispatch.h>
|
| 318 |
+
#include <ATen/ops/log2_cpu_dispatch.h>
|
| 319 |
+
#include <ATen/ops/log_normal_cpu_dispatch.h>
|
| 320 |
+
#include <ATen/ops/log_sigmoid_backward_cpu_dispatch.h>
|
| 321 |
+
#include <ATen/ops/log_sigmoid_forward_cpu_dispatch.h>
|
| 322 |
+
#include <ATen/ops/logaddexp_cpu_dispatch.h>
|
| 323 |
+
#include <ATen/ops/logaddexp2_cpu_dispatch.h>
|
| 324 |
+
#include <ATen/ops/logical_and_cpu_dispatch.h>
|
| 325 |
+
#include <ATen/ops/logical_not_cpu_dispatch.h>
|
| 326 |
+
#include <ATen/ops/logical_or_cpu_dispatch.h>
|
| 327 |
+
#include <ATen/ops/logical_xor_cpu_dispatch.h>
|
| 328 |
+
#include <ATen/ops/logit_cpu_dispatch.h>
|
| 329 |
+
#include <ATen/ops/logit_backward_cpu_dispatch.h>
|
| 330 |
+
#include <ATen/ops/logspace_cpu_dispatch.h>
|
| 331 |
+
#include <ATen/ops/lshift_cpu_dispatch.h>
|
| 332 |
+
#include <ATen/ops/lt_cpu_dispatch.h>
|
| 333 |
+
#include <ATen/ops/lu_unpack_cpu_dispatch.h>
|
| 334 |
+
#include <ATen/ops/masked_fill_cpu_dispatch.h>
|
| 335 |
+
#include <ATen/ops/masked_scatter_cpu_dispatch.h>
|
| 336 |
+
#include <ATen/ops/masked_select_cpu_dispatch.h>
|
| 337 |
+
#include <ATen/ops/max_cpu_dispatch.h>
|
| 338 |
+
#include <ATen/ops/max_pool2d_with_indices_cpu_dispatch.h>
|
| 339 |
+
#include <ATen/ops/max_pool2d_with_indices_backward_cpu_dispatch.h>
|
| 340 |
+
#include <ATen/ops/max_pool3d_with_indices_cpu_dispatch.h>
|
| 341 |
+
#include <ATen/ops/max_pool3d_with_indices_backward_cpu_dispatch.h>
|
| 342 |
+
#include <ATen/ops/max_unpool2d_cpu_dispatch.h>
|
| 343 |
+
#include <ATen/ops/max_unpool3d_cpu_dispatch.h>
|
| 344 |
+
#include <ATen/ops/maximum_cpu_dispatch.h>
|
| 345 |
+
#include <ATen/ops/mean_cpu_dispatch.h>
|
| 346 |
+
#include <ATen/ops/median_cpu_dispatch.h>
|
| 347 |
+
#include <ATen/ops/min_cpu_dispatch.h>
|
| 348 |
+
#include <ATen/ops/minimum_cpu_dispatch.h>
|
| 349 |
+
#include <ATen/ops/mish_cpu_dispatch.h>
|
| 350 |
+
#include <ATen/ops/mish_backward_cpu_dispatch.h>
|
| 351 |
+
#include <ATen/ops/mkldnn_rnn_layer_cpu_dispatch.h>
|
| 352 |
+
#include <ATen/ops/mkldnn_rnn_layer_backward_cpu_dispatch.h>
|
| 353 |
+
#include <ATen/ops/mm_cpu_dispatch.h>
|
| 354 |
+
#include <ATen/ops/mode_cpu_dispatch.h>
|
| 355 |
+
#include <ATen/ops/mse_loss_cpu_dispatch.h>
|
| 356 |
+
#include <ATen/ops/mse_loss_backward_cpu_dispatch.h>
|
| 357 |
+
#include <ATen/ops/mul_cpu_dispatch.h>
|
| 358 |
+
#include <ATen/ops/multi_margin_loss_cpu_dispatch.h>
|
| 359 |
+
#include <ATen/ops/multi_margin_loss_backward_cpu_dispatch.h>
|
| 360 |
+
#include <ATen/ops/multilabel_margin_loss_backward_cpu_dispatch.h>
|
| 361 |
+
#include <ATen/ops/multilabel_margin_loss_forward_cpu_dispatch.h>
|
| 362 |
+
#include <ATen/ops/multinomial_cpu_dispatch.h>
|
| 363 |
+
#include <ATen/ops/mvlgamma_cpu_dispatch.h>
|
| 364 |
+
#include <ATen/ops/nan_to_num_cpu_dispatch.h>
|
| 365 |
+
#include <ATen/ops/nanmedian_cpu_dispatch.h>
|
| 366 |
+
#include <ATen/ops/nansum_cpu_dispatch.h>
|
| 367 |
+
#include <ATen/ops/narrow_copy_cpu_dispatch.h>
|
| 368 |
+
#include <ATen/ops/native_batch_norm_cpu_dispatch.h>
|
| 369 |
+
#include <ATen/ops/native_batch_norm_backward_cpu_dispatch.h>
|
| 370 |
+
#include <ATen/ops/native_channel_shuffle_cpu_dispatch.h>
|
| 371 |
+
#include <ATen/ops/native_dropout_cpu_dispatch.h>
|
| 372 |
+
#include <ATen/ops/native_dropout_backward_cpu_dispatch.h>
|
| 373 |
+
#include <ATen/ops/native_group_norm_cpu_dispatch.h>
|
| 374 |
+
#include <ATen/ops/native_group_norm_backward_cpu_dispatch.h>
|
| 375 |
+
#include <ATen/ops/native_layer_norm_cpu_dispatch.h>
|
| 376 |
+
#include <ATen/ops/native_layer_norm_backward_cpu_dispatch.h>
|
| 377 |
+
#include <ATen/ops/ne_cpu_dispatch.h>
|
| 378 |
+
#include <ATen/ops/neg_cpu_dispatch.h>
|
| 379 |
+
#include <ATen/ops/nextafter_cpu_dispatch.h>
|
| 380 |
+
#include <ATen/ops/nll_loss2d_backward_cpu_dispatch.h>
|
| 381 |
+
#include <ATen/ops/nll_loss2d_forward_cpu_dispatch.h>
|
| 382 |
+
#include <ATen/ops/nll_loss_backward_cpu_dispatch.h>
|
| 383 |
+
#include <ATen/ops/nll_loss_forward_cpu_dispatch.h>
|
| 384 |
+
#include <ATen/ops/nonzero_cpu_dispatch.h>
|
| 385 |
+
#include <ATen/ops/nonzero_static_cpu_dispatch.h>
|
| 386 |
+
#include <ATen/ops/norm_cpu_dispatch.h>
|
| 387 |
+
#include <ATen/ops/normal_cpu_dispatch.h>
|
| 388 |
+
#include <ATen/ops/ormqr_cpu_dispatch.h>
|
| 389 |
+
#include <ATen/ops/pixel_shuffle_cpu_dispatch.h>
|
| 390 |
+
#include <ATen/ops/pixel_unshuffle_cpu_dispatch.h>
|
| 391 |
+
#include <ATen/ops/poisson_cpu_dispatch.h>
|
| 392 |
+
#include <ATen/ops/polar_cpu_dispatch.h>
|
| 393 |
+
#include <ATen/ops/polygamma_cpu_dispatch.h>
|
| 394 |
+
#include <ATen/ops/pow_cpu_dispatch.h>
|
| 395 |
+
#include <ATen/ops/prod_cpu_dispatch.h>
|
| 396 |
+
#include <ATen/ops/put_cpu_dispatch.h>
|
| 397 |
+
#include <ATen/ops/quantize_per_channel_cpu_dispatch.h>
|
| 398 |
+
#include <ATen/ops/quantize_per_tensor_cpu_dispatch.h>
|
| 399 |
+
#include <ATen/ops/quantize_per_tensor_dynamic_cpu_dispatch.h>
|
| 400 |
+
#include <ATen/ops/random_cpu_dispatch.h>
|
| 401 |
+
#include <ATen/ops/randperm_cpu_dispatch.h>
|
| 402 |
+
#include <ATen/ops/range_cpu_dispatch.h>
|
| 403 |
+
#include <ATen/ops/reciprocal_cpu_dispatch.h>
|
| 404 |
+
#include <ATen/ops/reflection_pad1d_cpu_dispatch.h>
|
| 405 |
+
#include <ATen/ops/reflection_pad1d_backward_cpu_dispatch.h>
|
| 406 |
+
#include <ATen/ops/reflection_pad2d_cpu_dispatch.h>
|
| 407 |
+
#include <ATen/ops/reflection_pad2d_backward_cpu_dispatch.h>
|
| 408 |
+
#include <ATen/ops/reflection_pad3d_cpu_dispatch.h>
|
| 409 |
+
#include <ATen/ops/reflection_pad3d_backward_cpu_dispatch.h>
|
| 410 |
+
#include <ATen/ops/relu_cpu_dispatch.h>
|
| 411 |
+
#include <ATen/ops/remainder_cpu_dispatch.h>
|
| 412 |
+
#include <ATen/ops/renorm_cpu_dispatch.h>
|
| 413 |
+
#include <ATen/ops/repeat_interleave_cpu_dispatch.h>
|
| 414 |
+
#include <ATen/ops/replication_pad1d_cpu_dispatch.h>
|
| 415 |
+
#include <ATen/ops/replication_pad1d_backward_cpu_dispatch.h>
|
| 416 |
+
#include <ATen/ops/replication_pad2d_cpu_dispatch.h>
|
| 417 |
+
#include <ATen/ops/replication_pad2d_backward_cpu_dispatch.h>
|
| 418 |
+
#include <ATen/ops/replication_pad3d_cpu_dispatch.h>
|
| 419 |
+
#include <ATen/ops/replication_pad3d_backward_cpu_dispatch.h>
|
| 420 |
+
#include <ATen/ops/resize_cpu_dispatch.h>
|
| 421 |
+
#include <ATen/ops/roll_cpu_dispatch.h>
|
| 422 |
+
#include <ATen/ops/round_cpu_dispatch.h>
|
| 423 |
+
#include <ATen/ops/rrelu_with_noise_cpu_dispatch.h>
|
| 424 |
+
#include <ATen/ops/rshift_cpu_dispatch.h>
|
| 425 |
+
#include <ATen/ops/rsqrt_cpu_dispatch.h>
|
| 426 |
+
#include <ATen/ops/rsub_cpu_dispatch.h>
|
| 427 |
+
#include <ATen/ops/scatter_cpu_dispatch.h>
|
| 428 |
+
#include <ATen/ops/scatter_add_cpu_dispatch.h>
|
| 429 |
+
#include <ATen/ops/scatter_reduce_cpu_dispatch.h>
|
| 430 |
+
#include <ATen/ops/searchsorted_cpu_dispatch.h>
|
| 431 |
+
#include <ATen/ops/segment_reduce_cpu_dispatch.h>
|
| 432 |
+
#include <ATen/ops/set_cpu_dispatch.h>
|
| 433 |
+
#include <ATen/ops/sgn_cpu_dispatch.h>
|
| 434 |
+
#include <ATen/ops/sigmoid_cpu_dispatch.h>
|
| 435 |
+
#include <ATen/ops/sigmoid_backward_cpu_dispatch.h>
|
| 436 |
+
#include <ATen/ops/sign_cpu_dispatch.h>
|
| 437 |
+
#include <ATen/ops/signbit_cpu_dispatch.h>
|
| 438 |
+
#include <ATen/ops/silu_cpu_dispatch.h>
|
| 439 |
+
#include <ATen/ops/silu_backward_cpu_dispatch.h>
|
| 440 |
+
#include <ATen/ops/sin_cpu_dispatch.h>
|
| 441 |
+
#include <ATen/ops/sinc_cpu_dispatch.h>
|
| 442 |
+
#include <ATen/ops/sinh_cpu_dispatch.h>
|
| 443 |
+
#include <ATen/ops/slow_conv3d_forward_cpu_dispatch.h>
|
| 444 |
+
#include <ATen/ops/slow_conv_dilated2d_cpu_dispatch.h>
|
| 445 |
+
#include <ATen/ops/slow_conv_dilated3d_cpu_dispatch.h>
|
| 446 |
+
#include <ATen/ops/slow_conv_transpose2d_cpu_dispatch.h>
|
| 447 |
+
#include <ATen/ops/slow_conv_transpose3d_cpu_dispatch.h>
|
| 448 |
+
#include <ATen/ops/smooth_l1_loss_cpu_dispatch.h>
|
| 449 |
+
#include <ATen/ops/smooth_l1_loss_backward_cpu_dispatch.h>
|
| 450 |
+
#include <ATen/ops/softplus_cpu_dispatch.h>
|
| 451 |
+
#include <ATen/ops/softplus_backward_cpu_dispatch.h>
|
| 452 |
+
#include <ATen/ops/softshrink_cpu_dispatch.h>
|
| 453 |
+
#include <ATen/ops/softshrink_backward_cpu_dispatch.h>
|
| 454 |
+
#include <ATen/ops/sort_cpu_dispatch.h>
|
| 455 |
+
#include <ATen/ops/special_airy_ai_cpu_dispatch.h>
|
| 456 |
+
#include <ATen/ops/special_bessel_j0_cpu_dispatch.h>
|
| 457 |
+
#include <ATen/ops/special_bessel_j1_cpu_dispatch.h>
|
| 458 |
+
#include <ATen/ops/special_bessel_y0_cpu_dispatch.h>
|
| 459 |
+
#include <ATen/ops/special_bessel_y1_cpu_dispatch.h>
|
| 460 |
+
#include <ATen/ops/special_chebyshev_polynomial_t_cpu_dispatch.h>
|
| 461 |
+
#include <ATen/ops/special_chebyshev_polynomial_u_cpu_dispatch.h>
|
| 462 |
+
#include <ATen/ops/special_chebyshev_polynomial_v_cpu_dispatch.h>
|
| 463 |
+
#include <ATen/ops/special_chebyshev_polynomial_w_cpu_dispatch.h>
|
| 464 |
+
#include <ATen/ops/special_entr_cpu_dispatch.h>
|
| 465 |
+
#include <ATen/ops/special_erfcx_cpu_dispatch.h>
|
| 466 |
+
#include <ATen/ops/special_hermite_polynomial_h_cpu_dispatch.h>
|
| 467 |
+
#include <ATen/ops/special_hermite_polynomial_he_cpu_dispatch.h>
|
| 468 |
+
#include <ATen/ops/special_i0e_cpu_dispatch.h>
|
| 469 |
+
#include <ATen/ops/special_i1_cpu_dispatch.h>
|
| 470 |
+
#include <ATen/ops/special_i1e_cpu_dispatch.h>
|
| 471 |
+
#include <ATen/ops/special_laguerre_polynomial_l_cpu_dispatch.h>
|
| 472 |
+
#include <ATen/ops/special_legendre_polynomial_p_cpu_dispatch.h>
|
| 473 |
+
#include <ATen/ops/special_log_ndtr_cpu_dispatch.h>
|
| 474 |
+
#include <ATen/ops/special_modified_bessel_i0_cpu_dispatch.h>
|
| 475 |
+
#include <ATen/ops/special_modified_bessel_i1_cpu_dispatch.h>
|
| 476 |
+
#include <ATen/ops/special_modified_bessel_k0_cpu_dispatch.h>
|
| 477 |
+
#include <ATen/ops/special_modified_bessel_k1_cpu_dispatch.h>
|
| 478 |
+
#include <ATen/ops/special_ndtri_cpu_dispatch.h>
|
| 479 |
+
#include <ATen/ops/special_scaled_modified_bessel_k0_cpu_dispatch.h>
|
| 480 |
+
#include <ATen/ops/special_scaled_modified_bessel_k1_cpu_dispatch.h>
|
| 481 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_t_cpu_dispatch.h>
|
| 482 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_u_cpu_dispatch.h>
|
| 483 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_v_cpu_dispatch.h>
|
| 484 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_w_cpu_dispatch.h>
|
| 485 |
+
#include <ATen/ops/special_spherical_bessel_j0_cpu_dispatch.h>
|
| 486 |
+
#include <ATen/ops/special_xlog1py_cpu_dispatch.h>
|
| 487 |
+
#include <ATen/ops/special_zeta_cpu_dispatch.h>
|
| 488 |
+
#include <ATen/ops/sqrt_cpu_dispatch.h>
|
| 489 |
+
#include <ATen/ops/sspaddmm_cpu_dispatch.h>
|
| 490 |
+
#include <ATen/ops/std_cpu_dispatch.h>
|
| 491 |
+
#include <ATen/ops/std_mean_cpu_dispatch.h>
|
| 492 |
+
#include <ATen/ops/sub_cpu_dispatch.h>
|
| 493 |
+
#include <ATen/ops/sum_cpu_dispatch.h>
|
| 494 |
+
#include <ATen/ops/take_cpu_dispatch.h>
|
| 495 |
+
#include <ATen/ops/tan_cpu_dispatch.h>
|
| 496 |
+
#include <ATen/ops/tanh_cpu_dispatch.h>
|
| 497 |
+
#include <ATen/ops/tanh_backward_cpu_dispatch.h>
|
| 498 |
+
#include <ATen/ops/threshold_cpu_dispatch.h>
|
| 499 |
+
#include <ATen/ops/threshold_backward_cpu_dispatch.h>
|
| 500 |
+
#include <ATen/ops/to_mkldnn_cpu_dispatch.h>
|
| 501 |
+
#include <ATen/ops/topk_cpu_dispatch.h>
|
| 502 |
+
#include <ATen/ops/trace_cpu_dispatch.h>
|
| 503 |
+
#include <ATen/ops/triangular_solve_cpu_dispatch.h>
|
| 504 |
+
#include <ATen/ops/tril_cpu_dispatch.h>
|
| 505 |
+
#include <ATen/ops/tril_indices_cpu_dispatch.h>
|
| 506 |
+
#include <ATen/ops/triu_cpu_dispatch.h>
|
| 507 |
+
#include <ATen/ops/triu_indices_cpu_dispatch.h>
|
| 508 |
+
#include <ATen/ops/trunc_cpu_dispatch.h>
|
| 509 |
+
#include <ATen/ops/unfold_cpu_dispatch.h>
|
| 510 |
+
#include <ATen/ops/unfold_backward_cpu_dispatch.h>
|
| 511 |
+
#include <ATen/ops/uniform_cpu_dispatch.h>
|
| 512 |
+
#include <ATen/ops/unique_consecutive_cpu_dispatch.h>
|
| 513 |
+
#include <ATen/ops/unique_dim_cpu_dispatch.h>
|
| 514 |
+
#include <ATen/ops/unique_dim_consecutive_cpu_dispatch.h>
|
| 515 |
+
#include <ATen/ops/upsample_bicubic2d_cpu_dispatch.h>
|
| 516 |
+
#include <ATen/ops/upsample_bicubic2d_backward_cpu_dispatch.h>
|
| 517 |
+
#include <ATen/ops/upsample_bilinear2d_cpu_dispatch.h>
|
| 518 |
+
#include <ATen/ops/upsample_bilinear2d_backward_cpu_dispatch.h>
|
| 519 |
+
#include <ATen/ops/upsample_linear1d_cpu_dispatch.h>
|
| 520 |
+
#include <ATen/ops/upsample_linear1d_backward_cpu_dispatch.h>
|
| 521 |
+
#include <ATen/ops/upsample_nearest1d_cpu_dispatch.h>
|
| 522 |
+
#include <ATen/ops/upsample_nearest1d_backward_cpu_dispatch.h>
|
| 523 |
+
#include <ATen/ops/upsample_nearest2d_cpu_dispatch.h>
|
| 524 |
+
#include <ATen/ops/upsample_nearest2d_backward_cpu_dispatch.h>
|
| 525 |
+
#include <ATen/ops/upsample_nearest3d_cpu_dispatch.h>
|
| 526 |
+
#include <ATen/ops/upsample_nearest3d_backward_cpu_dispatch.h>
|
| 527 |
+
#include <ATen/ops/upsample_trilinear3d_cpu_dispatch.h>
|
| 528 |
+
#include <ATen/ops/upsample_trilinear3d_backward_cpu_dispatch.h>
|
| 529 |
+
#include <ATen/ops/var_cpu_dispatch.h>
|
| 530 |
+
#include <ATen/ops/var_mean_cpu_dispatch.h>
|
| 531 |
+
#include <ATen/ops/vdot_cpu_dispatch.h>
|
| 532 |
+
#include <ATen/ops/view_cpu_dispatch.h>
|
| 533 |
+
#include <ATen/ops/view_as_complex_cpu_dispatch.h>
|
| 534 |
+
#include <ATen/ops/view_as_real_cpu_dispatch.h>
|
| 535 |
+
#include <ATen/ops/where_cpu_dispatch.h>
|
| 536 |
+
#include <ATen/ops/xlogy_cpu_dispatch.h>
|
| 537 |
+
#include <ATen/ops/zero_cpu_dispatch.h>
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
|
.venv/lib/python3.11/site-packages/torch/include/ATen/CUDAFunctions.h
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/TensorBody.h>
|
| 2 |
+
|
| 3 |
+
// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
|
| 4 |
+
// Code introduced to avoid cyclic dependency in static dispatch is no longer
|
| 5 |
+
// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
|
| 6 |
+
// to Operators.cpp for supporting multiple backends with multiple kernels.
|
| 7 |
+
//
|
| 8 |
+
// Note [Avoiding Include Cycles In Static Dispatch]
|
| 9 |
+
// In order to avoid #include cycles in the static dispatch build, we've carefully split out
|
| 10 |
+
// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
|
| 11 |
+
//
|
| 12 |
+
// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
|
| 13 |
+
// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
|
| 14 |
+
// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
|
| 15 |
+
// directly inlined into TensorBody.h.
|
| 16 |
+
// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
|
| 17 |
+
// which include functions that have defaultable std::optional<Tensor> arguments.
|
| 18 |
+
// That requires knowing the full Tensor class definition.
|
| 19 |
+
//
|
| 20 |
+
// We break the cycle by doing the following:
|
| 21 |
+
// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
|
| 22 |
+
// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
|
| 23 |
+
// - CPUFunctions_inl.h includes everything else
|
| 24 |
+
// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
|
| 25 |
+
// and then it includes CPUFunctions_inl.h.
|
| 26 |
+
// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
|
| 27 |
+
// - This also means that static dispatch build, CPUFunctions.h only needs to
|
| 28 |
+
// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
|
| 29 |
+
#include <ATen/CUDAFunctions_inl.h>
|
.venv/lib/python3.11/site-packages/torch/include/ATen/CUDAFunctions_inl.h
ADDED
|
@@ -0,0 +1,623 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 12 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 13 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 14 |
+
Consider including a specific operator from \
|
| 15 |
+
<ATen/ops/{my_operator}_cuda_dispatch.h>. \
|
| 16 |
+
See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
#include <ATen/ops/_adaptive_avg_pool2d_cuda_dispatch.h>
|
| 20 |
+
#include <ATen/ops/_adaptive_avg_pool2d_backward_cuda_dispatch.h>
|
| 21 |
+
#include <ATen/ops/_adaptive_avg_pool3d_cuda_dispatch.h>
|
| 22 |
+
#include <ATen/ops/_adaptive_avg_pool3d_backward_cuda_dispatch.h>
|
| 23 |
+
#include <ATen/ops/_addmm_activation_cuda_dispatch.h>
|
| 24 |
+
#include <ATen/ops/_aminmax_cuda_dispatch.h>
|
| 25 |
+
#include <ATen/ops/_amp_foreach_non_finite_check_and_unscale_cuda_dispatch.h>
|
| 26 |
+
#include <ATen/ops/_amp_update_scale_cuda_dispatch.h>
|
| 27 |
+
#include <ATen/ops/_assert_async_cuda_dispatch.h>
|
| 28 |
+
#include <ATen/ops/_batch_norm_with_update_cuda_dispatch.h>
|
| 29 |
+
#include <ATen/ops/_cdist_backward_cuda_dispatch.h>
|
| 30 |
+
#include <ATen/ops/_cdist_forward_cuda_dispatch.h>
|
| 31 |
+
#include <ATen/ops/_cholesky_solve_helper_cuda_dispatch.h>
|
| 32 |
+
#include <ATen/ops/_chunk_cat_cuda_dispatch.h>
|
| 33 |
+
#include <ATen/ops/_compute_linear_combination_cuda_dispatch.h>
|
| 34 |
+
#include <ATen/ops/_conv_depthwise2d_cuda_dispatch.h>
|
| 35 |
+
#include <ATen/ops/_convert_indices_from_coo_to_csr_cuda_dispatch.h>
|
| 36 |
+
#include <ATen/ops/_convert_indices_from_csr_to_coo_cuda_dispatch.h>
|
| 37 |
+
#include <ATen/ops/_convert_weight_to_int4pack_cuda_dispatch.h>
|
| 38 |
+
#include <ATen/ops/_cslt_compress_cuda_dispatch.h>
|
| 39 |
+
#include <ATen/ops/_cslt_sparse_mm_cuda_dispatch.h>
|
| 40 |
+
#include <ATen/ops/_cslt_sparse_mm_search_cuda_dispatch.h>
|
| 41 |
+
#include <ATen/ops/_ctc_loss_cuda_dispatch.h>
|
| 42 |
+
#include <ATen/ops/_ctc_loss_backward_cuda_dispatch.h>
|
| 43 |
+
#include <ATen/ops/_cudnn_ctc_loss_cuda_dispatch.h>
|
| 44 |
+
#include <ATen/ops/_cudnn_init_dropout_state_cuda_dispatch.h>
|
| 45 |
+
#include <ATen/ops/_cudnn_rnn_cuda_dispatch.h>
|
| 46 |
+
#include <ATen/ops/_cudnn_rnn_backward_cuda_dispatch.h>
|
| 47 |
+
#include <ATen/ops/_cudnn_rnn_flatten_weight_cuda_dispatch.h>
|
| 48 |
+
#include <ATen/ops/_cummax_helper_cuda_dispatch.h>
|
| 49 |
+
#include <ATen/ops/_cummin_helper_cuda_dispatch.h>
|
| 50 |
+
#include <ATen/ops/_dirichlet_grad_cuda_dispatch.h>
|
| 51 |
+
#include <ATen/ops/_efficient_attention_backward_cuda_dispatch.h>
|
| 52 |
+
#include <ATen/ops/_efficient_attention_forward_cuda_dispatch.h>
|
| 53 |
+
#include <ATen/ops/_efficientzerotensor_cuda_dispatch.h>
|
| 54 |
+
#include <ATen/ops/_embedding_bag_cuda_dispatch.h>
|
| 55 |
+
#include <ATen/ops/_embedding_bag_backward_cuda_dispatch.h>
|
| 56 |
+
#include <ATen/ops/_embedding_bag_dense_backward_cuda_dispatch.h>
|
| 57 |
+
#include <ATen/ops/_embedding_bag_forward_only_cuda_dispatch.h>
|
| 58 |
+
#include <ATen/ops/_embedding_bag_per_sample_weights_backward_cuda_dispatch.h>
|
| 59 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_cuda_dispatch.h>
|
| 60 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_backward_cuda_dispatch.h>
|
| 61 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_cuda_dispatch.h>
|
| 62 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_backward_cuda_dispatch.h>
|
| 63 |
+
#include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_cuda_dispatch.h>
|
| 64 |
+
#include <ATen/ops/_fft_c2c_cuda_dispatch.h>
|
| 65 |
+
#include <ATen/ops/_fft_c2r_cuda_dispatch.h>
|
| 66 |
+
#include <ATen/ops/_fft_r2c_cuda_dispatch.h>
|
| 67 |
+
#include <ATen/ops/_fill_mem_eff_dropout_mask_cuda_dispatch.h>
|
| 68 |
+
#include <ATen/ops/_flash_attention_backward_cuda_dispatch.h>
|
| 69 |
+
#include <ATen/ops/_flash_attention_forward_cuda_dispatch.h>
|
| 70 |
+
#include <ATen/ops/_foreach_abs_cuda_dispatch.h>
|
| 71 |
+
#include <ATen/ops/_foreach_acos_cuda_dispatch.h>
|
| 72 |
+
#include <ATen/ops/_foreach_add_cuda_dispatch.h>
|
| 73 |
+
#include <ATen/ops/_foreach_addcdiv_cuda_dispatch.h>
|
| 74 |
+
#include <ATen/ops/_foreach_addcmul_cuda_dispatch.h>
|
| 75 |
+
#include <ATen/ops/_foreach_asin_cuda_dispatch.h>
|
| 76 |
+
#include <ATen/ops/_foreach_atan_cuda_dispatch.h>
|
| 77 |
+
#include <ATen/ops/_foreach_ceil_cuda_dispatch.h>
|
| 78 |
+
#include <ATen/ops/_foreach_clamp_max_cuda_dispatch.h>
|
| 79 |
+
#include <ATen/ops/_foreach_clamp_min_cuda_dispatch.h>
|
| 80 |
+
#include <ATen/ops/_foreach_copy_cuda_dispatch.h>
|
| 81 |
+
#include <ATen/ops/_foreach_cos_cuda_dispatch.h>
|
| 82 |
+
#include <ATen/ops/_foreach_cosh_cuda_dispatch.h>
|
| 83 |
+
#include <ATen/ops/_foreach_div_cuda_dispatch.h>
|
| 84 |
+
#include <ATen/ops/_foreach_erf_cuda_dispatch.h>
|
| 85 |
+
#include <ATen/ops/_foreach_erfc_cuda_dispatch.h>
|
| 86 |
+
#include <ATen/ops/_foreach_exp_cuda_dispatch.h>
|
| 87 |
+
#include <ATen/ops/_foreach_expm1_cuda_dispatch.h>
|
| 88 |
+
#include <ATen/ops/_foreach_floor_cuda_dispatch.h>
|
| 89 |
+
#include <ATen/ops/_foreach_frac_cuda_dispatch.h>
|
| 90 |
+
#include <ATen/ops/_foreach_lerp_cuda_dispatch.h>
|
| 91 |
+
#include <ATen/ops/_foreach_lgamma_cuda_dispatch.h>
|
| 92 |
+
#include <ATen/ops/_foreach_log_cuda_dispatch.h>
|
| 93 |
+
#include <ATen/ops/_foreach_log10_cuda_dispatch.h>
|
| 94 |
+
#include <ATen/ops/_foreach_log1p_cuda_dispatch.h>
|
| 95 |
+
#include <ATen/ops/_foreach_log2_cuda_dispatch.h>
|
| 96 |
+
#include <ATen/ops/_foreach_max_cuda_dispatch.h>
|
| 97 |
+
#include <ATen/ops/_foreach_maximum_cuda_dispatch.h>
|
| 98 |
+
#include <ATen/ops/_foreach_minimum_cuda_dispatch.h>
|
| 99 |
+
#include <ATen/ops/_foreach_mul_cuda_dispatch.h>
|
| 100 |
+
#include <ATen/ops/_foreach_neg_cuda_dispatch.h>
|
| 101 |
+
#include <ATen/ops/_foreach_norm_cuda_dispatch.h>
|
| 102 |
+
#include <ATen/ops/_foreach_pow_cuda_dispatch.h>
|
| 103 |
+
#include <ATen/ops/_foreach_reciprocal_cuda_dispatch.h>
|
| 104 |
+
#include <ATen/ops/_foreach_round_cuda_dispatch.h>
|
| 105 |
+
#include <ATen/ops/_foreach_sigmoid_cuda_dispatch.h>
|
| 106 |
+
#include <ATen/ops/_foreach_sign_cuda_dispatch.h>
|
| 107 |
+
#include <ATen/ops/_foreach_sin_cuda_dispatch.h>
|
| 108 |
+
#include <ATen/ops/_foreach_sinh_cuda_dispatch.h>
|
| 109 |
+
#include <ATen/ops/_foreach_sqrt_cuda_dispatch.h>
|
| 110 |
+
#include <ATen/ops/_foreach_sub_cuda_dispatch.h>
|
| 111 |
+
#include <ATen/ops/_foreach_tan_cuda_dispatch.h>
|
| 112 |
+
#include <ATen/ops/_foreach_tanh_cuda_dispatch.h>
|
| 113 |
+
#include <ATen/ops/_foreach_trunc_cuda_dispatch.h>
|
| 114 |
+
#include <ATen/ops/_foreach_zero_cuda_dispatch.h>
|
| 115 |
+
#include <ATen/ops/_fused_adam_cuda_dispatch.h>
|
| 116 |
+
#include <ATen/ops/_fused_adamw_cuda_dispatch.h>
|
| 117 |
+
#include <ATen/ops/_fused_dropout_cuda_dispatch.h>
|
| 118 |
+
#include <ATen/ops/_fused_moving_avg_obs_fq_helper_cuda_dispatch.h>
|
| 119 |
+
#include <ATen/ops/_fused_sdp_choice_cuda_dispatch.h>
|
| 120 |
+
#include <ATen/ops/_fused_sgd_cuda_dispatch.h>
|
| 121 |
+
#include <ATen/ops/_index_put_impl_cuda_dispatch.h>
|
| 122 |
+
#include <ATen/ops/_int_mm_cuda_dispatch.h>
|
| 123 |
+
#include <ATen/ops/_jagged_to_padded_dense_forward_cuda_dispatch.h>
|
| 124 |
+
#include <ATen/ops/_linalg_det_cuda_dispatch.h>
|
| 125 |
+
#include <ATen/ops/_linalg_eigh_cuda_dispatch.h>
|
| 126 |
+
#include <ATen/ops/_linalg_eigvals_cuda_dispatch.h>
|
| 127 |
+
#include <ATen/ops/_linalg_slogdet_cuda_dispatch.h>
|
| 128 |
+
#include <ATen/ops/_linalg_solve_ex_cuda_dispatch.h>
|
| 129 |
+
#include <ATen/ops/_linalg_svd_cuda_dispatch.h>
|
| 130 |
+
#include <ATen/ops/_local_scalar_dense_cuda_dispatch.h>
|
| 131 |
+
#include <ATen/ops/_log_softmax_cuda_dispatch.h>
|
| 132 |
+
#include <ATen/ops/_log_softmax_backward_data_cuda_dispatch.h>
|
| 133 |
+
#include <ATen/ops/_logcumsumexp_cuda_dispatch.h>
|
| 134 |
+
#include <ATen/ops/_make_per_channel_quantized_tensor_cuda_dispatch.h>
|
| 135 |
+
#include <ATen/ops/_make_per_tensor_quantized_tensor_cuda_dispatch.h>
|
| 136 |
+
#include <ATen/ops/_masked_scale_cuda_dispatch.h>
|
| 137 |
+
#include <ATen/ops/_masked_softmax_cuda_dispatch.h>
|
| 138 |
+
#include <ATen/ops/_masked_softmax_backward_cuda_dispatch.h>
|
| 139 |
+
#include <ATen/ops/_mixed_dtypes_linear_cuda_dispatch.h>
|
| 140 |
+
#include <ATen/ops/_native_batch_norm_legit_cuda_dispatch.h>
|
| 141 |
+
#include <ATen/ops/_native_multi_head_attention_cuda_dispatch.h>
|
| 142 |
+
#include <ATen/ops/_nested_compute_contiguous_strides_offsets_cuda_dispatch.h>
|
| 143 |
+
#include <ATen/ops/_nested_from_padded_cuda_dispatch.h>
|
| 144 |
+
#include <ATen/ops/_nested_tensor_from_mask_cuda_dispatch.h>
|
| 145 |
+
#include <ATen/ops/_nested_tensor_from_mask_left_aligned_cuda_dispatch.h>
|
| 146 |
+
#include <ATen/ops/_nested_view_from_buffer_cuda_dispatch.h>
|
| 147 |
+
#include <ATen/ops/_padded_dense_to_jagged_forward_cuda_dispatch.h>
|
| 148 |
+
#include <ATen/ops/_pdist_backward_cuda_dispatch.h>
|
| 149 |
+
#include <ATen/ops/_pdist_forward_cuda_dispatch.h>
|
| 150 |
+
#include <ATen/ops/_prelu_kernel_cuda_dispatch.h>
|
| 151 |
+
#include <ATen/ops/_prelu_kernel_backward_cuda_dispatch.h>
|
| 152 |
+
#include <ATen/ops/_reshape_alias_cuda_dispatch.h>
|
| 153 |
+
#include <ATen/ops/_sample_dirichlet_cuda_dispatch.h>
|
| 154 |
+
#include <ATen/ops/_scaled_dot_product_cudnn_attention_cuda_dispatch.h>
|
| 155 |
+
#include <ATen/ops/_scaled_dot_product_cudnn_attention_backward_cuda_dispatch.h>
|
| 156 |
+
#include <ATen/ops/_scaled_dot_product_efficient_attention_cuda_dispatch.h>
|
| 157 |
+
#include <ATen/ops/_scaled_dot_product_efficient_attention_backward_cuda_dispatch.h>
|
| 158 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_cuda_dispatch.h>
|
| 159 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_backward_cuda_dispatch.h>
|
| 160 |
+
#include <ATen/ops/_scaled_mm_cuda_dispatch.h>
|
| 161 |
+
#include <ATen/ops/_segment_reduce_backward_cuda_dispatch.h>
|
| 162 |
+
#include <ATen/ops/_slow_conv2d_backward_cuda_dispatch.h>
|
| 163 |
+
#include <ATen/ops/_slow_conv2d_forward_cuda_dispatch.h>
|
| 164 |
+
#include <ATen/ops/_softmax_cuda_dispatch.h>
|
| 165 |
+
#include <ATen/ops/_softmax_backward_data_cuda_dispatch.h>
|
| 166 |
+
#include <ATen/ops/_sparse_semi_structured_addmm_cuda_dispatch.h>
|
| 167 |
+
#include <ATen/ops/_sparse_semi_structured_apply_cuda_dispatch.h>
|
| 168 |
+
#include <ATen/ops/_sparse_semi_structured_apply_dense_cuda_dispatch.h>
|
| 169 |
+
#include <ATen/ops/_sparse_semi_structured_linear_cuda_dispatch.h>
|
| 170 |
+
#include <ATen/ops/_sparse_semi_structured_mm_cuda_dispatch.h>
|
| 171 |
+
#include <ATen/ops/_sparse_semi_structured_tile_cuda_dispatch.h>
|
| 172 |
+
#include <ATen/ops/_standard_gamma_cuda_dispatch.h>
|
| 173 |
+
#include <ATen/ops/_standard_gamma_grad_cuda_dispatch.h>
|
| 174 |
+
#include <ATen/ops/_thnn_fused_gru_cell_cuda_dispatch.h>
|
| 175 |
+
#include <ATen/ops/_thnn_fused_gru_cell_backward_cuda_dispatch.h>
|
| 176 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_cuda_dispatch.h>
|
| 177 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_backward_impl_cuda_dispatch.h>
|
| 178 |
+
#include <ATen/ops/_to_sparse_cuda_dispatch.h>
|
| 179 |
+
#include <ATen/ops/_to_sparse_bsc_cuda_dispatch.h>
|
| 180 |
+
#include <ATen/ops/_to_sparse_bsr_cuda_dispatch.h>
|
| 181 |
+
#include <ATen/ops/_to_sparse_csc_cuda_dispatch.h>
|
| 182 |
+
#include <ATen/ops/_to_sparse_csr_cuda_dispatch.h>
|
| 183 |
+
#include <ATen/ops/_to_sparse_semi_structured_cuda_dispatch.h>
|
| 184 |
+
#include <ATen/ops/_transform_bias_rescale_qkv_cuda_dispatch.h>
|
| 185 |
+
#include <ATen/ops/_transformer_encoder_layer_fwd_cuda_dispatch.h>
|
| 186 |
+
#include <ATen/ops/_triton_multi_head_attention_cuda_dispatch.h>
|
| 187 |
+
#include <ATen/ops/_triton_scaled_dot_attention_cuda_dispatch.h>
|
| 188 |
+
#include <ATen/ops/_unique_cuda_dispatch.h>
|
| 189 |
+
#include <ATen/ops/_unique2_cuda_dispatch.h>
|
| 190 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_cuda_dispatch.h>
|
| 191 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_backward_cuda_dispatch.h>
|
| 192 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_cuda_dispatch.h>
|
| 193 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_backward_cuda_dispatch.h>
|
| 194 |
+
#include <ATen/ops/_upsample_nearest_exact1d_cuda_dispatch.h>
|
| 195 |
+
#include <ATen/ops/_upsample_nearest_exact1d_backward_cuda_dispatch.h>
|
| 196 |
+
#include <ATen/ops/_upsample_nearest_exact2d_cuda_dispatch.h>
|
| 197 |
+
#include <ATen/ops/_upsample_nearest_exact2d_backward_cuda_dispatch.h>
|
| 198 |
+
#include <ATen/ops/_upsample_nearest_exact3d_cuda_dispatch.h>
|
| 199 |
+
#include <ATen/ops/_upsample_nearest_exact3d_backward_cuda_dispatch.h>
|
| 200 |
+
#include <ATen/ops/_use_cudnn_ctc_loss_cuda_dispatch.h>
|
| 201 |
+
#include <ATen/ops/_validate_compressed_sparse_indices_cuda_dispatch.h>
|
| 202 |
+
#include <ATen/ops/_weight_int4pack_mm_cuda_dispatch.h>
|
| 203 |
+
#include <ATen/ops/_weight_norm_interface_cuda_dispatch.h>
|
| 204 |
+
#include <ATen/ops/_weight_norm_interface_backward_cuda_dispatch.h>
|
| 205 |
+
#include <ATen/ops/abs_cuda_dispatch.h>
|
| 206 |
+
#include <ATen/ops/acos_cuda_dispatch.h>
|
| 207 |
+
#include <ATen/ops/acosh_cuda_dispatch.h>
|
| 208 |
+
#include <ATen/ops/adaptive_avg_pool2d_cuda_dispatch.h>
|
| 209 |
+
#include <ATen/ops/adaptive_avg_pool3d_cuda_dispatch.h>
|
| 210 |
+
#include <ATen/ops/adaptive_avg_pool3d_backward_cuda_dispatch.h>
|
| 211 |
+
#include <ATen/ops/adaptive_max_pool2d_cuda_dispatch.h>
|
| 212 |
+
#include <ATen/ops/adaptive_max_pool2d_backward_cuda_dispatch.h>
|
| 213 |
+
#include <ATen/ops/adaptive_max_pool3d_cuda_dispatch.h>
|
| 214 |
+
#include <ATen/ops/adaptive_max_pool3d_backward_cuda_dispatch.h>
|
| 215 |
+
#include <ATen/ops/add_cuda_dispatch.h>
|
| 216 |
+
#include <ATen/ops/addbmm_cuda_dispatch.h>
|
| 217 |
+
#include <ATen/ops/addcdiv_cuda_dispatch.h>
|
| 218 |
+
#include <ATen/ops/addcmul_cuda_dispatch.h>
|
| 219 |
+
#include <ATen/ops/addmm_cuda_dispatch.h>
|
| 220 |
+
#include <ATen/ops/addmv_cuda_dispatch.h>
|
| 221 |
+
#include <ATen/ops/addr_cuda_dispatch.h>
|
| 222 |
+
#include <ATen/ops/all_cuda_dispatch.h>
|
| 223 |
+
#include <ATen/ops/amax_cuda_dispatch.h>
|
| 224 |
+
#include <ATen/ops/amin_cuda_dispatch.h>
|
| 225 |
+
#include <ATen/ops/aminmax_cuda_dispatch.h>
|
| 226 |
+
#include <ATen/ops/angle_cuda_dispatch.h>
|
| 227 |
+
#include <ATen/ops/any_cuda_dispatch.h>
|
| 228 |
+
#include <ATen/ops/arange_cuda_dispatch.h>
|
| 229 |
+
#include <ATen/ops/argmax_cuda_dispatch.h>
|
| 230 |
+
#include <ATen/ops/argmin_cuda_dispatch.h>
|
| 231 |
+
#include <ATen/ops/as_strided_cuda_dispatch.h>
|
| 232 |
+
#include <ATen/ops/asin_cuda_dispatch.h>
|
| 233 |
+
#include <ATen/ops/asinh_cuda_dispatch.h>
|
| 234 |
+
#include <ATen/ops/atan_cuda_dispatch.h>
|
| 235 |
+
#include <ATen/ops/atan2_cuda_dispatch.h>
|
| 236 |
+
#include <ATen/ops/atanh_cuda_dispatch.h>
|
| 237 |
+
#include <ATen/ops/avg_pool2d_cuda_dispatch.h>
|
| 238 |
+
#include <ATen/ops/avg_pool2d_backward_cuda_dispatch.h>
|
| 239 |
+
#include <ATen/ops/avg_pool3d_cuda_dispatch.h>
|
| 240 |
+
#include <ATen/ops/avg_pool3d_backward_cuda_dispatch.h>
|
| 241 |
+
#include <ATen/ops/baddbmm_cuda_dispatch.h>
|
| 242 |
+
#include <ATen/ops/batch_norm_backward_cuda_dispatch.h>
|
| 243 |
+
#include <ATen/ops/batch_norm_backward_elemt_cuda_dispatch.h>
|
| 244 |
+
#include <ATen/ops/batch_norm_backward_reduce_cuda_dispatch.h>
|
| 245 |
+
#include <ATen/ops/batch_norm_elemt_cuda_dispatch.h>
|
| 246 |
+
#include <ATen/ops/batch_norm_gather_stats_cuda_dispatch.h>
|
| 247 |
+
#include <ATen/ops/batch_norm_gather_stats_with_counts_cuda_dispatch.h>
|
| 248 |
+
#include <ATen/ops/batch_norm_stats_cuda_dispatch.h>
|
| 249 |
+
#include <ATen/ops/batch_norm_update_stats_cuda_dispatch.h>
|
| 250 |
+
#include <ATen/ops/bernoulli_cuda_dispatch.h>
|
| 251 |
+
#include <ATen/ops/binary_cross_entropy_cuda_dispatch.h>
|
| 252 |
+
#include <ATen/ops/binary_cross_entropy_backward_cuda_dispatch.h>
|
| 253 |
+
#include <ATen/ops/bincount_cuda_dispatch.h>
|
| 254 |
+
#include <ATen/ops/binomial_cuda_dispatch.h>
|
| 255 |
+
#include <ATen/ops/bitwise_and_cuda_dispatch.h>
|
| 256 |
+
#include <ATen/ops/bitwise_left_shift_cuda_dispatch.h>
|
| 257 |
+
#include <ATen/ops/bitwise_not_cuda_dispatch.h>
|
| 258 |
+
#include <ATen/ops/bitwise_or_cuda_dispatch.h>
|
| 259 |
+
#include <ATen/ops/bitwise_right_shift_cuda_dispatch.h>
|
| 260 |
+
#include <ATen/ops/bitwise_xor_cuda_dispatch.h>
|
| 261 |
+
#include <ATen/ops/bmm_cuda_dispatch.h>
|
| 262 |
+
#include <ATen/ops/bucketize_cuda_dispatch.h>
|
| 263 |
+
#include <ATen/ops/cat_cuda_dispatch.h>
|
| 264 |
+
#include <ATen/ops/cauchy_cuda_dispatch.h>
|
| 265 |
+
#include <ATen/ops/ceil_cuda_dispatch.h>
|
| 266 |
+
#include <ATen/ops/channel_shuffle_cuda_dispatch.h>
|
| 267 |
+
#include <ATen/ops/cholesky_cuda_dispatch.h>
|
| 268 |
+
#include <ATen/ops/cholesky_inverse_cuda_dispatch.h>
|
| 269 |
+
#include <ATen/ops/clamp_cuda_dispatch.h>
|
| 270 |
+
#include <ATen/ops/clamp_max_cuda_dispatch.h>
|
| 271 |
+
#include <ATen/ops/clamp_min_cuda_dispatch.h>
|
| 272 |
+
#include <ATen/ops/col2im_cuda_dispatch.h>
|
| 273 |
+
#include <ATen/ops/complex_cuda_dispatch.h>
|
| 274 |
+
#include <ATen/ops/conj_physical_cuda_dispatch.h>
|
| 275 |
+
#include <ATen/ops/conv_depthwise3d_cuda_dispatch.h>
|
| 276 |
+
#include <ATen/ops/convolution_backward_cuda_dispatch.h>
|
| 277 |
+
#include <ATen/ops/copysign_cuda_dispatch.h>
|
| 278 |
+
#include <ATen/ops/cos_cuda_dispatch.h>
|
| 279 |
+
#include <ATen/ops/cosh_cuda_dispatch.h>
|
| 280 |
+
#include <ATen/ops/count_nonzero_cuda_dispatch.h>
|
| 281 |
+
#include <ATen/ops/cudnn_affine_grid_generator_cuda_dispatch.h>
|
| 282 |
+
#include <ATen/ops/cudnn_affine_grid_generator_backward_cuda_dispatch.h>
|
| 283 |
+
#include <ATen/ops/cudnn_batch_norm_cuda_dispatch.h>
|
| 284 |
+
#include <ATen/ops/cudnn_batch_norm_backward_cuda_dispatch.h>
|
| 285 |
+
#include <ATen/ops/cudnn_convolution_cuda_dispatch.h>
|
| 286 |
+
#include <ATen/ops/cudnn_convolution_add_relu_cuda_dispatch.h>
|
| 287 |
+
#include <ATen/ops/cudnn_convolution_relu_cuda_dispatch.h>
|
| 288 |
+
#include <ATen/ops/cudnn_convolution_transpose_cuda_dispatch.h>
|
| 289 |
+
#include <ATen/ops/cudnn_grid_sampler_cuda_dispatch.h>
|
| 290 |
+
#include <ATen/ops/cudnn_grid_sampler_backward_cuda_dispatch.h>
|
| 291 |
+
#include <ATen/ops/cumprod_cuda_dispatch.h>
|
| 292 |
+
#include <ATen/ops/cumsum_cuda_dispatch.h>
|
| 293 |
+
#include <ATen/ops/dequantize_cuda_dispatch.h>
|
| 294 |
+
#include <ATen/ops/digamma_cuda_dispatch.h>
|
| 295 |
+
#include <ATen/ops/div_cuda_dispatch.h>
|
| 296 |
+
#include <ATen/ops/dot_cuda_dispatch.h>
|
| 297 |
+
#include <ATen/ops/elu_cuda_dispatch.h>
|
| 298 |
+
#include <ATen/ops/elu_backward_cuda_dispatch.h>
|
| 299 |
+
#include <ATen/ops/embedding_dense_backward_cuda_dispatch.h>
|
| 300 |
+
#include <ATen/ops/embedding_renorm_cuda_dispatch.h>
|
| 301 |
+
#include <ATen/ops/empty_cuda_dispatch.h>
|
| 302 |
+
#include <ATen/ops/empty_strided_cuda_dispatch.h>
|
| 303 |
+
#include <ATen/ops/eq_cuda_dispatch.h>
|
| 304 |
+
#include <ATen/ops/equal_cuda_dispatch.h>
|
| 305 |
+
#include <ATen/ops/erf_cuda_dispatch.h>
|
| 306 |
+
#include <ATen/ops/erfc_cuda_dispatch.h>
|
| 307 |
+
#include <ATen/ops/erfinv_cuda_dispatch.h>
|
| 308 |
+
#include <ATen/ops/exp_cuda_dispatch.h>
|
| 309 |
+
#include <ATen/ops/exp2_cuda_dispatch.h>
|
| 310 |
+
#include <ATen/ops/expm1_cuda_dispatch.h>
|
| 311 |
+
#include <ATen/ops/exponential_cuda_dispatch.h>
|
| 312 |
+
#include <ATen/ops/eye_cuda_dispatch.h>
|
| 313 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_cuda_dispatch.h>
|
| 314 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_cuda_dispatch.h>
|
| 315 |
+
#include <ATen/ops/fill_cuda_dispatch.h>
|
| 316 |
+
#include <ATen/ops/flip_cuda_dispatch.h>
|
| 317 |
+
#include <ATen/ops/floor_cuda_dispatch.h>
|
| 318 |
+
#include <ATen/ops/floor_divide_cuda_dispatch.h>
|
| 319 |
+
#include <ATen/ops/fmax_cuda_dispatch.h>
|
| 320 |
+
#include <ATen/ops/fmin_cuda_dispatch.h>
|
| 321 |
+
#include <ATen/ops/fmod_cuda_dispatch.h>
|
| 322 |
+
#include <ATen/ops/frac_cuda_dispatch.h>
|
| 323 |
+
#include <ATen/ops/fractional_max_pool2d_cuda_dispatch.h>
|
| 324 |
+
#include <ATen/ops/fractional_max_pool2d_backward_cuda_dispatch.h>
|
| 325 |
+
#include <ATen/ops/fractional_max_pool3d_cuda_dispatch.h>
|
| 326 |
+
#include <ATen/ops/fractional_max_pool3d_backward_cuda_dispatch.h>
|
| 327 |
+
#include <ATen/ops/frexp_cuda_dispatch.h>
|
| 328 |
+
#include <ATen/ops/gather_cuda_dispatch.h>
|
| 329 |
+
#include <ATen/ops/gcd_cuda_dispatch.h>
|
| 330 |
+
#include <ATen/ops/ge_cuda_dispatch.h>
|
| 331 |
+
#include <ATen/ops/gelu_cuda_dispatch.h>
|
| 332 |
+
#include <ATen/ops/gelu_backward_cuda_dispatch.h>
|
| 333 |
+
#include <ATen/ops/geometric_cuda_dispatch.h>
|
| 334 |
+
#include <ATen/ops/geqrf_cuda_dispatch.h>
|
| 335 |
+
#include <ATen/ops/glu_cuda_dispatch.h>
|
| 336 |
+
#include <ATen/ops/glu_backward_cuda_dispatch.h>
|
| 337 |
+
#include <ATen/ops/glu_backward_jvp_cuda_dispatch.h>
|
| 338 |
+
#include <ATen/ops/glu_jvp_cuda_dispatch.h>
|
| 339 |
+
#include <ATen/ops/grid_sampler_2d_cuda_dispatch.h>
|
| 340 |
+
#include <ATen/ops/grid_sampler_2d_backward_cuda_dispatch.h>
|
| 341 |
+
#include <ATen/ops/grid_sampler_3d_cuda_dispatch.h>
|
| 342 |
+
#include <ATen/ops/grid_sampler_3d_backward_cuda_dispatch.h>
|
| 343 |
+
#include <ATen/ops/gt_cuda_dispatch.h>
|
| 344 |
+
#include <ATen/ops/hardshrink_cuda_dispatch.h>
|
| 345 |
+
#include <ATen/ops/hardshrink_backward_cuda_dispatch.h>
|
| 346 |
+
#include <ATen/ops/hardsigmoid_cuda_dispatch.h>
|
| 347 |
+
#include <ATen/ops/hardsigmoid_backward_cuda_dispatch.h>
|
| 348 |
+
#include <ATen/ops/hardswish_cuda_dispatch.h>
|
| 349 |
+
#include <ATen/ops/hardswish_backward_cuda_dispatch.h>
|
| 350 |
+
#include <ATen/ops/hardtanh_cuda_dispatch.h>
|
| 351 |
+
#include <ATen/ops/hardtanh_backward_cuda_dispatch.h>
|
| 352 |
+
#include <ATen/ops/heaviside_cuda_dispatch.h>
|
| 353 |
+
#include <ATen/ops/histc_cuda_dispatch.h>
|
| 354 |
+
#include <ATen/ops/huber_loss_cuda_dispatch.h>
|
| 355 |
+
#include <ATen/ops/huber_loss_backward_cuda_dispatch.h>
|
| 356 |
+
#include <ATen/ops/hypot_cuda_dispatch.h>
|
| 357 |
+
#include <ATen/ops/i0_cuda_dispatch.h>
|
| 358 |
+
#include <ATen/ops/igamma_cuda_dispatch.h>
|
| 359 |
+
#include <ATen/ops/igammac_cuda_dispatch.h>
|
| 360 |
+
#include <ATen/ops/im2col_cuda_dispatch.h>
|
| 361 |
+
#include <ATen/ops/index_cuda_dispatch.h>
|
| 362 |
+
#include <ATen/ops/index_add_cuda_dispatch.h>
|
| 363 |
+
#include <ATen/ops/index_copy_cuda_dispatch.h>
|
| 364 |
+
#include <ATen/ops/index_fill_cuda_dispatch.h>
|
| 365 |
+
#include <ATen/ops/index_reduce_cuda_dispatch.h>
|
| 366 |
+
#include <ATen/ops/index_select_cuda_dispatch.h>
|
| 367 |
+
#include <ATen/ops/is_set_to_cuda_dispatch.h>
|
| 368 |
+
#include <ATen/ops/isin_cuda_dispatch.h>
|
| 369 |
+
#include <ATen/ops/isnan_cuda_dispatch.h>
|
| 370 |
+
#include <ATen/ops/isneginf_cuda_dispatch.h>
|
| 371 |
+
#include <ATen/ops/isposinf_cuda_dispatch.h>
|
| 372 |
+
#include <ATen/ops/kthvalue_cuda_dispatch.h>
|
| 373 |
+
#include <ATen/ops/lcm_cuda_dispatch.h>
|
| 374 |
+
#include <ATen/ops/le_cuda_dispatch.h>
|
| 375 |
+
#include <ATen/ops/leaky_relu_cuda_dispatch.h>
|
| 376 |
+
#include <ATen/ops/leaky_relu_backward_cuda_dispatch.h>
|
| 377 |
+
#include <ATen/ops/lerp_cuda_dispatch.h>
|
| 378 |
+
#include <ATen/ops/lgamma_cuda_dispatch.h>
|
| 379 |
+
#include <ATen/ops/linalg_cholesky_ex_cuda_dispatch.h>
|
| 380 |
+
#include <ATen/ops/linalg_cross_cuda_dispatch.h>
|
| 381 |
+
#include <ATen/ops/linalg_eig_cuda_dispatch.h>
|
| 382 |
+
#include <ATen/ops/linalg_eigvals_cuda_dispatch.h>
|
| 383 |
+
#include <ATen/ops/linalg_householder_product_cuda_dispatch.h>
|
| 384 |
+
#include <ATen/ops/linalg_inv_ex_cuda_dispatch.h>
|
| 385 |
+
#include <ATen/ops/linalg_ldl_factor_ex_cuda_dispatch.h>
|
| 386 |
+
#include <ATen/ops/linalg_ldl_solve_cuda_dispatch.h>
|
| 387 |
+
#include <ATen/ops/linalg_lstsq_cuda_dispatch.h>
|
| 388 |
+
#include <ATen/ops/linalg_lu_cuda_dispatch.h>
|
| 389 |
+
#include <ATen/ops/linalg_lu_factor_ex_cuda_dispatch.h>
|
| 390 |
+
#include <ATen/ops/linalg_lu_solve_cuda_dispatch.h>
|
| 391 |
+
#include <ATen/ops/linalg_matrix_exp_cuda_dispatch.h>
|
| 392 |
+
#include <ATen/ops/linalg_qr_cuda_dispatch.h>
|
| 393 |
+
#include <ATen/ops/linalg_solve_triangular_cuda_dispatch.h>
|
| 394 |
+
#include <ATen/ops/linalg_vector_norm_cuda_dispatch.h>
|
| 395 |
+
#include <ATen/ops/linspace_cuda_dispatch.h>
|
| 396 |
+
#include <ATen/ops/log_cuda_dispatch.h>
|
| 397 |
+
#include <ATen/ops/log10_cuda_dispatch.h>
|
| 398 |
+
#include <ATen/ops/log1p_cuda_dispatch.h>
|
| 399 |
+
#include <ATen/ops/log2_cuda_dispatch.h>
|
| 400 |
+
#include <ATen/ops/log_normal_cuda_dispatch.h>
|
| 401 |
+
#include <ATen/ops/log_sigmoid_backward_cuda_dispatch.h>
|
| 402 |
+
#include <ATen/ops/log_sigmoid_forward_cuda_dispatch.h>
|
| 403 |
+
#include <ATen/ops/logaddexp_cuda_dispatch.h>
|
| 404 |
+
#include <ATen/ops/logaddexp2_cuda_dispatch.h>
|
| 405 |
+
#include <ATen/ops/logical_and_cuda_dispatch.h>
|
| 406 |
+
#include <ATen/ops/logical_not_cuda_dispatch.h>
|
| 407 |
+
#include <ATen/ops/logical_or_cuda_dispatch.h>
|
| 408 |
+
#include <ATen/ops/logical_xor_cuda_dispatch.h>
|
| 409 |
+
#include <ATen/ops/logit_cuda_dispatch.h>
|
| 410 |
+
#include <ATen/ops/logit_backward_cuda_dispatch.h>
|
| 411 |
+
#include <ATen/ops/logspace_cuda_dispatch.h>
|
| 412 |
+
#include <ATen/ops/lshift_cuda_dispatch.h>
|
| 413 |
+
#include <ATen/ops/lt_cuda_dispatch.h>
|
| 414 |
+
#include <ATen/ops/lu_unpack_cuda_dispatch.h>
|
| 415 |
+
#include <ATen/ops/masked_fill_cuda_dispatch.h>
|
| 416 |
+
#include <ATen/ops/masked_scatter_cuda_dispatch.h>
|
| 417 |
+
#include <ATen/ops/masked_select_cuda_dispatch.h>
|
| 418 |
+
#include <ATen/ops/max_cuda_dispatch.h>
|
| 419 |
+
#include <ATen/ops/max_pool2d_with_indices_cuda_dispatch.h>
|
| 420 |
+
#include <ATen/ops/max_pool2d_with_indices_backward_cuda_dispatch.h>
|
| 421 |
+
#include <ATen/ops/max_pool3d_with_indices_cuda_dispatch.h>
|
| 422 |
+
#include <ATen/ops/max_pool3d_with_indices_backward_cuda_dispatch.h>
|
| 423 |
+
#include <ATen/ops/max_unpool2d_cuda_dispatch.h>
|
| 424 |
+
#include <ATen/ops/max_unpool3d_cuda_dispatch.h>
|
| 425 |
+
#include <ATen/ops/maximum_cuda_dispatch.h>
|
| 426 |
+
#include <ATen/ops/mean_cuda_dispatch.h>
|
| 427 |
+
#include <ATen/ops/median_cuda_dispatch.h>
|
| 428 |
+
#include <ATen/ops/min_cuda_dispatch.h>
|
| 429 |
+
#include <ATen/ops/minimum_cuda_dispatch.h>
|
| 430 |
+
#include <ATen/ops/miopen_batch_norm_cuda_dispatch.h>
|
| 431 |
+
#include <ATen/ops/miopen_batch_norm_backward_cuda_dispatch.h>
|
| 432 |
+
#include <ATen/ops/miopen_convolution_cuda_dispatch.h>
|
| 433 |
+
#include <ATen/ops/miopen_convolution_add_relu_cuda_dispatch.h>
|
| 434 |
+
#include <ATen/ops/miopen_convolution_relu_cuda_dispatch.h>
|
| 435 |
+
#include <ATen/ops/miopen_convolution_transpose_cuda_dispatch.h>
|
| 436 |
+
#include <ATen/ops/miopen_depthwise_convolution_cuda_dispatch.h>
|
| 437 |
+
#include <ATen/ops/miopen_rnn_cuda_dispatch.h>
|
| 438 |
+
#include <ATen/ops/miopen_rnn_backward_cuda_dispatch.h>
|
| 439 |
+
#include <ATen/ops/mish_cuda_dispatch.h>
|
| 440 |
+
#include <ATen/ops/mish_backward_cuda_dispatch.h>
|
| 441 |
+
#include <ATen/ops/mm_cuda_dispatch.h>
|
| 442 |
+
#include <ATen/ops/mode_cuda_dispatch.h>
|
| 443 |
+
#include <ATen/ops/mse_loss_cuda_dispatch.h>
|
| 444 |
+
#include <ATen/ops/mse_loss_backward_cuda_dispatch.h>
|
| 445 |
+
#include <ATen/ops/mul_cuda_dispatch.h>
|
| 446 |
+
#include <ATen/ops/multi_margin_loss_cuda_dispatch.h>
|
| 447 |
+
#include <ATen/ops/multi_margin_loss_backward_cuda_dispatch.h>
|
| 448 |
+
#include <ATen/ops/multilabel_margin_loss_backward_cuda_dispatch.h>
|
| 449 |
+
#include <ATen/ops/multilabel_margin_loss_forward_cuda_dispatch.h>
|
| 450 |
+
#include <ATen/ops/multinomial_cuda_dispatch.h>
|
| 451 |
+
#include <ATen/ops/mvlgamma_cuda_dispatch.h>
|
| 452 |
+
#include <ATen/ops/nan_to_num_cuda_dispatch.h>
|
| 453 |
+
#include <ATen/ops/nanmedian_cuda_dispatch.h>
|
| 454 |
+
#include <ATen/ops/nansum_cuda_dispatch.h>
|
| 455 |
+
#include <ATen/ops/native_batch_norm_cuda_dispatch.h>
|
| 456 |
+
#include <ATen/ops/native_batch_norm_backward_cuda_dispatch.h>
|
| 457 |
+
#include <ATen/ops/native_dropout_cuda_dispatch.h>
|
| 458 |
+
#include <ATen/ops/native_dropout_backward_cuda_dispatch.h>
|
| 459 |
+
#include <ATen/ops/native_group_norm_cuda_dispatch.h>
|
| 460 |
+
#include <ATen/ops/native_group_norm_backward_cuda_dispatch.h>
|
| 461 |
+
#include <ATen/ops/native_layer_norm_cuda_dispatch.h>
|
| 462 |
+
#include <ATen/ops/native_layer_norm_backward_cuda_dispatch.h>
|
| 463 |
+
#include <ATen/ops/ne_cuda_dispatch.h>
|
| 464 |
+
#include <ATen/ops/neg_cuda_dispatch.h>
|
| 465 |
+
#include <ATen/ops/nextafter_cuda_dispatch.h>
|
| 466 |
+
#include <ATen/ops/nll_loss2d_backward_cuda_dispatch.h>
|
| 467 |
+
#include <ATen/ops/nll_loss2d_forward_cuda_dispatch.h>
|
| 468 |
+
#include <ATen/ops/nll_loss_backward_cuda_dispatch.h>
|
| 469 |
+
#include <ATen/ops/nll_loss_forward_cuda_dispatch.h>
|
| 470 |
+
#include <ATen/ops/nonzero_cuda_dispatch.h>
|
| 471 |
+
#include <ATen/ops/norm_cuda_dispatch.h>
|
| 472 |
+
#include <ATen/ops/normal_cuda_dispatch.h>
|
| 473 |
+
#include <ATen/ops/ormqr_cuda_dispatch.h>
|
| 474 |
+
#include <ATen/ops/poisson_cuda_dispatch.h>
|
| 475 |
+
#include <ATen/ops/polar_cuda_dispatch.h>
|
| 476 |
+
#include <ATen/ops/polygamma_cuda_dispatch.h>
|
| 477 |
+
#include <ATen/ops/pow_cuda_dispatch.h>
|
| 478 |
+
#include <ATen/ops/prod_cuda_dispatch.h>
|
| 479 |
+
#include <ATen/ops/put_cuda_dispatch.h>
|
| 480 |
+
#include <ATen/ops/quantize_per_channel_cuda_dispatch.h>
|
| 481 |
+
#include <ATen/ops/quantize_per_tensor_cuda_dispatch.h>
|
| 482 |
+
#include <ATen/ops/quantize_per_tensor_dynamic_cuda_dispatch.h>
|
| 483 |
+
#include <ATen/ops/random_cuda_dispatch.h>
|
| 484 |
+
#include <ATen/ops/randperm_cuda_dispatch.h>
|
| 485 |
+
#include <ATen/ops/range_cuda_dispatch.h>
|
| 486 |
+
#include <ATen/ops/reciprocal_cuda_dispatch.h>
|
| 487 |
+
#include <ATen/ops/record_stream_cuda_dispatch.h>
|
| 488 |
+
#include <ATen/ops/reflection_pad1d_cuda_dispatch.h>
|
| 489 |
+
#include <ATen/ops/reflection_pad1d_backward_cuda_dispatch.h>
|
| 490 |
+
#include <ATen/ops/reflection_pad2d_cuda_dispatch.h>
|
| 491 |
+
#include <ATen/ops/reflection_pad2d_backward_cuda_dispatch.h>
|
| 492 |
+
#include <ATen/ops/reflection_pad3d_cuda_dispatch.h>
|
| 493 |
+
#include <ATen/ops/reflection_pad3d_backward_cuda_dispatch.h>
|
| 494 |
+
#include <ATen/ops/relu_cuda_dispatch.h>
|
| 495 |
+
#include <ATen/ops/remainder_cuda_dispatch.h>
|
| 496 |
+
#include <ATen/ops/renorm_cuda_dispatch.h>
|
| 497 |
+
#include <ATen/ops/repeat_interleave_cuda_dispatch.h>
|
| 498 |
+
#include <ATen/ops/replication_pad1d_cuda_dispatch.h>
|
| 499 |
+
#include <ATen/ops/replication_pad1d_backward_cuda_dispatch.h>
|
| 500 |
+
#include <ATen/ops/replication_pad2d_cuda_dispatch.h>
|
| 501 |
+
#include <ATen/ops/replication_pad2d_backward_cuda_dispatch.h>
|
| 502 |
+
#include <ATen/ops/replication_pad3d_cuda_dispatch.h>
|
| 503 |
+
#include <ATen/ops/replication_pad3d_backward_cuda_dispatch.h>
|
| 504 |
+
#include <ATen/ops/resize_cuda_dispatch.h>
|
| 505 |
+
#include <ATen/ops/roll_cuda_dispatch.h>
|
| 506 |
+
#include <ATen/ops/round_cuda_dispatch.h>
|
| 507 |
+
#include <ATen/ops/rrelu_with_noise_cuda_dispatch.h>
|
| 508 |
+
#include <ATen/ops/rshift_cuda_dispatch.h>
|
| 509 |
+
#include <ATen/ops/rsqrt_cuda_dispatch.h>
|
| 510 |
+
#include <ATen/ops/rsub_cuda_dispatch.h>
|
| 511 |
+
#include <ATen/ops/scatter_cuda_dispatch.h>
|
| 512 |
+
#include <ATen/ops/scatter_add_cuda_dispatch.h>
|
| 513 |
+
#include <ATen/ops/scatter_reduce_cuda_dispatch.h>
|
| 514 |
+
#include <ATen/ops/searchsorted_cuda_dispatch.h>
|
| 515 |
+
#include <ATen/ops/segment_reduce_cuda_dispatch.h>
|
| 516 |
+
#include <ATen/ops/set_cuda_dispatch.h>
|
| 517 |
+
#include <ATen/ops/sgn_cuda_dispatch.h>
|
| 518 |
+
#include <ATen/ops/sigmoid_cuda_dispatch.h>
|
| 519 |
+
#include <ATen/ops/sigmoid_backward_cuda_dispatch.h>
|
| 520 |
+
#include <ATen/ops/sign_cuda_dispatch.h>
|
| 521 |
+
#include <ATen/ops/signbit_cuda_dispatch.h>
|
| 522 |
+
#include <ATen/ops/silu_cuda_dispatch.h>
|
| 523 |
+
#include <ATen/ops/silu_backward_cuda_dispatch.h>
|
| 524 |
+
#include <ATen/ops/sin_cuda_dispatch.h>
|
| 525 |
+
#include <ATen/ops/sinc_cuda_dispatch.h>
|
| 526 |
+
#include <ATen/ops/sinh_cuda_dispatch.h>
|
| 527 |
+
#include <ATen/ops/slow_conv_dilated2d_cuda_dispatch.h>
|
| 528 |
+
#include <ATen/ops/slow_conv_dilated3d_cuda_dispatch.h>
|
| 529 |
+
#include <ATen/ops/slow_conv_transpose2d_cuda_dispatch.h>
|
| 530 |
+
#include <ATen/ops/slow_conv_transpose3d_cuda_dispatch.h>
|
| 531 |
+
#include <ATen/ops/smooth_l1_loss_cuda_dispatch.h>
|
| 532 |
+
#include <ATen/ops/smooth_l1_loss_backward_cuda_dispatch.h>
|
| 533 |
+
#include <ATen/ops/softplus_cuda_dispatch.h>
|
| 534 |
+
#include <ATen/ops/softplus_backward_cuda_dispatch.h>
|
| 535 |
+
#include <ATen/ops/softshrink_cuda_dispatch.h>
|
| 536 |
+
#include <ATen/ops/softshrink_backward_cuda_dispatch.h>
|
| 537 |
+
#include <ATen/ops/sort_cuda_dispatch.h>
|
| 538 |
+
#include <ATen/ops/special_airy_ai_cuda_dispatch.h>
|
| 539 |
+
#include <ATen/ops/special_bessel_j0_cuda_dispatch.h>
|
| 540 |
+
#include <ATen/ops/special_bessel_j1_cuda_dispatch.h>
|
| 541 |
+
#include <ATen/ops/special_bessel_y0_cuda_dispatch.h>
|
| 542 |
+
#include <ATen/ops/special_bessel_y1_cuda_dispatch.h>
|
| 543 |
+
#include <ATen/ops/special_chebyshev_polynomial_t_cuda_dispatch.h>
|
| 544 |
+
#include <ATen/ops/special_chebyshev_polynomial_u_cuda_dispatch.h>
|
| 545 |
+
#include <ATen/ops/special_chebyshev_polynomial_v_cuda_dispatch.h>
|
| 546 |
+
#include <ATen/ops/special_chebyshev_polynomial_w_cuda_dispatch.h>
|
| 547 |
+
#include <ATen/ops/special_entr_cuda_dispatch.h>
|
| 548 |
+
#include <ATen/ops/special_erfcx_cuda_dispatch.h>
|
| 549 |
+
#include <ATen/ops/special_hermite_polynomial_h_cuda_dispatch.h>
|
| 550 |
+
#include <ATen/ops/special_hermite_polynomial_he_cuda_dispatch.h>
|
| 551 |
+
#include <ATen/ops/special_i0e_cuda_dispatch.h>
|
| 552 |
+
#include <ATen/ops/special_i1_cuda_dispatch.h>
|
| 553 |
+
#include <ATen/ops/special_i1e_cuda_dispatch.h>
|
| 554 |
+
#include <ATen/ops/special_laguerre_polynomial_l_cuda_dispatch.h>
|
| 555 |
+
#include <ATen/ops/special_legendre_polynomial_p_cuda_dispatch.h>
|
| 556 |
+
#include <ATen/ops/special_log_ndtr_cuda_dispatch.h>
|
| 557 |
+
#include <ATen/ops/special_modified_bessel_i0_cuda_dispatch.h>
|
| 558 |
+
#include <ATen/ops/special_modified_bessel_i1_cuda_dispatch.h>
|
| 559 |
+
#include <ATen/ops/special_modified_bessel_k0_cuda_dispatch.h>
|
| 560 |
+
#include <ATen/ops/special_modified_bessel_k1_cuda_dispatch.h>
|
| 561 |
+
#include <ATen/ops/special_ndtri_cuda_dispatch.h>
|
| 562 |
+
#include <ATen/ops/special_scaled_modified_bessel_k0_cuda_dispatch.h>
|
| 563 |
+
#include <ATen/ops/special_scaled_modified_bessel_k1_cuda_dispatch.h>
|
| 564 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_t_cuda_dispatch.h>
|
| 565 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_u_cuda_dispatch.h>
|
| 566 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_v_cuda_dispatch.h>
|
| 567 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_w_cuda_dispatch.h>
|
| 568 |
+
#include <ATen/ops/special_spherical_bessel_j0_cuda_dispatch.h>
|
| 569 |
+
#include <ATen/ops/special_xlog1py_cuda_dispatch.h>
|
| 570 |
+
#include <ATen/ops/special_zeta_cuda_dispatch.h>
|
| 571 |
+
#include <ATen/ops/split_with_sizes_copy_cuda_dispatch.h>
|
| 572 |
+
#include <ATen/ops/sqrt_cuda_dispatch.h>
|
| 573 |
+
#include <ATen/ops/sspaddmm_cuda_dispatch.h>
|
| 574 |
+
#include <ATen/ops/std_cuda_dispatch.h>
|
| 575 |
+
#include <ATen/ops/std_mean_cuda_dispatch.h>
|
| 576 |
+
#include <ATen/ops/sub_cuda_dispatch.h>
|
| 577 |
+
#include <ATen/ops/sum_cuda_dispatch.h>
|
| 578 |
+
#include <ATen/ops/take_cuda_dispatch.h>
|
| 579 |
+
#include <ATen/ops/tan_cuda_dispatch.h>
|
| 580 |
+
#include <ATen/ops/tanh_cuda_dispatch.h>
|
| 581 |
+
#include <ATen/ops/tanh_backward_cuda_dispatch.h>
|
| 582 |
+
#include <ATen/ops/threshold_cuda_dispatch.h>
|
| 583 |
+
#include <ATen/ops/threshold_backward_cuda_dispatch.h>
|
| 584 |
+
#include <ATen/ops/topk_cuda_dispatch.h>
|
| 585 |
+
#include <ATen/ops/trace_cuda_dispatch.h>
|
| 586 |
+
#include <ATen/ops/triangular_solve_cuda_dispatch.h>
|
| 587 |
+
#include <ATen/ops/tril_cuda_dispatch.h>
|
| 588 |
+
#include <ATen/ops/tril_indices_cuda_dispatch.h>
|
| 589 |
+
#include <ATen/ops/triu_cuda_dispatch.h>
|
| 590 |
+
#include <ATen/ops/triu_indices_cuda_dispatch.h>
|
| 591 |
+
#include <ATen/ops/trunc_cuda_dispatch.h>
|
| 592 |
+
#include <ATen/ops/unfold_cuda_dispatch.h>
|
| 593 |
+
#include <ATen/ops/unfold_backward_cuda_dispatch.h>
|
| 594 |
+
#include <ATen/ops/uniform_cuda_dispatch.h>
|
| 595 |
+
#include <ATen/ops/unique_consecutive_cuda_dispatch.h>
|
| 596 |
+
#include <ATen/ops/unique_dim_cuda_dispatch.h>
|
| 597 |
+
#include <ATen/ops/unique_dim_consecutive_cuda_dispatch.h>
|
| 598 |
+
#include <ATen/ops/upsample_bicubic2d_cuda_dispatch.h>
|
| 599 |
+
#include <ATen/ops/upsample_bicubic2d_backward_cuda_dispatch.h>
|
| 600 |
+
#include <ATen/ops/upsample_bilinear2d_cuda_dispatch.h>
|
| 601 |
+
#include <ATen/ops/upsample_bilinear2d_backward_cuda_dispatch.h>
|
| 602 |
+
#include <ATen/ops/upsample_linear1d_cuda_dispatch.h>
|
| 603 |
+
#include <ATen/ops/upsample_linear1d_backward_cuda_dispatch.h>
|
| 604 |
+
#include <ATen/ops/upsample_nearest1d_cuda_dispatch.h>
|
| 605 |
+
#include <ATen/ops/upsample_nearest1d_backward_cuda_dispatch.h>
|
| 606 |
+
#include <ATen/ops/upsample_nearest2d_cuda_dispatch.h>
|
| 607 |
+
#include <ATen/ops/upsample_nearest2d_backward_cuda_dispatch.h>
|
| 608 |
+
#include <ATen/ops/upsample_nearest3d_cuda_dispatch.h>
|
| 609 |
+
#include <ATen/ops/upsample_nearest3d_backward_cuda_dispatch.h>
|
| 610 |
+
#include <ATen/ops/upsample_trilinear3d_cuda_dispatch.h>
|
| 611 |
+
#include <ATen/ops/upsample_trilinear3d_backward_cuda_dispatch.h>
|
| 612 |
+
#include <ATen/ops/var_cuda_dispatch.h>
|
| 613 |
+
#include <ATen/ops/var_mean_cuda_dispatch.h>
|
| 614 |
+
#include <ATen/ops/vdot_cuda_dispatch.h>
|
| 615 |
+
#include <ATen/ops/view_cuda_dispatch.h>
|
| 616 |
+
#include <ATen/ops/view_as_complex_cuda_dispatch.h>
|
| 617 |
+
#include <ATen/ops/view_as_real_cuda_dispatch.h>
|
| 618 |
+
#include <ATen/ops/where_cuda_dispatch.h>
|
| 619 |
+
#include <ATen/ops/xlogy_cuda_dispatch.h>
|
| 620 |
+
#include <ATen/ops/zero_cuda_dispatch.h>
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
|
.venv/lib/python3.11/site-packages/torch/include/ATen/CachedTensorUtils.h
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/ATen.h>
|
| 4 |
+
|
| 5 |
+
namespace at::caching {
|
| 6 |
+
|
| 7 |
+
// Some systems (just cudagraphs currently) will persist a static tensor output
|
| 8 |
+
// whose TensorImpl does not change across iterations. For these tensors caching
|
| 9 |
+
// dtype conversions is invalid. Additionally, there will be an extra reference
|
| 10 |
+
// count to these cached tensors that would prevent buffer inplacing and other
|
| 11 |
+
// checks on tensor uniqueness. If we are not using these systems the enabled
|
| 12 |
+
// flag will be false and we will avoid the hash lookup.
|
| 13 |
+
|
| 14 |
+
TORCH_API bool is_cached_tensor(const at::Tensor& t);
|
| 15 |
+
TORCH_API void add_cached_tensor(const at::Tensor& t);
|
| 16 |
+
TORCH_API void remove_cached_tensor(const at::Tensor& t);
|
| 17 |
+
TORCH_API void set_cached_tensors_enabled(bool enable);
|
| 18 |
+
|
| 19 |
+
// For gradient buffer stealing we will adjust the use count of tensors
|
| 20 |
+
// which are persisted by cudagraphs, just as we need to adjust reference
|
| 21 |
+
// count of tensors with hooks.
|
| 22 |
+
TORCH_API size_t adjusted_use_count(const at::Tensor& t);
|
| 23 |
+
|
| 24 |
+
} // namespace at::caching
|
.venv/lib/python3.11/site-packages/torch/include/ATen/CollapseDims.h
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <c10/util/Exception.h>
|
| 2 |
+
#include <utility>
|
| 3 |
+
|
| 4 |
+
namespace at {
|
| 5 |
+
|
| 6 |
+
/*
|
| 7 |
+
[collapse dims] Updates sizes, and strides to reflect a "collapse" of
|
| 8 |
+
the info, possibly excluding the optional excludeDim. A "collapsed" version
|
| 9 |
+
of the info is the fewest dims that order the tensor's elements in the same
|
| 10 |
+
way as the original info. If excludeDim is specified, the collapse is the
|
| 11 |
+
fewest dims that order the tensor's elements as the original and preserve the
|
| 12 |
+
excluded dimension, unless the tensor collapses to a point.
|
| 13 |
+
|
| 14 |
+
This function returns a pair of values.
|
| 15 |
+
|
| 16 |
+
1) The (new) index of the preserved dimension if excludeDim is
|
| 17 |
+
specified. 0 if the tensor is collapsed to a point. -1
|
| 18 |
+
otherwise.
|
| 19 |
+
|
| 20 |
+
2) The new number of dimensions.
|
| 21 |
+
*/
|
| 22 |
+
template <typename T>
|
| 23 |
+
inline std::pair<int64_t, int64_t> collapse_dims(
|
| 24 |
+
T* sizes,
|
| 25 |
+
T* strides,
|
| 26 |
+
int64_t dims,
|
| 27 |
+
const int excludeDim = -1) {
|
| 28 |
+
TORCH_CHECK(
|
| 29 |
+
excludeDim >= -1 && excludeDim < dims,
|
| 30 |
+
"expected excluded dim between -1 and dims - 1");
|
| 31 |
+
|
| 32 |
+
int64_t stopDim = (excludeDim == -1) ? dims : excludeDim;
|
| 33 |
+
int64_t newIndex = -1;
|
| 34 |
+
int64_t oldIndex = 0;
|
| 35 |
+
int64_t remappedExcludedDim = -1;
|
| 36 |
+
|
| 37 |
+
while (oldIndex < dims) {
|
| 38 |
+
// Finds a dimension to collapse into
|
| 39 |
+
for (; oldIndex < stopDim; ++oldIndex) {
|
| 40 |
+
if (sizes[oldIndex] == 1) {
|
| 41 |
+
continue;
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
++newIndex;
|
| 45 |
+
sizes[newIndex] = sizes[oldIndex];
|
| 46 |
+
strides[newIndex] = strides[oldIndex];
|
| 47 |
+
++oldIndex;
|
| 48 |
+
break;
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
// Collapses dims
|
| 52 |
+
for (; oldIndex < stopDim; ++oldIndex) {
|
| 53 |
+
if (sizes[oldIndex] == 1) {
|
| 54 |
+
continue;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
if (strides[newIndex] == sizes[oldIndex] * strides[oldIndex]) {
|
| 58 |
+
sizes[newIndex] *= sizes[oldIndex];
|
| 59 |
+
strides[newIndex] = strides[oldIndex];
|
| 60 |
+
} else {
|
| 61 |
+
++newIndex;
|
| 62 |
+
sizes[newIndex] = sizes[oldIndex];
|
| 63 |
+
strides[newIndex] = strides[oldIndex];
|
| 64 |
+
}
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
// Handles excludeDim being set (oldIndex == excludeDim)
|
| 68 |
+
if (oldIndex != dims) {
|
| 69 |
+
// Preserves excluded dimension
|
| 70 |
+
++newIndex;
|
| 71 |
+
sizes[newIndex] = sizes[oldIndex];
|
| 72 |
+
strides[newIndex] = strides[oldIndex];
|
| 73 |
+
remappedExcludedDim = newIndex;
|
| 74 |
+
|
| 75 |
+
// Restarts iteration after excludeDim
|
| 76 |
+
++oldIndex;
|
| 77 |
+
stopDim = dims;
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
// Handles special case of all dims size 1
|
| 82 |
+
if (newIndex == -1 || (newIndex == 0 && sizes[0] == 1)) {
|
| 83 |
+
dims = 1;
|
| 84 |
+
sizes[0] = 1;
|
| 85 |
+
strides[0] = 1;
|
| 86 |
+
|
| 87 |
+
return std::pair<int64_t, int64_t>(0, 1);
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
dims = newIndex + 1;
|
| 91 |
+
return std::pair<int64_t, int64_t>(remappedExcludedDim, dims);
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/CompositeExplicitAutogradFunctions.h
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/TensorBody.h>
|
| 2 |
+
|
| 3 |
+
// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
|
| 4 |
+
// Code introduced to avoid cyclic dependency in static dispatch is no longer
|
| 5 |
+
// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
|
| 6 |
+
// to Operators.cpp for supporting multiple backends with multiple kernels.
|
| 7 |
+
//
|
| 8 |
+
// Note [Avoiding Include Cycles In Static Dispatch]
|
| 9 |
+
// In order to avoid #include cycles in the static dispatch build, we've carefully split out
|
| 10 |
+
// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
|
| 11 |
+
//
|
| 12 |
+
// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
|
| 13 |
+
// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
|
| 14 |
+
// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
|
| 15 |
+
// directly inlined into TensorBody.h.
|
| 16 |
+
// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
|
| 17 |
+
// which include functions that have defaultable std::optional<Tensor> arguments.
|
| 18 |
+
// That requires knowing the full Tensor class definition.
|
| 19 |
+
//
|
| 20 |
+
// We break the cycle by doing the following:
|
| 21 |
+
// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
|
| 22 |
+
// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
|
| 23 |
+
// - CPUFunctions_inl.h includes everything else
|
| 24 |
+
// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
|
| 25 |
+
// and then it includes CPUFunctions_inl.h.
|
| 26 |
+
// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
|
| 27 |
+
// - This also means that static dispatch build, CPUFunctions.h only needs to
|
| 28 |
+
// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
|
| 29 |
+
#include <ATen/CompositeExplicitAutogradFunctions_inl.h>
|
.venv/lib/python3.11/site-packages/torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions.h
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/TensorBody.h>
|
| 2 |
+
|
| 3 |
+
// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
|
| 4 |
+
// Code introduced to avoid cyclic dependency in static dispatch is no longer
|
| 5 |
+
// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
|
| 6 |
+
// to Operators.cpp for supporting multiple backends with multiple kernels.
|
| 7 |
+
//
|
| 8 |
+
// Note [Avoiding Include Cycles In Static Dispatch]
|
| 9 |
+
// In order to avoid #include cycles in the static dispatch build, we've carefully split out
|
| 10 |
+
// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
|
| 11 |
+
//
|
| 12 |
+
// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
|
| 13 |
+
// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
|
| 14 |
+
// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
|
| 15 |
+
// directly inlined into TensorBody.h.
|
| 16 |
+
// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
|
| 17 |
+
// which include functions that have defaultable std::optional<Tensor> arguments.
|
| 18 |
+
// That requires knowing the full Tensor class definition.
|
| 19 |
+
//
|
| 20 |
+
// We break the cycle by doing the following:
|
| 21 |
+
// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
|
| 22 |
+
// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
|
| 23 |
+
// - CPUFunctions_inl.h includes everything else
|
| 24 |
+
// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
|
| 25 |
+
// and then it includes CPUFunctions_inl.h.
|
| 26 |
+
// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
|
| 27 |
+
// - This also means that static dispatch build, CPUFunctions.h only needs to
|
| 28 |
+
// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
|
| 29 |
+
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h>
|
.venv/lib/python3.11/site-packages/torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 12 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 13 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 14 |
+
Consider including a specific operator from \
|
| 15 |
+
<ATen/ops/{my_operator}_compositeexplicitautogradnonfunctional_dispatch.h>. \
|
| 16 |
+
See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
#include <ATen/ops/_addmm_activation_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 20 |
+
#include <ATen/ops/_conj_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 21 |
+
#include <ATen/ops/_convert_indices_from_coo_to_csr_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 22 |
+
#include <ATen/ops/_convert_indices_from_csr_to_coo_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 23 |
+
#include <ATen/ops/_fw_primal_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 24 |
+
#include <ATen/ops/_indices_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 25 |
+
#include <ATen/ops/_linalg_det_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 26 |
+
#include <ATen/ops/_linalg_eigh_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 27 |
+
#include <ATen/ops/_linalg_slogdet_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 28 |
+
#include <ATen/ops/_linalg_solve_ex_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 29 |
+
#include <ATen/ops/_linalg_svd_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 30 |
+
#include <ATen/ops/_log_softmax_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 31 |
+
#include <ATen/ops/_log_softmax_backward_data_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 32 |
+
#include <ATen/ops/_make_dual_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 33 |
+
#include <ATen/ops/_neg_view_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 34 |
+
#include <ATen/ops/_nested_get_values_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 35 |
+
#include <ATen/ops/_nested_view_from_buffer_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 36 |
+
#include <ATen/ops/_nested_view_from_jagged_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 37 |
+
#include <ATen/ops/_reshape_alias_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 38 |
+
#include <ATen/ops/_softmax_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 39 |
+
#include <ATen/ops/_softmax_backward_data_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 40 |
+
#include <ATen/ops/_sparse_broadcast_to_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 41 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_view_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 42 |
+
#include <ATen/ops/_trilinear_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 43 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 44 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 45 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 46 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 47 |
+
#include <ATen/ops/_upsample_nearest_exact1d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 48 |
+
#include <ATen/ops/_upsample_nearest_exact1d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 49 |
+
#include <ATen/ops/_upsample_nearest_exact2d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 50 |
+
#include <ATen/ops/_upsample_nearest_exact2d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 51 |
+
#include <ATen/ops/_upsample_nearest_exact3d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 52 |
+
#include <ATen/ops/_upsample_nearest_exact3d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 53 |
+
#include <ATen/ops/_values_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 54 |
+
#include <ATen/ops/acos_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 55 |
+
#include <ATen/ops/acosh_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 56 |
+
#include <ATen/ops/adaptive_max_pool2d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 57 |
+
#include <ATen/ops/adaptive_max_pool2d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 58 |
+
#include <ATen/ops/adaptive_max_pool3d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 59 |
+
#include <ATen/ops/adaptive_max_pool3d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 60 |
+
#include <ATen/ops/add_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 61 |
+
#include <ATen/ops/addcdiv_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 62 |
+
#include <ATen/ops/addcmul_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 63 |
+
#include <ATen/ops/addmm_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 64 |
+
#include <ATen/ops/addmv_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 65 |
+
#include <ATen/ops/alias_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 66 |
+
#include <ATen/ops/all_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 67 |
+
#include <ATen/ops/amax_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 68 |
+
#include <ATen/ops/amin_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 69 |
+
#include <ATen/ops/aminmax_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 70 |
+
#include <ATen/ops/any_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 71 |
+
#include <ATen/ops/argmax_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 72 |
+
#include <ATen/ops/argmin_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 73 |
+
#include <ATen/ops/as_strided_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 74 |
+
#include <ATen/ops/as_strided_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 75 |
+
#include <ATen/ops/as_strided_scatter_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 76 |
+
#include <ATen/ops/asin_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 77 |
+
#include <ATen/ops/asinh_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 78 |
+
#include <ATen/ops/atan_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 79 |
+
#include <ATen/ops/atan2_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 80 |
+
#include <ATen/ops/atanh_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 81 |
+
#include <ATen/ops/avg_pool2d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 82 |
+
#include <ATen/ops/avg_pool2d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 83 |
+
#include <ATen/ops/avg_pool3d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 84 |
+
#include <ATen/ops/avg_pool3d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 85 |
+
#include <ATen/ops/baddbmm_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 86 |
+
#include <ATen/ops/bernoulli_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 87 |
+
#include <ATen/ops/bitwise_and_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 88 |
+
#include <ATen/ops/bitwise_left_shift_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 89 |
+
#include <ATen/ops/bitwise_not_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 90 |
+
#include <ATen/ops/bitwise_or_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 91 |
+
#include <ATen/ops/bitwise_right_shift_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 92 |
+
#include <ATen/ops/bitwise_xor_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 93 |
+
#include <ATen/ops/bmm_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 94 |
+
#include <ATen/ops/cat_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 95 |
+
#include <ATen/ops/ccol_indices_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 96 |
+
#include <ATen/ops/ceil_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 97 |
+
#include <ATen/ops/clamp_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 98 |
+
#include <ATen/ops/clamp_max_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 99 |
+
#include <ATen/ops/clamp_min_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 100 |
+
#include <ATen/ops/col_indices_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 101 |
+
#include <ATen/ops/copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 102 |
+
#include <ATen/ops/copysign_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 103 |
+
#include <ATen/ops/cos_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 104 |
+
#include <ATen/ops/cosh_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 105 |
+
#include <ATen/ops/crow_indices_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 106 |
+
#include <ATen/ops/cumprod_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 107 |
+
#include <ATen/ops/cumsum_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 108 |
+
#include <ATen/ops/detach_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 109 |
+
#include <ATen/ops/diag_embed_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 110 |
+
#include <ATen/ops/diagonal_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 111 |
+
#include <ATen/ops/diagonal_scatter_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 112 |
+
#include <ATen/ops/digamma_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 113 |
+
#include <ATen/ops/div_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 114 |
+
#include <ATen/ops/elu_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 115 |
+
#include <ATen/ops/elu_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 116 |
+
#include <ATen/ops/eq_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 117 |
+
#include <ATen/ops/erf_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 118 |
+
#include <ATen/ops/erfc_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 119 |
+
#include <ATen/ops/erfinv_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 120 |
+
#include <ATen/ops/exp_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 121 |
+
#include <ATen/ops/exp2_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 122 |
+
#include <ATen/ops/expand_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 123 |
+
#include <ATen/ops/expm1_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 124 |
+
#include <ATen/ops/floor_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 125 |
+
#include <ATen/ops/fmax_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 126 |
+
#include <ATen/ops/fmin_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 127 |
+
#include <ATen/ops/fmod_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 128 |
+
#include <ATen/ops/frac_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 129 |
+
#include <ATen/ops/fractional_max_pool2d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 130 |
+
#include <ATen/ops/fractional_max_pool2d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 131 |
+
#include <ATen/ops/fractional_max_pool3d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 132 |
+
#include <ATen/ops/gather_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 133 |
+
#include <ATen/ops/gcd_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 134 |
+
#include <ATen/ops/ge_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 135 |
+
#include <ATen/ops/gelu_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 136 |
+
#include <ATen/ops/gelu_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 137 |
+
#include <ATen/ops/glu_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 138 |
+
#include <ATen/ops/gt_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 139 |
+
#include <ATen/ops/hardshrink_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 140 |
+
#include <ATen/ops/hardshrink_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 141 |
+
#include <ATen/ops/hardsigmoid_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 142 |
+
#include <ATen/ops/hardsigmoid_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 143 |
+
#include <ATen/ops/heaviside_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 144 |
+
#include <ATen/ops/hypot_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 145 |
+
#include <ATen/ops/i0_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 146 |
+
#include <ATen/ops/igamma_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 147 |
+
#include <ATen/ops/igammac_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 148 |
+
#include <ATen/ops/index_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 149 |
+
#include <ATen/ops/index_add_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 150 |
+
#include <ATen/ops/index_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 151 |
+
#include <ATen/ops/index_reduce_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 152 |
+
#include <ATen/ops/indices_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 153 |
+
#include <ATen/ops/isin_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 154 |
+
#include <ATen/ops/isneginf_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 155 |
+
#include <ATen/ops/isposinf_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 156 |
+
#include <ATen/ops/lcm_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 157 |
+
#include <ATen/ops/le_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 158 |
+
#include <ATen/ops/leaky_relu_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 159 |
+
#include <ATen/ops/leaky_relu_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 160 |
+
#include <ATen/ops/lerp_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 161 |
+
#include <ATen/ops/lgamma_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 162 |
+
#include <ATen/ops/lift_fresh_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 163 |
+
#include <ATen/ops/linalg_cholesky_ex_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 164 |
+
#include <ATen/ops/linalg_cross_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 165 |
+
#include <ATen/ops/linalg_inv_ex_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 166 |
+
#include <ATen/ops/linalg_ldl_factor_ex_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 167 |
+
#include <ATen/ops/linalg_ldl_solve_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 168 |
+
#include <ATen/ops/linalg_lu_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 169 |
+
#include <ATen/ops/linalg_lu_factor_ex_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 170 |
+
#include <ATen/ops/linalg_lu_solve_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 171 |
+
#include <ATen/ops/linalg_pinv_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 172 |
+
#include <ATen/ops/linalg_qr_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 173 |
+
#include <ATen/ops/linalg_vector_norm_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 174 |
+
#include <ATen/ops/log_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 175 |
+
#include <ATen/ops/log10_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 176 |
+
#include <ATen/ops/log1p_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 177 |
+
#include <ATen/ops/log2_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 178 |
+
#include <ATen/ops/logaddexp_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 179 |
+
#include <ATen/ops/logaddexp2_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 180 |
+
#include <ATen/ops/logit_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 181 |
+
#include <ATen/ops/logsumexp_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 182 |
+
#include <ATen/ops/lt_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 183 |
+
#include <ATen/ops/lu_unpack_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 184 |
+
#include <ATen/ops/max_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 185 |
+
#include <ATen/ops/max_pool2d_with_indices_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 186 |
+
#include <ATen/ops/max_pool2d_with_indices_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 187 |
+
#include <ATen/ops/maximum_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 188 |
+
#include <ATen/ops/mean_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 189 |
+
#include <ATen/ops/min_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 190 |
+
#include <ATen/ops/minimum_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 191 |
+
#include <ATen/ops/mish_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 192 |
+
#include <ATen/ops/mm_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 193 |
+
#include <ATen/ops/mse_loss_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 194 |
+
#include <ATen/ops/mul_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 195 |
+
#include <ATen/ops/narrow_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 196 |
+
#include <ATen/ops/ne_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 197 |
+
#include <ATen/ops/neg_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 198 |
+
#include <ATen/ops/new_empty_strided_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 199 |
+
#include <ATen/ops/nextafter_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 200 |
+
#include <ATen/ops/nll_loss_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 201 |
+
#include <ATen/ops/nll_loss_forward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 202 |
+
#include <ATen/ops/norm_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 203 |
+
#include <ATen/ops/permute_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 204 |
+
#include <ATen/ops/pixel_shuffle_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 205 |
+
#include <ATen/ops/pixel_unshuffle_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 206 |
+
#include <ATen/ops/polygamma_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 207 |
+
#include <ATen/ops/pow_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 208 |
+
#include <ATen/ops/prod_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 209 |
+
#include <ATen/ops/reciprocal_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 210 |
+
#include <ATen/ops/reflection_pad1d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 211 |
+
#include <ATen/ops/reflection_pad1d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 212 |
+
#include <ATen/ops/reflection_pad3d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 213 |
+
#include <ATen/ops/reflection_pad3d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 214 |
+
#include <ATen/ops/remainder_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 215 |
+
#include <ATen/ops/renorm_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 216 |
+
#include <ATen/ops/replication_pad1d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 217 |
+
#include <ATen/ops/replication_pad1d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 218 |
+
#include <ATen/ops/replication_pad2d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 219 |
+
#include <ATen/ops/replication_pad3d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 220 |
+
#include <ATen/ops/round_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 221 |
+
#include <ATen/ops/row_indices_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 222 |
+
#include <ATen/ops/rsqrt_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 223 |
+
#include <ATen/ops/scatter_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 224 |
+
#include <ATen/ops/scatter_add_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 225 |
+
#include <ATen/ops/scatter_reduce_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 226 |
+
#include <ATen/ops/select_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 227 |
+
#include <ATen/ops/select_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 228 |
+
#include <ATen/ops/select_scatter_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 229 |
+
#include <ATen/ops/sgn_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 230 |
+
#include <ATen/ops/sigmoid_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 231 |
+
#include <ATen/ops/sigmoid_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 232 |
+
#include <ATen/ops/sign_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 233 |
+
#include <ATen/ops/signbit_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 234 |
+
#include <ATen/ops/silu_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 235 |
+
#include <ATen/ops/silu_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 236 |
+
#include <ATen/ops/sin_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 237 |
+
#include <ATen/ops/sinc_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 238 |
+
#include <ATen/ops/sinh_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 239 |
+
#include <ATen/ops/slice_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 240 |
+
#include <ATen/ops/slice_scatter_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 241 |
+
#include <ATen/ops/slow_conv_transpose2d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 242 |
+
#include <ATen/ops/smooth_l1_loss_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 243 |
+
#include <ATen/ops/softplus_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 244 |
+
#include <ATen/ops/softplus_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 245 |
+
#include <ATen/ops/softshrink_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 246 |
+
#include <ATen/ops/softshrink_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 247 |
+
#include <ATen/ops/sort_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 248 |
+
#include <ATen/ops/special_airy_ai_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 249 |
+
#include <ATen/ops/special_bessel_j0_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 250 |
+
#include <ATen/ops/special_bessel_j1_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 251 |
+
#include <ATen/ops/special_bessel_y0_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 252 |
+
#include <ATen/ops/special_bessel_y1_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 253 |
+
#include <ATen/ops/special_chebyshev_polynomial_t_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 254 |
+
#include <ATen/ops/special_chebyshev_polynomial_u_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 255 |
+
#include <ATen/ops/special_chebyshev_polynomial_v_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 256 |
+
#include <ATen/ops/special_chebyshev_polynomial_w_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 257 |
+
#include <ATen/ops/special_entr_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 258 |
+
#include <ATen/ops/special_erfcx_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 259 |
+
#include <ATen/ops/special_hermite_polynomial_h_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 260 |
+
#include <ATen/ops/special_hermite_polynomial_he_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 261 |
+
#include <ATen/ops/special_i0e_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 262 |
+
#include <ATen/ops/special_i1_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 263 |
+
#include <ATen/ops/special_i1e_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 264 |
+
#include <ATen/ops/special_laguerre_polynomial_l_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 265 |
+
#include <ATen/ops/special_legendre_polynomial_p_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 266 |
+
#include <ATen/ops/special_log_ndtr_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 267 |
+
#include <ATen/ops/special_modified_bessel_i0_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 268 |
+
#include <ATen/ops/special_modified_bessel_i1_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 269 |
+
#include <ATen/ops/special_modified_bessel_k0_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 270 |
+
#include <ATen/ops/special_modified_bessel_k1_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 271 |
+
#include <ATen/ops/special_ndtri_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 272 |
+
#include <ATen/ops/special_scaled_modified_bessel_k0_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 273 |
+
#include <ATen/ops/special_scaled_modified_bessel_k1_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 274 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_t_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 275 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_u_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 276 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_v_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 277 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_w_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 278 |
+
#include <ATen/ops/special_spherical_bessel_j0_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 279 |
+
#include <ATen/ops/special_xlog1py_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 280 |
+
#include <ATen/ops/special_zeta_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 281 |
+
#include <ATen/ops/split_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 282 |
+
#include <ATen/ops/split_with_sizes_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 283 |
+
#include <ATen/ops/sqrt_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 284 |
+
#include <ATen/ops/squeeze_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 285 |
+
#include <ATen/ops/sub_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 286 |
+
#include <ATen/ops/sum_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 287 |
+
#include <ATen/ops/t_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 288 |
+
#include <ATen/ops/tan_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 289 |
+
#include <ATen/ops/tanh_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 290 |
+
#include <ATen/ops/tanh_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 291 |
+
#include <ATen/ops/threshold_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 292 |
+
#include <ATen/ops/threshold_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 293 |
+
#include <ATen/ops/topk_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 294 |
+
#include <ATen/ops/transpose_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 295 |
+
#include <ATen/ops/triangular_solve_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 296 |
+
#include <ATen/ops/tril_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 297 |
+
#include <ATen/ops/triu_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 298 |
+
#include <ATen/ops/trunc_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 299 |
+
#include <ATen/ops/unbind_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 300 |
+
#include <ATen/ops/unfold_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 301 |
+
#include <ATen/ops/unsqueeze_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 302 |
+
#include <ATen/ops/upsample_bicubic2d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 303 |
+
#include <ATen/ops/upsample_bicubic2d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 304 |
+
#include <ATen/ops/upsample_bilinear2d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 305 |
+
#include <ATen/ops/upsample_bilinear2d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 306 |
+
#include <ATen/ops/upsample_linear1d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 307 |
+
#include <ATen/ops/upsample_linear1d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 308 |
+
#include <ATen/ops/upsample_nearest1d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 309 |
+
#include <ATen/ops/upsample_nearest1d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 310 |
+
#include <ATen/ops/upsample_nearest2d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 311 |
+
#include <ATen/ops/upsample_nearest2d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 312 |
+
#include <ATen/ops/upsample_nearest3d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 313 |
+
#include <ATen/ops/upsample_nearest3d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 314 |
+
#include <ATen/ops/upsample_trilinear3d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 315 |
+
#include <ATen/ops/upsample_trilinear3d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 316 |
+
#include <ATen/ops/values_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 317 |
+
#include <ATen/ops/view_as_complex_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 318 |
+
#include <ATen/ops/view_as_real_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 319 |
+
#include <ATen/ops/view_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 320 |
+
#include <ATen/ops/xlogy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
|
.venv/lib/python3.11/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions_inl.h
ADDED
|
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 12 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 13 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 14 |
+
Consider including a specific operator from \
|
| 15 |
+
<ATen/ops/{my_operator}_compositeimplicitautograd_dispatch.h>. \
|
| 16 |
+
See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
#include <ATen/ops/_add_batch_dim_compositeimplicitautograd_dispatch.h>
|
| 20 |
+
#include <ATen/ops/_assert_tensor_metadata_compositeimplicitautograd_dispatch.h>
|
| 21 |
+
#include <ATen/ops/_autocast_to_full_precision_compositeimplicitautograd_dispatch.h>
|
| 22 |
+
#include <ATen/ops/_autocast_to_reduced_precision_compositeimplicitautograd_dispatch.h>
|
| 23 |
+
#include <ATen/ops/_backward_compositeimplicitautograd_dispatch.h>
|
| 24 |
+
#include <ATen/ops/_batch_norm_impl_index_compositeimplicitautograd_dispatch.h>
|
| 25 |
+
#include <ATen/ops/_batch_norm_impl_index_backward_compositeimplicitautograd_dispatch.h>
|
| 26 |
+
#include <ATen/ops/_cast_Byte_compositeimplicitautograd_dispatch.h>
|
| 27 |
+
#include <ATen/ops/_cast_Char_compositeimplicitautograd_dispatch.h>
|
| 28 |
+
#include <ATen/ops/_cast_Double_compositeimplicitautograd_dispatch.h>
|
| 29 |
+
#include <ATen/ops/_cast_Float_compositeimplicitautograd_dispatch.h>
|
| 30 |
+
#include <ATen/ops/_cast_Half_compositeimplicitautograd_dispatch.h>
|
| 31 |
+
#include <ATen/ops/_cast_Int_compositeimplicitautograd_dispatch.h>
|
| 32 |
+
#include <ATen/ops/_cast_Long_compositeimplicitautograd_dispatch.h>
|
| 33 |
+
#include <ATen/ops/_cast_Short_compositeimplicitautograd_dispatch.h>
|
| 34 |
+
#include <ATen/ops/_choose_qparams_per_tensor_compositeimplicitautograd_dispatch.h>
|
| 35 |
+
#include <ATen/ops/_convolution_compositeimplicitautograd_dispatch.h>
|
| 36 |
+
#include <ATen/ops/_convolution_double_backward_compositeimplicitautograd_dispatch.h>
|
| 37 |
+
#include <ATen/ops/_convolution_mode_compositeimplicitautograd_dispatch.h>
|
| 38 |
+
#include <ATen/ops/_cufft_clear_plan_cache_compositeimplicitautograd_dispatch.h>
|
| 39 |
+
#include <ATen/ops/_cufft_get_plan_cache_max_size_compositeimplicitautograd_dispatch.h>
|
| 40 |
+
#include <ATen/ops/_cufft_get_plan_cache_size_compositeimplicitautograd_dispatch.h>
|
| 41 |
+
#include <ATen/ops/_cufft_set_plan_cache_max_size_compositeimplicitautograd_dispatch.h>
|
| 42 |
+
#include <ATen/ops/_debug_has_internal_overlap_compositeimplicitautograd_dispatch.h>
|
| 43 |
+
#include <ATen/ops/_dim_arange_compositeimplicitautograd_dispatch.h>
|
| 44 |
+
#include <ATen/ops/_embedding_bag_sparse_backward_compositeimplicitautograd_dispatch.h>
|
| 45 |
+
#include <ATen/ops/_gather_sparse_backward_compositeimplicitautograd_dispatch.h>
|
| 46 |
+
#include <ATen/ops/_grid_sampler_2d_cpu_fallback_backward_compositeimplicitautograd_dispatch.h>
|
| 47 |
+
#include <ATen/ops/_has_compatible_shallow_copy_type_compositeimplicitautograd_dispatch.h>
|
| 48 |
+
#include <ATen/ops/_is_zerotensor_compositeimplicitautograd_dispatch.h>
|
| 49 |
+
#include <ATen/ops/_lu_with_info_compositeimplicitautograd_dispatch.h>
|
| 50 |
+
#include <ATen/ops/_nnpack_available_compositeimplicitautograd_dispatch.h>
|
| 51 |
+
#include <ATen/ops/_pack_padded_sequence_backward_compositeimplicitautograd_dispatch.h>
|
| 52 |
+
#include <ATen/ops/_pad_circular_compositeimplicitautograd_dispatch.h>
|
| 53 |
+
#include <ATen/ops/_pad_enum_compositeimplicitautograd_dispatch.h>
|
| 54 |
+
#include <ATen/ops/_pad_packed_sequence_compositeimplicitautograd_dispatch.h>
|
| 55 |
+
#include <ATen/ops/_propagate_xla_data_compositeimplicitautograd_dispatch.h>
|
| 56 |
+
#include <ATen/ops/_remove_batch_dim_compositeimplicitautograd_dispatch.h>
|
| 57 |
+
#include <ATen/ops/_reshape_from_tensor_compositeimplicitautograd_dispatch.h>
|
| 58 |
+
#include <ATen/ops/_rowwise_prune_compositeimplicitautograd_dispatch.h>
|
| 59 |
+
#include <ATen/ops/_saturate_weight_to_fp16_compositeimplicitautograd_dispatch.h>
|
| 60 |
+
#include <ATen/ops/_scaled_dot_product_attention_math_compositeimplicitautograd_dispatch.h>
|
| 61 |
+
#include <ATen/ops/_shape_as_tensor_compositeimplicitautograd_dispatch.h>
|
| 62 |
+
#include <ATen/ops/_sobol_engine_draw_compositeimplicitautograd_dispatch.h>
|
| 63 |
+
#include <ATen/ops/_sobol_engine_ff_compositeimplicitautograd_dispatch.h>
|
| 64 |
+
#include <ATen/ops/_sobol_engine_initialize_state_compositeimplicitautograd_dispatch.h>
|
| 65 |
+
#include <ATen/ops/_sobol_engine_scramble_compositeimplicitautograd_dispatch.h>
|
| 66 |
+
#include <ATen/ops/_sparse_bsc_tensor_unsafe_compositeimplicitautograd_dispatch.h>
|
| 67 |
+
#include <ATen/ops/_sparse_bsr_tensor_unsafe_compositeimplicitautograd_dispatch.h>
|
| 68 |
+
#include <ATen/ops/_sparse_compressed_tensor_unsafe_compositeimplicitautograd_dispatch.h>
|
| 69 |
+
#include <ATen/ops/_sparse_coo_tensor_unsafe_compositeimplicitautograd_dispatch.h>
|
| 70 |
+
#include <ATen/ops/_sparse_csc_tensor_unsafe_compositeimplicitautograd_dispatch.h>
|
| 71 |
+
#include <ATen/ops/_sparse_csr_tensor_unsafe_compositeimplicitautograd_dispatch.h>
|
| 72 |
+
#include <ATen/ops/_sparse_log_softmax_compositeimplicitautograd_dispatch.h>
|
| 73 |
+
#include <ATen/ops/_sparse_mm_compositeimplicitautograd_dispatch.h>
|
| 74 |
+
#include <ATen/ops/_sparse_softmax_compositeimplicitautograd_dispatch.h>
|
| 75 |
+
#include <ATen/ops/_sparse_sum_compositeimplicitautograd_dispatch.h>
|
| 76 |
+
#include <ATen/ops/_test_ambiguous_defaults_compositeimplicitautograd_dispatch.h>
|
| 77 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_compositeimplicitautograd_dispatch.h>
|
| 78 |
+
#include <ATen/ops/_test_check_tensor_compositeimplicitautograd_dispatch.h>
|
| 79 |
+
#include <ATen/ops/_test_serialization_subcmul_compositeimplicitautograd_dispatch.h>
|
| 80 |
+
#include <ATen/ops/_test_string_default_compositeimplicitautograd_dispatch.h>
|
| 81 |
+
#include <ATen/ops/_thnn_differentiable_gru_cell_backward_compositeimplicitautograd_dispatch.h>
|
| 82 |
+
#include <ATen/ops/_thnn_differentiable_lstm_cell_backward_compositeimplicitautograd_dispatch.h>
|
| 83 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_backward_compositeimplicitautograd_dispatch.h>
|
| 84 |
+
#include <ATen/ops/_to_cpu_compositeimplicitautograd_dispatch.h>
|
| 85 |
+
#include <ATen/ops/_unpack_dual_compositeimplicitautograd_dispatch.h>
|
| 86 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_compositeimplicitautograd_dispatch.h>
|
| 87 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_compositeimplicitautograd_dispatch.h>
|
| 88 |
+
#include <ATen/ops/_upsample_nearest_exact1d_compositeimplicitautograd_dispatch.h>
|
| 89 |
+
#include <ATen/ops/_upsample_nearest_exact2d_compositeimplicitautograd_dispatch.h>
|
| 90 |
+
#include <ATen/ops/_upsample_nearest_exact3d_compositeimplicitautograd_dispatch.h>
|
| 91 |
+
#include <ATen/ops/_use_cudnn_rnn_flatten_weight_compositeimplicitautograd_dispatch.h>
|
| 92 |
+
#include <ATen/ops/_validate_sparse_bsc_tensor_args_compositeimplicitautograd_dispatch.h>
|
| 93 |
+
#include <ATen/ops/_validate_sparse_bsr_tensor_args_compositeimplicitautograd_dispatch.h>
|
| 94 |
+
#include <ATen/ops/_validate_sparse_compressed_tensor_args_compositeimplicitautograd_dispatch.h>
|
| 95 |
+
#include <ATen/ops/_validate_sparse_coo_tensor_args_compositeimplicitautograd_dispatch.h>
|
| 96 |
+
#include <ATen/ops/_validate_sparse_csc_tensor_args_compositeimplicitautograd_dispatch.h>
|
| 97 |
+
#include <ATen/ops/_validate_sparse_csr_tensor_args_compositeimplicitautograd_dispatch.h>
|
| 98 |
+
#include <ATen/ops/_version_compositeimplicitautograd_dispatch.h>
|
| 99 |
+
#include <ATen/ops/_weight_norm_compositeimplicitautograd_dispatch.h>
|
| 100 |
+
#include <ATen/ops/_weight_norm_differentiable_backward_compositeimplicitautograd_dispatch.h>
|
| 101 |
+
#include <ATen/ops/_wrapped_linear_prepack_compositeimplicitautograd_dispatch.h>
|
| 102 |
+
#include <ATen/ops/_wrapped_quantized_linear_prepacked_compositeimplicitautograd_dispatch.h>
|
| 103 |
+
#include <ATen/ops/absolute_compositeimplicitautograd_dispatch.h>
|
| 104 |
+
#include <ATen/ops/adaptive_avg_pool1d_compositeimplicitautograd_dispatch.h>
|
| 105 |
+
#include <ATen/ops/adaptive_avg_pool2d_compositeimplicitautograd_dispatch.h>
|
| 106 |
+
#include <ATen/ops/adaptive_avg_pool3d_compositeimplicitautograd_dispatch.h>
|
| 107 |
+
#include <ATen/ops/adaptive_max_pool1d_compositeimplicitautograd_dispatch.h>
|
| 108 |
+
#include <ATen/ops/adjoint_compositeimplicitautograd_dispatch.h>
|
| 109 |
+
#include <ATen/ops/affine_grid_generator_backward_compositeimplicitautograd_dispatch.h>
|
| 110 |
+
#include <ATen/ops/align_as_compositeimplicitautograd_dispatch.h>
|
| 111 |
+
#include <ATen/ops/align_tensors_compositeimplicitautograd_dispatch.h>
|
| 112 |
+
#include <ATen/ops/align_to_compositeimplicitautograd_dispatch.h>
|
| 113 |
+
#include <ATen/ops/all_compositeimplicitautograd_dispatch.h>
|
| 114 |
+
#include <ATen/ops/alpha_dropout_compositeimplicitautograd_dispatch.h>
|
| 115 |
+
#include <ATen/ops/and_compositeimplicitautograd_dispatch.h>
|
| 116 |
+
#include <ATen/ops/any_compositeimplicitautograd_dispatch.h>
|
| 117 |
+
#include <ATen/ops/arccos_compositeimplicitautograd_dispatch.h>
|
| 118 |
+
#include <ATen/ops/arccosh_compositeimplicitautograd_dispatch.h>
|
| 119 |
+
#include <ATen/ops/arcsin_compositeimplicitautograd_dispatch.h>
|
| 120 |
+
#include <ATen/ops/arcsinh_compositeimplicitautograd_dispatch.h>
|
| 121 |
+
#include <ATen/ops/arctan_compositeimplicitautograd_dispatch.h>
|
| 122 |
+
#include <ATen/ops/arctan2_compositeimplicitautograd_dispatch.h>
|
| 123 |
+
#include <ATen/ops/arctanh_compositeimplicitautograd_dispatch.h>
|
| 124 |
+
#include <ATen/ops/argsort_compositeimplicitautograd_dispatch.h>
|
| 125 |
+
#include <ATen/ops/argwhere_compositeimplicitautograd_dispatch.h>
|
| 126 |
+
#include <ATen/ops/atleast_1d_compositeimplicitautograd_dispatch.h>
|
| 127 |
+
#include <ATen/ops/atleast_2d_compositeimplicitautograd_dispatch.h>
|
| 128 |
+
#include <ATen/ops/atleast_3d_compositeimplicitautograd_dispatch.h>
|
| 129 |
+
#include <ATen/ops/avg_pool1d_compositeimplicitautograd_dispatch.h>
|
| 130 |
+
#include <ATen/ops/batch_norm_compositeimplicitautograd_dispatch.h>
|
| 131 |
+
#include <ATen/ops/bilinear_compositeimplicitautograd_dispatch.h>
|
| 132 |
+
#include <ATen/ops/broadcast_tensors_compositeimplicitautograd_dispatch.h>
|
| 133 |
+
#include <ATen/ops/broadcast_to_compositeimplicitautograd_dispatch.h>
|
| 134 |
+
#include <ATen/ops/can_cast_compositeimplicitautograd_dispatch.h>
|
| 135 |
+
#include <ATen/ops/cartesian_prod_compositeimplicitautograd_dispatch.h>
|
| 136 |
+
#include <ATen/ops/cat_compositeimplicitautograd_dispatch.h>
|
| 137 |
+
#include <ATen/ops/cdist_compositeimplicitautograd_dispatch.h>
|
| 138 |
+
#include <ATen/ops/chain_matmul_compositeimplicitautograd_dispatch.h>
|
| 139 |
+
#include <ATen/ops/chalf_compositeimplicitautograd_dispatch.h>
|
| 140 |
+
#include <ATen/ops/choose_qparams_optimized_compositeimplicitautograd_dispatch.h>
|
| 141 |
+
#include <ATen/ops/chunk_compositeimplicitautograd_dispatch.h>
|
| 142 |
+
#include <ATen/ops/clip_compositeimplicitautograd_dispatch.h>
|
| 143 |
+
#include <ATen/ops/coalesce_compositeimplicitautograd_dispatch.h>
|
| 144 |
+
#include <ATen/ops/column_stack_compositeimplicitautograd_dispatch.h>
|
| 145 |
+
#include <ATen/ops/combinations_compositeimplicitautograd_dispatch.h>
|
| 146 |
+
#include <ATen/ops/concat_compositeimplicitautograd_dispatch.h>
|
| 147 |
+
#include <ATen/ops/concatenate_compositeimplicitautograd_dispatch.h>
|
| 148 |
+
#include <ATen/ops/conj_compositeimplicitautograd_dispatch.h>
|
| 149 |
+
#include <ATen/ops/conj_physical_compositeimplicitautograd_dispatch.h>
|
| 150 |
+
#include <ATen/ops/contiguous_compositeimplicitautograd_dispatch.h>
|
| 151 |
+
#include <ATen/ops/conv1d_compositeimplicitautograd_dispatch.h>
|
| 152 |
+
#include <ATen/ops/conv2d_compositeimplicitautograd_dispatch.h>
|
| 153 |
+
#include <ATen/ops/conv3d_compositeimplicitautograd_dispatch.h>
|
| 154 |
+
#include <ATen/ops/conv_tbc_backward_compositeimplicitautograd_dispatch.h>
|
| 155 |
+
#include <ATen/ops/conv_transpose1d_compositeimplicitautograd_dispatch.h>
|
| 156 |
+
#include <ATen/ops/conv_transpose2d_compositeimplicitautograd_dispatch.h>
|
| 157 |
+
#include <ATen/ops/conv_transpose3d_compositeimplicitautograd_dispatch.h>
|
| 158 |
+
#include <ATen/ops/corrcoef_compositeimplicitautograd_dispatch.h>
|
| 159 |
+
#include <ATen/ops/cosine_embedding_loss_compositeimplicitautograd_dispatch.h>
|
| 160 |
+
#include <ATen/ops/cosine_similarity_compositeimplicitautograd_dispatch.h>
|
| 161 |
+
#include <ATen/ops/cov_compositeimplicitautograd_dispatch.h>
|
| 162 |
+
#include <ATen/ops/cross_compositeimplicitautograd_dispatch.h>
|
| 163 |
+
#include <ATen/ops/cross_entropy_loss_compositeimplicitautograd_dispatch.h>
|
| 164 |
+
#include <ATen/ops/ctc_loss_compositeimplicitautograd_dispatch.h>
|
| 165 |
+
#include <ATen/ops/cudnn_is_acceptable_compositeimplicitautograd_dispatch.h>
|
| 166 |
+
#include <ATen/ops/cummax_compositeimplicitautograd_dispatch.h>
|
| 167 |
+
#include <ATen/ops/cummaxmin_backward_compositeimplicitautograd_dispatch.h>
|
| 168 |
+
#include <ATen/ops/cummin_compositeimplicitautograd_dispatch.h>
|
| 169 |
+
#include <ATen/ops/cumprod_compositeimplicitautograd_dispatch.h>
|
| 170 |
+
#include <ATen/ops/cumprod_backward_compositeimplicitautograd_dispatch.h>
|
| 171 |
+
#include <ATen/ops/cumsum_compositeimplicitautograd_dispatch.h>
|
| 172 |
+
#include <ATen/ops/cumulative_trapezoid_compositeimplicitautograd_dispatch.h>
|
| 173 |
+
#include <ATen/ops/data_compositeimplicitautograd_dispatch.h>
|
| 174 |
+
#include <ATen/ops/det_compositeimplicitautograd_dispatch.h>
|
| 175 |
+
#include <ATen/ops/diag_compositeimplicitautograd_dispatch.h>
|
| 176 |
+
#include <ATen/ops/diagflat_compositeimplicitautograd_dispatch.h>
|
| 177 |
+
#include <ATen/ops/diagonal_compositeimplicitautograd_dispatch.h>
|
| 178 |
+
#include <ATen/ops/diff_compositeimplicitautograd_dispatch.h>
|
| 179 |
+
#include <ATen/ops/divide_compositeimplicitautograd_dispatch.h>
|
| 180 |
+
#include <ATen/ops/dropout_compositeimplicitautograd_dispatch.h>
|
| 181 |
+
#include <ATen/ops/dsplit_compositeimplicitautograd_dispatch.h>
|
| 182 |
+
#include <ATen/ops/dstack_compositeimplicitautograd_dispatch.h>
|
| 183 |
+
#include <ATen/ops/einsum_compositeimplicitautograd_dispatch.h>
|
| 184 |
+
#include <ATen/ops/embedding_backward_compositeimplicitautograd_dispatch.h>
|
| 185 |
+
#include <ATen/ops/embedding_bag_compositeimplicitautograd_dispatch.h>
|
| 186 |
+
#include <ATen/ops/embedding_sparse_backward_compositeimplicitautograd_dispatch.h>
|
| 187 |
+
#include <ATen/ops/empty_compositeimplicitautograd_dispatch.h>
|
| 188 |
+
#include <ATen/ops/expand_as_compositeimplicitautograd_dispatch.h>
|
| 189 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_compositeimplicitautograd_dispatch.h>
|
| 190 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_backward_compositeimplicitautograd_dispatch.h>
|
| 191 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_compositeimplicitautograd_dispatch.h>
|
| 192 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_backward_compositeimplicitautograd_dispatch.h>
|
| 193 |
+
#include <ATen/ops/fbgemm_linear_fp16_weight_compositeimplicitautograd_dispatch.h>
|
| 194 |
+
#include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation_compositeimplicitautograd_dispatch.h>
|
| 195 |
+
#include <ATen/ops/fbgemm_linear_int8_weight_compositeimplicitautograd_dispatch.h>
|
| 196 |
+
#include <ATen/ops/fbgemm_linear_int8_weight_fp32_activation_compositeimplicitautograd_dispatch.h>
|
| 197 |
+
#include <ATen/ops/fbgemm_linear_quantize_weight_compositeimplicitautograd_dispatch.h>
|
| 198 |
+
#include <ATen/ops/fbgemm_pack_gemm_matrix_fp16_compositeimplicitautograd_dispatch.h>
|
| 199 |
+
#include <ATen/ops/fbgemm_pack_quantized_matrix_compositeimplicitautograd_dispatch.h>
|
| 200 |
+
#include <ATen/ops/feature_alpha_dropout_compositeimplicitautograd_dispatch.h>
|
| 201 |
+
#include <ATen/ops/feature_dropout_compositeimplicitautograd_dispatch.h>
|
| 202 |
+
#include <ATen/ops/fft_fft_compositeimplicitautograd_dispatch.h>
|
| 203 |
+
#include <ATen/ops/fft_fft2_compositeimplicitautograd_dispatch.h>
|
| 204 |
+
#include <ATen/ops/fft_fftn_compositeimplicitautograd_dispatch.h>
|
| 205 |
+
#include <ATen/ops/fft_fftshift_compositeimplicitautograd_dispatch.h>
|
| 206 |
+
#include <ATen/ops/fft_hfft_compositeimplicitautograd_dispatch.h>
|
| 207 |
+
#include <ATen/ops/fft_hfft2_compositeimplicitautograd_dispatch.h>
|
| 208 |
+
#include <ATen/ops/fft_hfftn_compositeimplicitautograd_dispatch.h>
|
| 209 |
+
#include <ATen/ops/fft_ifft_compositeimplicitautograd_dispatch.h>
|
| 210 |
+
#include <ATen/ops/fft_ifft2_compositeimplicitautograd_dispatch.h>
|
| 211 |
+
#include <ATen/ops/fft_ifftn_compositeimplicitautograd_dispatch.h>
|
| 212 |
+
#include <ATen/ops/fft_ifftshift_compositeimplicitautograd_dispatch.h>
|
| 213 |
+
#include <ATen/ops/fft_ihfft_compositeimplicitautograd_dispatch.h>
|
| 214 |
+
#include <ATen/ops/fft_ihfft2_compositeimplicitautograd_dispatch.h>
|
| 215 |
+
#include <ATen/ops/fft_ihfftn_compositeimplicitautograd_dispatch.h>
|
| 216 |
+
#include <ATen/ops/fft_irfft_compositeimplicitautograd_dispatch.h>
|
| 217 |
+
#include <ATen/ops/fft_irfft2_compositeimplicitautograd_dispatch.h>
|
| 218 |
+
#include <ATen/ops/fft_irfftn_compositeimplicitautograd_dispatch.h>
|
| 219 |
+
#include <ATen/ops/fft_rfft_compositeimplicitautograd_dispatch.h>
|
| 220 |
+
#include <ATen/ops/fft_rfft2_compositeimplicitautograd_dispatch.h>
|
| 221 |
+
#include <ATen/ops/fft_rfftn_compositeimplicitautograd_dispatch.h>
|
| 222 |
+
#include <ATen/ops/fill_diagonal_compositeimplicitautograd_dispatch.h>
|
| 223 |
+
#include <ATen/ops/fix_compositeimplicitautograd_dispatch.h>
|
| 224 |
+
#include <ATen/ops/flatten_compositeimplicitautograd_dispatch.h>
|
| 225 |
+
#include <ATen/ops/flatten_dense_tensors_compositeimplicitautograd_dispatch.h>
|
| 226 |
+
#include <ATen/ops/fliplr_compositeimplicitautograd_dispatch.h>
|
| 227 |
+
#include <ATen/ops/flipud_compositeimplicitautograd_dispatch.h>
|
| 228 |
+
#include <ATen/ops/float_power_compositeimplicitautograd_dispatch.h>
|
| 229 |
+
#include <ATen/ops/frobenius_norm_compositeimplicitautograd_dispatch.h>
|
| 230 |
+
#include <ATen/ops/fused_moving_avg_obs_fake_quant_compositeimplicitautograd_dispatch.h>
|
| 231 |
+
#include <ATen/ops/gather_compositeimplicitautograd_dispatch.h>
|
| 232 |
+
#include <ATen/ops/gather_backward_compositeimplicitautograd_dispatch.h>
|
| 233 |
+
#include <ATen/ops/ger_compositeimplicitautograd_dispatch.h>
|
| 234 |
+
#include <ATen/ops/gradient_compositeimplicitautograd_dispatch.h>
|
| 235 |
+
#include <ATen/ops/greater_compositeimplicitautograd_dispatch.h>
|
| 236 |
+
#include <ATen/ops/greater_equal_compositeimplicitautograd_dispatch.h>
|
| 237 |
+
#include <ATen/ops/grid_sampler_compositeimplicitautograd_dispatch.h>
|
| 238 |
+
#include <ATen/ops/group_norm_compositeimplicitautograd_dispatch.h>
|
| 239 |
+
#include <ATen/ops/gru_compositeimplicitautograd_dispatch.h>
|
| 240 |
+
#include <ATen/ops/gru_cell_compositeimplicitautograd_dispatch.h>
|
| 241 |
+
#include <ATen/ops/hinge_embedding_loss_compositeimplicitautograd_dispatch.h>
|
| 242 |
+
#include <ATen/ops/histogramdd_compositeimplicitautograd_dispatch.h>
|
| 243 |
+
#include <ATen/ops/hsplit_compositeimplicitautograd_dispatch.h>
|
| 244 |
+
#include <ATen/ops/hstack_compositeimplicitautograd_dispatch.h>
|
| 245 |
+
#include <ATen/ops/imag_compositeimplicitautograd_dispatch.h>
|
| 246 |
+
#include <ATen/ops/index_add_compositeimplicitautograd_dispatch.h>
|
| 247 |
+
#include <ATen/ops/index_copy_compositeimplicitautograd_dispatch.h>
|
| 248 |
+
#include <ATen/ops/index_fill_compositeimplicitautograd_dispatch.h>
|
| 249 |
+
#include <ATen/ops/index_select_compositeimplicitautograd_dispatch.h>
|
| 250 |
+
#include <ATen/ops/index_select_backward_compositeimplicitautograd_dispatch.h>
|
| 251 |
+
#include <ATen/ops/infinitely_differentiable_gelu_backward_compositeimplicitautograd_dispatch.h>
|
| 252 |
+
#include <ATen/ops/inner_compositeimplicitautograd_dispatch.h>
|
| 253 |
+
#include <ATen/ops/instance_norm_compositeimplicitautograd_dispatch.h>
|
| 254 |
+
#include <ATen/ops/inverse_compositeimplicitautograd_dispatch.h>
|
| 255 |
+
#include <ATen/ops/is_complex_compositeimplicitautograd_dispatch.h>
|
| 256 |
+
#include <ATen/ops/is_conj_compositeimplicitautograd_dispatch.h>
|
| 257 |
+
#include <ATen/ops/is_distributed_compositeimplicitautograd_dispatch.h>
|
| 258 |
+
#include <ATen/ops/is_floating_point_compositeimplicitautograd_dispatch.h>
|
| 259 |
+
#include <ATen/ops/is_inference_compositeimplicitautograd_dispatch.h>
|
| 260 |
+
#include <ATen/ops/is_leaf_compositeimplicitautograd_dispatch.h>
|
| 261 |
+
#include <ATen/ops/is_neg_compositeimplicitautograd_dispatch.h>
|
| 262 |
+
#include <ATen/ops/is_nonzero_compositeimplicitautograd_dispatch.h>
|
| 263 |
+
#include <ATen/ops/is_signed_compositeimplicitautograd_dispatch.h>
|
| 264 |
+
#include <ATen/ops/is_vulkan_available_compositeimplicitautograd_dispatch.h>
|
| 265 |
+
#include <ATen/ops/isclose_compositeimplicitautograd_dispatch.h>
|
| 266 |
+
#include <ATen/ops/isfinite_compositeimplicitautograd_dispatch.h>
|
| 267 |
+
#include <ATen/ops/isreal_compositeimplicitautograd_dispatch.h>
|
| 268 |
+
#include <ATen/ops/istft_compositeimplicitautograd_dispatch.h>
|
| 269 |
+
#include <ATen/ops/item_compositeimplicitautograd_dispatch.h>
|
| 270 |
+
#include <ATen/ops/kl_div_compositeimplicitautograd_dispatch.h>
|
| 271 |
+
#include <ATen/ops/kron_compositeimplicitautograd_dispatch.h>
|
| 272 |
+
#include <ATen/ops/kthvalue_compositeimplicitautograd_dispatch.h>
|
| 273 |
+
#include <ATen/ops/l1_loss_compositeimplicitautograd_dispatch.h>
|
| 274 |
+
#include <ATen/ops/layer_norm_compositeimplicitautograd_dispatch.h>
|
| 275 |
+
#include <ATen/ops/ldexp_compositeimplicitautograd_dispatch.h>
|
| 276 |
+
#include <ATen/ops/less_compositeimplicitautograd_dispatch.h>
|
| 277 |
+
#include <ATen/ops/less_equal_compositeimplicitautograd_dispatch.h>
|
| 278 |
+
#include <ATen/ops/linalg_cholesky_compositeimplicitautograd_dispatch.h>
|
| 279 |
+
#include <ATen/ops/linalg_cond_compositeimplicitautograd_dispatch.h>
|
| 280 |
+
#include <ATen/ops/linalg_det_compositeimplicitautograd_dispatch.h>
|
| 281 |
+
#include <ATen/ops/linalg_diagonal_compositeimplicitautograd_dispatch.h>
|
| 282 |
+
#include <ATen/ops/linalg_eigh_compositeimplicitautograd_dispatch.h>
|
| 283 |
+
#include <ATen/ops/linalg_eigvals_compositeimplicitautograd_dispatch.h>
|
| 284 |
+
#include <ATen/ops/linalg_eigvalsh_compositeimplicitautograd_dispatch.h>
|
| 285 |
+
#include <ATen/ops/linalg_inv_compositeimplicitautograd_dispatch.h>
|
| 286 |
+
#include <ATen/ops/linalg_ldl_factor_compositeimplicitautograd_dispatch.h>
|
| 287 |
+
#include <ATen/ops/linalg_lu_factor_compositeimplicitautograd_dispatch.h>
|
| 288 |
+
#include <ATen/ops/linalg_matmul_compositeimplicitautograd_dispatch.h>
|
| 289 |
+
#include <ATen/ops/linalg_matrix_norm_compositeimplicitautograd_dispatch.h>
|
| 290 |
+
#include <ATen/ops/linalg_matrix_power_compositeimplicitautograd_dispatch.h>
|
| 291 |
+
#include <ATen/ops/linalg_matrix_rank_compositeimplicitautograd_dispatch.h>
|
| 292 |
+
#include <ATen/ops/linalg_multi_dot_compositeimplicitautograd_dispatch.h>
|
| 293 |
+
#include <ATen/ops/linalg_norm_compositeimplicitautograd_dispatch.h>
|
| 294 |
+
#include <ATen/ops/linalg_pinv_compositeimplicitautograd_dispatch.h>
|
| 295 |
+
#include <ATen/ops/linalg_slogdet_compositeimplicitautograd_dispatch.h>
|
| 296 |
+
#include <ATen/ops/linalg_solve_compositeimplicitautograd_dispatch.h>
|
| 297 |
+
#include <ATen/ops/linalg_solve_ex_compositeimplicitautograd_dispatch.h>
|
| 298 |
+
#include <ATen/ops/linalg_svd_compositeimplicitautograd_dispatch.h>
|
| 299 |
+
#include <ATen/ops/linalg_svdvals_compositeimplicitautograd_dispatch.h>
|
| 300 |
+
#include <ATen/ops/linalg_tensorinv_compositeimplicitautograd_dispatch.h>
|
| 301 |
+
#include <ATen/ops/linalg_tensorsolve_compositeimplicitautograd_dispatch.h>
|
| 302 |
+
#include <ATen/ops/linalg_vander_compositeimplicitautograd_dispatch.h>
|
| 303 |
+
#include <ATen/ops/linalg_vecdot_compositeimplicitautograd_dispatch.h>
|
| 304 |
+
#include <ATen/ops/linear_compositeimplicitautograd_dispatch.h>
|
| 305 |
+
#include <ATen/ops/log_sigmoid_compositeimplicitautograd_dispatch.h>
|
| 306 |
+
#include <ATen/ops/log_softmax_compositeimplicitautograd_dispatch.h>
|
| 307 |
+
#include <ATen/ops/logcumsumexp_compositeimplicitautograd_dispatch.h>
|
| 308 |
+
#include <ATen/ops/logdet_compositeimplicitautograd_dispatch.h>
|
| 309 |
+
#include <ATen/ops/logsumexp_compositeimplicitautograd_dispatch.h>
|
| 310 |
+
#include <ATen/ops/lstm_compositeimplicitautograd_dispatch.h>
|
| 311 |
+
#include <ATen/ops/lstm_cell_compositeimplicitautograd_dispatch.h>
|
| 312 |
+
#include <ATen/ops/lu_solve_compositeimplicitautograd_dispatch.h>
|
| 313 |
+
#include <ATen/ops/mH_compositeimplicitautograd_dispatch.h>
|
| 314 |
+
#include <ATen/ops/mT_compositeimplicitautograd_dispatch.h>
|
| 315 |
+
#include <ATen/ops/margin_ranking_loss_compositeimplicitautograd_dispatch.h>
|
| 316 |
+
#include <ATen/ops/masked_select_backward_compositeimplicitautograd_dispatch.h>
|
| 317 |
+
#include <ATen/ops/matmul_compositeimplicitautograd_dispatch.h>
|
| 318 |
+
#include <ATen/ops/matrix_H_compositeimplicitautograd_dispatch.h>
|
| 319 |
+
#include <ATen/ops/matrix_exp_compositeimplicitautograd_dispatch.h>
|
| 320 |
+
#include <ATen/ops/matrix_exp_backward_compositeimplicitautograd_dispatch.h>
|
| 321 |
+
#include <ATen/ops/matrix_power_compositeimplicitautograd_dispatch.h>
|
| 322 |
+
#include <ATen/ops/max_compositeimplicitautograd_dispatch.h>
|
| 323 |
+
#include <ATen/ops/max_pool1d_compositeimplicitautograd_dispatch.h>
|
| 324 |
+
#include <ATen/ops/max_pool1d_with_indices_compositeimplicitautograd_dispatch.h>
|
| 325 |
+
#include <ATen/ops/max_pool2d_compositeimplicitautograd_dispatch.h>
|
| 326 |
+
#include <ATen/ops/max_pool3d_compositeimplicitautograd_dispatch.h>
|
| 327 |
+
#include <ATen/ops/mean_compositeimplicitautograd_dispatch.h>
|
| 328 |
+
#include <ATen/ops/median_compositeimplicitautograd_dispatch.h>
|
| 329 |
+
#include <ATen/ops/meshgrid_compositeimplicitautograd_dispatch.h>
|
| 330 |
+
#include <ATen/ops/min_compositeimplicitautograd_dispatch.h>
|
| 331 |
+
#include <ATen/ops/mish_backward_compositeimplicitautograd_dispatch.h>
|
| 332 |
+
#include <ATen/ops/mode_compositeimplicitautograd_dispatch.h>
|
| 333 |
+
#include <ATen/ops/moveaxis_compositeimplicitautograd_dispatch.h>
|
| 334 |
+
#include <ATen/ops/movedim_compositeimplicitautograd_dispatch.h>
|
| 335 |
+
#include <ATen/ops/msort_compositeimplicitautograd_dispatch.h>
|
| 336 |
+
#include <ATen/ops/multilabel_margin_loss_compositeimplicitautograd_dispatch.h>
|
| 337 |
+
#include <ATen/ops/multiply_compositeimplicitautograd_dispatch.h>
|
| 338 |
+
#include <ATen/ops/nanmean_compositeimplicitautograd_dispatch.h>
|
| 339 |
+
#include <ATen/ops/nanmedian_compositeimplicitautograd_dispatch.h>
|
| 340 |
+
#include <ATen/ops/nanquantile_compositeimplicitautograd_dispatch.h>
|
| 341 |
+
#include <ATen/ops/narrow_compositeimplicitautograd_dispatch.h>
|
| 342 |
+
#include <ATen/ops/native_channel_shuffle_compositeimplicitautograd_dispatch.h>
|
| 343 |
+
#include <ATen/ops/negative_compositeimplicitautograd_dispatch.h>
|
| 344 |
+
#include <ATen/ops/nested_to_padded_tensor_compositeimplicitautograd_dispatch.h>
|
| 345 |
+
#include <ATen/ops/nll_loss_compositeimplicitautograd_dispatch.h>
|
| 346 |
+
#include <ATen/ops/nll_loss2d_compositeimplicitautograd_dispatch.h>
|
| 347 |
+
#include <ATen/ops/nll_loss_nd_compositeimplicitautograd_dispatch.h>
|
| 348 |
+
#include <ATen/ops/nonzero_numpy_compositeimplicitautograd_dispatch.h>
|
| 349 |
+
#include <ATen/ops/norm_compositeimplicitautograd_dispatch.h>
|
| 350 |
+
#include <ATen/ops/norm_except_dim_compositeimplicitautograd_dispatch.h>
|
| 351 |
+
#include <ATen/ops/not_equal_compositeimplicitautograd_dispatch.h>
|
| 352 |
+
#include <ATen/ops/nuclear_norm_compositeimplicitautograd_dispatch.h>
|
| 353 |
+
#include <ATen/ops/numpy_T_compositeimplicitautograd_dispatch.h>
|
| 354 |
+
#include <ATen/ops/one_hot_compositeimplicitautograd_dispatch.h>
|
| 355 |
+
#include <ATen/ops/or_compositeimplicitautograd_dispatch.h>
|
| 356 |
+
#include <ATen/ops/orgqr_compositeimplicitautograd_dispatch.h>
|
| 357 |
+
#include <ATen/ops/outer_compositeimplicitautograd_dispatch.h>
|
| 358 |
+
#include <ATen/ops/output_nr_compositeimplicitautograd_dispatch.h>
|
| 359 |
+
#include <ATen/ops/pad_compositeimplicitautograd_dispatch.h>
|
| 360 |
+
#include <ATen/ops/pad_sequence_compositeimplicitautograd_dispatch.h>
|
| 361 |
+
#include <ATen/ops/pairwise_distance_compositeimplicitautograd_dispatch.h>
|
| 362 |
+
#include <ATen/ops/pdist_compositeimplicitautograd_dispatch.h>
|
| 363 |
+
#include <ATen/ops/pin_memory_compositeimplicitautograd_dispatch.h>
|
| 364 |
+
#include <ATen/ops/pinverse_compositeimplicitautograd_dispatch.h>
|
| 365 |
+
#include <ATen/ops/poisson_nll_loss_compositeimplicitautograd_dispatch.h>
|
| 366 |
+
#include <ATen/ops/positive_compositeimplicitautograd_dispatch.h>
|
| 367 |
+
#include <ATen/ops/prelu_compositeimplicitautograd_dispatch.h>
|
| 368 |
+
#include <ATen/ops/prod_compositeimplicitautograd_dispatch.h>
|
| 369 |
+
#include <ATen/ops/promote_types_compositeimplicitautograd_dispatch.h>
|
| 370 |
+
#include <ATen/ops/qr_compositeimplicitautograd_dispatch.h>
|
| 371 |
+
#include <ATen/ops/quantile_compositeimplicitautograd_dispatch.h>
|
| 372 |
+
#include <ATen/ops/quantized_gru_cell_compositeimplicitautograd_dispatch.h>
|
| 373 |
+
#include <ATen/ops/quantized_lstm_cell_compositeimplicitautograd_dispatch.h>
|
| 374 |
+
#include <ATen/ops/quantized_rnn_relu_cell_compositeimplicitautograd_dispatch.h>
|
| 375 |
+
#include <ATen/ops/quantized_rnn_tanh_cell_compositeimplicitautograd_dispatch.h>
|
| 376 |
+
#include <ATen/ops/rand_compositeimplicitautograd_dispatch.h>
|
| 377 |
+
#include <ATen/ops/randn_compositeimplicitautograd_dispatch.h>
|
| 378 |
+
#include <ATen/ops/ravel_compositeimplicitautograd_dispatch.h>
|
| 379 |
+
#include <ATen/ops/real_compositeimplicitautograd_dispatch.h>
|
| 380 |
+
#include <ATen/ops/refine_names_compositeimplicitautograd_dispatch.h>
|
| 381 |
+
#include <ATen/ops/relu6_compositeimplicitautograd_dispatch.h>
|
| 382 |
+
#include <ATen/ops/rename_compositeimplicitautograd_dispatch.h>
|
| 383 |
+
#include <ATen/ops/repeat_interleave_compositeimplicitautograd_dispatch.h>
|
| 384 |
+
#include <ATen/ops/requires_grad_compositeimplicitautograd_dispatch.h>
|
| 385 |
+
#include <ATen/ops/reshape_compositeimplicitautograd_dispatch.h>
|
| 386 |
+
#include <ATen/ops/reshape_as_compositeimplicitautograd_dispatch.h>
|
| 387 |
+
#include <ATen/ops/resolve_conj_compositeimplicitautograd_dispatch.h>
|
| 388 |
+
#include <ATen/ops/resolve_neg_compositeimplicitautograd_dispatch.h>
|
| 389 |
+
#include <ATen/ops/result_type_compositeimplicitautograd_dispatch.h>
|
| 390 |
+
#include <ATen/ops/retain_grad_compositeimplicitautograd_dispatch.h>
|
| 391 |
+
#include <ATen/ops/retains_grad_compositeimplicitautograd_dispatch.h>
|
| 392 |
+
#include <ATen/ops/rms_norm_compositeimplicitautograd_dispatch.h>
|
| 393 |
+
#include <ATen/ops/rnn_relu_compositeimplicitautograd_dispatch.h>
|
| 394 |
+
#include <ATen/ops/rnn_relu_cell_compositeimplicitautograd_dispatch.h>
|
| 395 |
+
#include <ATen/ops/rnn_tanh_compositeimplicitautograd_dispatch.h>
|
| 396 |
+
#include <ATen/ops/rnn_tanh_cell_compositeimplicitautograd_dispatch.h>
|
| 397 |
+
#include <ATen/ops/row_stack_compositeimplicitautograd_dispatch.h>
|
| 398 |
+
#include <ATen/ops/rrelu_compositeimplicitautograd_dispatch.h>
|
| 399 |
+
#include <ATen/ops/scaled_dot_product_attention_compositeimplicitautograd_dispatch.h>
|
| 400 |
+
#include <ATen/ops/scatter_compositeimplicitautograd_dispatch.h>
|
| 401 |
+
#include <ATen/ops/scatter_add_compositeimplicitautograd_dispatch.h>
|
| 402 |
+
#include <ATen/ops/select_compositeimplicitautograd_dispatch.h>
|
| 403 |
+
#include <ATen/ops/selu_compositeimplicitautograd_dispatch.h>
|
| 404 |
+
#include <ATen/ops/set_compositeimplicitautograd_dispatch.h>
|
| 405 |
+
#include <ATen/ops/set_data_compositeimplicitautograd_dispatch.h>
|
| 406 |
+
#include <ATen/ops/silu_backward_compositeimplicitautograd_dispatch.h>
|
| 407 |
+
#include <ATen/ops/size_compositeimplicitautograd_dispatch.h>
|
| 408 |
+
#include <ATen/ops/slogdet_compositeimplicitautograd_dispatch.h>
|
| 409 |
+
#include <ATen/ops/slow_conv3d_compositeimplicitautograd_dispatch.h>
|
| 410 |
+
#include <ATen/ops/smm_compositeimplicitautograd_dispatch.h>
|
| 411 |
+
#include <ATen/ops/softmax_compositeimplicitautograd_dispatch.h>
|
| 412 |
+
#include <ATen/ops/sort_compositeimplicitautograd_dispatch.h>
|
| 413 |
+
#include <ATen/ops/sparse_bsc_tensor_compositeimplicitautograd_dispatch.h>
|
| 414 |
+
#include <ATen/ops/sparse_bsr_tensor_compositeimplicitautograd_dispatch.h>
|
| 415 |
+
#include <ATen/ops/sparse_coo_tensor_compositeimplicitautograd_dispatch.h>
|
| 416 |
+
#include <ATen/ops/sparse_csc_tensor_compositeimplicitautograd_dispatch.h>
|
| 417 |
+
#include <ATen/ops/sparse_csr_tensor_compositeimplicitautograd_dispatch.h>
|
| 418 |
+
#include <ATen/ops/special_digamma_compositeimplicitautograd_dispatch.h>
|
| 419 |
+
#include <ATen/ops/special_erf_compositeimplicitautograd_dispatch.h>
|
| 420 |
+
#include <ATen/ops/special_erfc_compositeimplicitautograd_dispatch.h>
|
| 421 |
+
#include <ATen/ops/special_erfinv_compositeimplicitautograd_dispatch.h>
|
| 422 |
+
#include <ATen/ops/special_exp2_compositeimplicitautograd_dispatch.h>
|
| 423 |
+
#include <ATen/ops/special_expit_compositeimplicitautograd_dispatch.h>
|
| 424 |
+
#include <ATen/ops/special_expm1_compositeimplicitautograd_dispatch.h>
|
| 425 |
+
#include <ATen/ops/special_gammainc_compositeimplicitautograd_dispatch.h>
|
| 426 |
+
#include <ATen/ops/special_gammaincc_compositeimplicitautograd_dispatch.h>
|
| 427 |
+
#include <ATen/ops/special_gammaln_compositeimplicitautograd_dispatch.h>
|
| 428 |
+
#include <ATen/ops/special_i0_compositeimplicitautograd_dispatch.h>
|
| 429 |
+
#include <ATen/ops/special_log1p_compositeimplicitautograd_dispatch.h>
|
| 430 |
+
#include <ATen/ops/special_log_softmax_compositeimplicitautograd_dispatch.h>
|
| 431 |
+
#include <ATen/ops/special_logit_compositeimplicitautograd_dispatch.h>
|
| 432 |
+
#include <ATen/ops/special_logsumexp_compositeimplicitautograd_dispatch.h>
|
| 433 |
+
#include <ATen/ops/special_multigammaln_compositeimplicitautograd_dispatch.h>
|
| 434 |
+
#include <ATen/ops/special_ndtr_compositeimplicitautograd_dispatch.h>
|
| 435 |
+
#include <ATen/ops/special_polygamma_compositeimplicitautograd_dispatch.h>
|
| 436 |
+
#include <ATen/ops/special_psi_compositeimplicitautograd_dispatch.h>
|
| 437 |
+
#include <ATen/ops/special_round_compositeimplicitautograd_dispatch.h>
|
| 438 |
+
#include <ATen/ops/special_sinc_compositeimplicitautograd_dispatch.h>
|
| 439 |
+
#include <ATen/ops/special_softmax_compositeimplicitautograd_dispatch.h>
|
| 440 |
+
#include <ATen/ops/special_xlogy_compositeimplicitautograd_dispatch.h>
|
| 441 |
+
#include <ATen/ops/split_compositeimplicitautograd_dispatch.h>
|
| 442 |
+
#include <ATen/ops/square_compositeimplicitautograd_dispatch.h>
|
| 443 |
+
#include <ATen/ops/squeeze_compositeimplicitautograd_dispatch.h>
|
| 444 |
+
#include <ATen/ops/sspaddmm_compositeimplicitautograd_dispatch.h>
|
| 445 |
+
#include <ATen/ops/std_compositeimplicitautograd_dispatch.h>
|
| 446 |
+
#include <ATen/ops/std_mean_compositeimplicitautograd_dispatch.h>
|
| 447 |
+
#include <ATen/ops/stft_compositeimplicitautograd_dispatch.h>
|
| 448 |
+
#include <ATen/ops/stride_compositeimplicitautograd_dispatch.h>
|
| 449 |
+
#include <ATen/ops/subtract_compositeimplicitautograd_dispatch.h>
|
| 450 |
+
#include <ATen/ops/sum_compositeimplicitautograd_dispatch.h>
|
| 451 |
+
#include <ATen/ops/sum_to_size_compositeimplicitautograd_dispatch.h>
|
| 452 |
+
#include <ATen/ops/svd_compositeimplicitautograd_dispatch.h>
|
| 453 |
+
#include <ATen/ops/swapaxes_compositeimplicitautograd_dispatch.h>
|
| 454 |
+
#include <ATen/ops/swapdims_compositeimplicitautograd_dispatch.h>
|
| 455 |
+
#include <ATen/ops/sym_numel_compositeimplicitautograd_dispatch.h>
|
| 456 |
+
#include <ATen/ops/sym_size_compositeimplicitautograd_dispatch.h>
|
| 457 |
+
#include <ATen/ops/sym_storage_offset_compositeimplicitautograd_dispatch.h>
|
| 458 |
+
#include <ATen/ops/sym_stride_compositeimplicitautograd_dispatch.h>
|
| 459 |
+
#include <ATen/ops/take_along_dim_compositeimplicitautograd_dispatch.h>
|
| 460 |
+
#include <ATen/ops/tensor_split_compositeimplicitautograd_dispatch.h>
|
| 461 |
+
#include <ATen/ops/tensordot_compositeimplicitautograd_dispatch.h>
|
| 462 |
+
#include <ATen/ops/thnn_conv2d_compositeimplicitautograd_dispatch.h>
|
| 463 |
+
#include <ATen/ops/tile_compositeimplicitautograd_dispatch.h>
|
| 464 |
+
#include <ATen/ops/to_compositeimplicitautograd_dispatch.h>
|
| 465 |
+
#include <ATen/ops/to_dense_compositeimplicitautograd_dispatch.h>
|
| 466 |
+
#include <ATen/ops/to_dense_backward_compositeimplicitautograd_dispatch.h>
|
| 467 |
+
#include <ATen/ops/to_mkldnn_backward_compositeimplicitautograd_dispatch.h>
|
| 468 |
+
#include <ATen/ops/to_sparse_compositeimplicitautograd_dispatch.h>
|
| 469 |
+
#include <ATen/ops/to_sparse_bsc_compositeimplicitautograd_dispatch.h>
|
| 470 |
+
#include <ATen/ops/to_sparse_bsr_compositeimplicitautograd_dispatch.h>
|
| 471 |
+
#include <ATen/ops/to_sparse_csc_compositeimplicitautograd_dispatch.h>
|
| 472 |
+
#include <ATen/ops/to_sparse_csr_compositeimplicitautograd_dispatch.h>
|
| 473 |
+
#include <ATen/ops/trace_backward_compositeimplicitautograd_dispatch.h>
|
| 474 |
+
#include <ATen/ops/transpose_compositeimplicitautograd_dispatch.h>
|
| 475 |
+
#include <ATen/ops/trapezoid_compositeimplicitautograd_dispatch.h>
|
| 476 |
+
#include <ATen/ops/trapz_compositeimplicitautograd_dispatch.h>
|
| 477 |
+
#include <ATen/ops/triplet_margin_loss_compositeimplicitautograd_dispatch.h>
|
| 478 |
+
#include <ATen/ops/true_divide_compositeimplicitautograd_dispatch.h>
|
| 479 |
+
#include <ATen/ops/type_as_compositeimplicitautograd_dispatch.h>
|
| 480 |
+
#include <ATen/ops/unbind_compositeimplicitautograd_dispatch.h>
|
| 481 |
+
#include <ATen/ops/unflatten_compositeimplicitautograd_dispatch.h>
|
| 482 |
+
#include <ATen/ops/unflatten_dense_tensors_compositeimplicitautograd_dispatch.h>
|
| 483 |
+
#include <ATen/ops/unsafe_chunk_compositeimplicitautograd_dispatch.h>
|
| 484 |
+
#include <ATen/ops/upsample_bicubic2d_compositeimplicitautograd_dispatch.h>
|
| 485 |
+
#include <ATen/ops/upsample_bilinear2d_compositeimplicitautograd_dispatch.h>
|
| 486 |
+
#include <ATen/ops/upsample_linear1d_compositeimplicitautograd_dispatch.h>
|
| 487 |
+
#include <ATen/ops/upsample_nearest1d_compositeimplicitautograd_dispatch.h>
|
| 488 |
+
#include <ATen/ops/upsample_nearest2d_compositeimplicitautograd_dispatch.h>
|
| 489 |
+
#include <ATen/ops/upsample_nearest3d_compositeimplicitautograd_dispatch.h>
|
| 490 |
+
#include <ATen/ops/upsample_trilinear3d_compositeimplicitautograd_dispatch.h>
|
| 491 |
+
#include <ATen/ops/value_selecting_reduction_backward_compositeimplicitautograd_dispatch.h>
|
| 492 |
+
#include <ATen/ops/vander_compositeimplicitautograd_dispatch.h>
|
| 493 |
+
#include <ATen/ops/var_compositeimplicitautograd_dispatch.h>
|
| 494 |
+
#include <ATen/ops/var_mean_compositeimplicitautograd_dispatch.h>
|
| 495 |
+
#include <ATen/ops/view_as_compositeimplicitautograd_dispatch.h>
|
| 496 |
+
#include <ATen/ops/vsplit_compositeimplicitautograd_dispatch.h>
|
| 497 |
+
#include <ATen/ops/vstack_compositeimplicitautograd_dispatch.h>
|
| 498 |
+
#include <ATen/ops/where_compositeimplicitautograd_dispatch.h>
|
| 499 |
+
#include <ATen/ops/xor_compositeimplicitautograd_dispatch.h>
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
|
.venv/lib/python3.11/site-packages/torch/include/ATen/DLConvertor.h
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/ATen.h>
|
| 4 |
+
#include <ATen/Tensor.h>
|
| 5 |
+
#include <ATen/dlpack.h>
|
| 6 |
+
|
| 7 |
+
// this convertor will:
|
| 8 |
+
// 1) take a Tensor object and wrap it in the DLPack tensor
|
| 9 |
+
// 2) take a dlpack tensor and convert it to the ATen Tensor
|
| 10 |
+
|
| 11 |
+
namespace at {
|
| 12 |
+
|
| 13 |
+
TORCH_API ScalarType toScalarType(const DLDataType& dtype);
|
| 14 |
+
TORCH_API DLManagedTensor* toDLPack(const Tensor& src);
|
| 15 |
+
TORCH_API Tensor fromDLPack(DLManagedTensor* src);
|
| 16 |
+
C10_DEPRECATED_MESSAGE("Please migrate to a non-const variant")
|
| 17 |
+
inline Tensor fromDLPack(const DLManagedTensor* src) {
|
| 18 |
+
return fromDLPack(const_cast<DLManagedTensor*>(src));
|
| 19 |
+
}
|
| 20 |
+
TORCH_API Tensor
|
| 21 |
+
fromDLPack(DLManagedTensor* src, std::function<void(void*)> deleter);
|
| 22 |
+
TORCH_API DLDataType getDLDataType(const Tensor& t);
|
| 23 |
+
TORCH_API DLDevice getDLContext(const Tensor& tensor, const int64_t& device_id);
|
| 24 |
+
|
| 25 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/DeviceGuard.h
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/IListRef.h>
|
| 4 |
+
#include <ATen/core/Tensor.h>
|
| 5 |
+
#include <c10/core/DeviceGuard.h>
|
| 6 |
+
#include <c10/core/ScalarType.h> // TensorList whyyyyy
|
| 7 |
+
|
| 8 |
+
namespace at {
|
| 9 |
+
|
| 10 |
+
// Are you here because you're wondering why DeviceGuard(tensor) no
|
| 11 |
+
// longer works? For code organization reasons, we have temporarily(?)
|
| 12 |
+
// removed this constructor from DeviceGuard. The new way to
|
| 13 |
+
// spell it is:
|
| 14 |
+
//
|
| 15 |
+
// OptionalDeviceGuard guard(device_of(tensor));
|
| 16 |
+
|
| 17 |
+
/// Return the Device of a Tensor, if the Tensor is defined.
|
| 18 |
+
inline std::optional<Device> device_of(const Tensor& t) {
|
| 19 |
+
if (t.defined()) {
|
| 20 |
+
return std::make_optional(t.device());
|
| 21 |
+
} else {
|
| 22 |
+
return std::nullopt;
|
| 23 |
+
}
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
inline std::optional<Device> device_of(const std::optional<Tensor>& t) {
|
| 27 |
+
return t.has_value() ? device_of(t.value()) : std::nullopt;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
/// Return the Device of a TensorList, if the list is non-empty and
|
| 31 |
+
/// the first Tensor is defined. (This function implicitly assumes
|
| 32 |
+
/// that all tensors in the list have the same device.)
|
| 33 |
+
inline std::optional<Device> device_of(ITensorListRef t) {
|
| 34 |
+
if (!t.empty()) {
|
| 35 |
+
return device_of(t.front());
|
| 36 |
+
} else {
|
| 37 |
+
return std::nullopt;
|
| 38 |
+
}
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/Dispatch_v2.h
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/Dispatch.h>
|
| 2 |
+
|
| 3 |
+
// This is a new implementation of the AT_DISPATCH macro family from
|
| 4 |
+
// ATen/Dispatch.h
|
| 5 |
+
//
|
| 6 |
+
// The intended usage is:
|
| 7 |
+
//
|
| 8 |
+
// ScalarType scalar_type;
|
| 9 |
+
//
|
| 10 |
+
// AT_DISPATCH_V2(
|
| 11 |
+
// scalar_type,
|
| 12 |
+
// "debug string",
|
| 13 |
+
// AT_WRAP([&] {
|
| 14 |
+
// ... code to specialize with scalar_t ...
|
| 15 |
+
// }),
|
| 16 |
+
// kHalf,
|
| 17 |
+
// AT_EXPAND(AT_ALL_TYPES),
|
| 18 |
+
// ... as many types arguments as needed ...
|
| 19 |
+
// )
|
| 20 |
+
//
|
| 21 |
+
// For example, given an old style:
|
| 22 |
+
//
|
| 23 |
+
// AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
|
| 24 |
+
// kComplexHalf,
|
| 25 |
+
// kHalf,
|
| 26 |
+
// self.scalar_type(),
|
| 27 |
+
// "_local_scalar_dense_cpu",
|
| 28 |
+
// [&] {
|
| 29 |
+
// scalar_t value = *self.data_ptr<scalar_t>();
|
| 30 |
+
// r = Scalar(value);
|
| 31 |
+
// }
|
| 32 |
+
// )
|
| 33 |
+
//
|
| 34 |
+
// You now write:
|
| 35 |
+
//
|
| 36 |
+
// AT_DISPATCH_V2(
|
| 37 |
+
// self.scalar_type(),
|
| 38 |
+
// "_local_scalar_dense_cpu",
|
| 39 |
+
// AT_WRAP([&] {
|
| 40 |
+
// scalar_t value = *self.data_ptr<scalar_t>();
|
| 41 |
+
// r = Scalar(value);
|
| 42 |
+
// }),
|
| 43 |
+
// AT_EXPAND(AT_ALL_TYPES),
|
| 44 |
+
// AT_EXPAND(AT_COMPLEX_TYPES),
|
| 45 |
+
// kComplexHalf,
|
| 46 |
+
// kHalf,
|
| 47 |
+
// )
|
| 48 |
+
//
|
| 49 |
+
// Notably, it sports the following improvements:
|
| 50 |
+
//
|
| 51 |
+
// - It is not necessary to specify the arity (e.g.,
|
| 52 |
+
// AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND{2,3,4,...})
|
| 53 |
+
// when using the macro
|
| 54 |
+
//
|
| 55 |
+
// - It is not necessary to specify each dtype individually; if
|
| 56 |
+
// there is a set of related dtypes and you want to dispatch
|
| 57 |
+
// over all of them, you can simply say, e.g., AT_EXPAND(AT_INTEGRAL_TYPES)
|
| 58 |
+
// in your argument list.
|
| 59 |
+
//
|
| 60 |
+
// However, you must remember to wrap the payload body in AT_WRAP, or commas
|
| 61 |
+
// inside your lambda will be improperly handled. Furthermore, if you more
|
| 62 |
+
// entries to ScalarType than can be supported by this macro, it will fail
|
| 63 |
+
// with an obscure error (due to attempting to concatenate AT_AP with
|
| 64 |
+
// something that is not a number).
|
| 65 |
+
//
|
| 66 |
+
// The implementation strategy is to use the count arguments trick
|
| 67 |
+
// (e.g., as described in https://stackoverflow.com/a/2124385/23845)
|
| 68 |
+
// to discover how many dtypes have been passed, and then dispatch to a
|
| 69 |
+
// hand-written macro for each arity that applies as many DISPATCH_CASE as
|
| 70 |
+
// necessary. The hand-written macros can be regenerated for other arities
|
| 71 |
+
// with the script below.
|
| 72 |
+
//
|
| 73 |
+
// There is some delicacy in the implementation in controlling when
|
| 74 |
+
// macro expansion occurs, mediated with AT_EXPAND and AT_GUARD. I mostly
|
| 75 |
+
// relied on GPT4 to help me get it right.
|
| 76 |
+
|
| 77 |
+
// Public API macros
|
| 78 |
+
|
| 79 |
+
// See documentation above
|
| 80 |
+
#define AT_DISPATCH_V2(TYPE, NAME, BODY, ...) \
|
| 81 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, AT_AP_VAR(AT_WRAP(BODY), TYPE, __VA_ARGS__))
|
| 82 |
+
|
| 83 |
+
// This macro lets you pass an arbitrary expression that may contain internal
|
| 84 |
+
// commas to another macro without having the commas causing the expression
|
| 85 |
+
// to be interpreted as being multiple arguments
|
| 86 |
+
#define AT_WRAP(...) __VA_ARGS__
|
| 87 |
+
|
| 88 |
+
#define AT_FLOAT8_TYPES \
|
| 89 |
+
c10::kFloat8_e5m2, c10::kFloat8_e5m2fnuz, c10::kFloat8_e4m3fn, \
|
| 90 |
+
c10::kFloat8_e4m3fnuz
|
| 91 |
+
|
| 92 |
+
#define AT_INTEGRAL_TYPES \
|
| 93 |
+
c10::kByte, c10::kChar, c10::kInt, c10::kLong, c10::kShort
|
| 94 |
+
#define AT_FLOATING_TYPES c10::kDouble, c10::kFloat
|
| 95 |
+
#define AT_BAREBONES_UNSIGNED_TYPES c10::kUInt16, c10::kUInt32, c10::kUInt64
|
| 96 |
+
#define AT_INTEGRAL_TYPES_V2 \
|
| 97 |
+
AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
|
| 98 |
+
#define AT_COMPLEX_TYPES c10::kComplexDouble, c10::kComplexFloat
|
| 99 |
+
#define AT_QINT_TYPES c10::kQInt8, c10::kQUInt8, c10::kQInt32
|
| 100 |
+
// NB: not *actually* all types
|
| 101 |
+
#define AT_ALL_TYPES AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)
|
| 102 |
+
#define AT_ALL_TYPES_AND_COMPLEX \
|
| 103 |
+
AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES)
|
| 104 |
+
|
| 105 |
+
// Helper macros
|
| 106 |
+
|
| 107 |
+
#define AT_AP_VAR(N, T, ...) \
|
| 108 |
+
AT_EXPAND(AT_CONCAT(AT_AP, AT_NUM_ARGS(__VA_ARGS__))(AT_WRAP(N), __VA_ARGS__))
|
| 109 |
+
#define AT_CONCAT(a, b) AT_CONCAT_AUX(a, b)
|
| 110 |
+
#define AT_CONCAT_AUX(a, b) a##b
|
| 111 |
+
#define AT_EXPAND(X) X
|
| 112 |
+
|
| 113 |
+
// Ensure we never have too many scalar types for the expansion here to
|
| 114 |
+
// support. To bump this, you must regenerate the macros below.
|
| 115 |
+
static_assert(static_cast<int>(c10::ScalarType::NumOptions) < 45);
|
| 116 |
+
|
| 117 |
+
// Python code to regenerate generate code below:
|
| 118 |
+
#if 0
|
| 119 |
+
|
| 120 |
+
num_args = 45
|
| 121 |
+
|
| 122 |
+
nums = ', '.join(str(i) for i in reversed(range(num_args+1)))
|
| 123 |
+
args = ', '.join(f'_{i}' for i in range(1, num_args+1))
|
| 124 |
+
|
| 125 |
+
print(f'#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, {nums}))')
|
| 126 |
+
print(f'#define AT_NUM_ARGS_AUX({args}, N, ...) N')
|
| 127 |
+
|
| 128 |
+
for i in range(1, num_args+1):
|
| 129 |
+
args = ', '.join(f'_{i}' for i in range(1, i+1))
|
| 130 |
+
cases = ' '.join([f'AT_DISPATCH_CASE(_{j}, N)' for j in range(1, i+1)])
|
| 131 |
+
print(f'#define AT_AP{i}(N, {args}) {cases}')
|
| 132 |
+
|
| 133 |
+
#endif
|
| 134 |
+
|
| 135 |
+
// Begin generated code
|
| 136 |
+
// clang-format off
|
| 137 |
+
|
| 138 |
+
#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0))
|
| 139 |
+
#define AT_NUM_ARGS_AUX(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, N, ...) N
|
| 140 |
+
#define AT_AP1(N, _1) AT_DISPATCH_CASE(_1, N)
|
| 141 |
+
#define AT_AP2(N, _1, _2) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N)
|
| 142 |
+
#define AT_AP3(N, _1, _2, _3) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N)
|
| 143 |
+
#define AT_AP4(N, _1, _2, _3, _4) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N)
|
| 144 |
+
#define AT_AP5(N, _1, _2, _3, _4, _5) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N)
|
| 145 |
+
#define AT_AP6(N, _1, _2, _3, _4, _5, _6) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N)
|
| 146 |
+
#define AT_AP7(N, _1, _2, _3, _4, _5, _6, _7) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N)
|
| 147 |
+
#define AT_AP8(N, _1, _2, _3, _4, _5, _6, _7, _8) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N)
|
| 148 |
+
#define AT_AP9(N, _1, _2, _3, _4, _5, _6, _7, _8, _9) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N)
|
| 149 |
+
#define AT_AP10(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N)
|
| 150 |
+
#define AT_AP11(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N)
|
| 151 |
+
#define AT_AP12(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N)
|
| 152 |
+
#define AT_AP13(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N)
|
| 153 |
+
#define AT_AP14(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N)
|
| 154 |
+
#define AT_AP15(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N)
|
| 155 |
+
#define AT_AP16(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N)
|
| 156 |
+
#define AT_AP17(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N)
|
| 157 |
+
#define AT_AP18(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N)
|
| 158 |
+
#define AT_AP19(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N)
|
| 159 |
+
#define AT_AP20(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N)
|
| 160 |
+
#define AT_AP21(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N)
|
| 161 |
+
#define AT_AP22(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N)
|
| 162 |
+
#define AT_AP23(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N)
|
| 163 |
+
#define AT_AP24(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N)
|
| 164 |
+
#define AT_AP25(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N)
|
| 165 |
+
#define AT_AP26(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N)
|
| 166 |
+
#define AT_AP27(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N)
|
| 167 |
+
#define AT_AP28(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N)
|
| 168 |
+
#define AT_AP29(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N)
|
| 169 |
+
#define AT_AP30(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N)
|
| 170 |
+
#define AT_AP31(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N)
|
| 171 |
+
#define AT_AP32(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N)
|
| 172 |
+
#define AT_AP33(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N)
|
| 173 |
+
#define AT_AP34(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N)
|
| 174 |
+
#define AT_AP35(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N)
|
| 175 |
+
#define AT_AP36(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N)
|
| 176 |
+
#define AT_AP37(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N)
|
| 177 |
+
#define AT_AP38(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N)
|
| 178 |
+
#define AT_AP39(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N)
|
| 179 |
+
#define AT_AP40(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N)
|
| 180 |
+
#define AT_AP41(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N)
|
| 181 |
+
#define AT_AP42(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N)
|
| 182 |
+
#define AT_AP43(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N)
|
| 183 |
+
#define AT_AP44(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N)
|
| 184 |
+
#define AT_AP45(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N)
|
| 185 |
+
// End generated code
|
| 186 |
+
// clang-format on
|
.venv/lib/python3.11/site-packages/torch/include/ATen/DynamicLibrary.h
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/Utils.h>
|
| 4 |
+
#include <c10/macros/Export.h>
|
| 5 |
+
#include <c10/util/Exception.h>
|
| 6 |
+
|
| 7 |
+
namespace c10 {
|
| 8 |
+
|
| 9 |
+
class DynamicLibraryError : public Error {
|
| 10 |
+
using Error::Error;
|
| 11 |
+
};
|
| 12 |
+
|
| 13 |
+
} // namespace c10
|
| 14 |
+
|
| 15 |
+
namespace at {
|
| 16 |
+
|
| 17 |
+
struct DynamicLibrary {
|
| 18 |
+
AT_DISALLOW_COPY_AND_ASSIGN(DynamicLibrary);
|
| 19 |
+
|
| 20 |
+
TORCH_API DynamicLibrary(
|
| 21 |
+
const char* name,
|
| 22 |
+
const char* alt_name = nullptr,
|
| 23 |
+
bool leak_handle = false);
|
| 24 |
+
|
| 25 |
+
TORCH_API void* sym(const char* name);
|
| 26 |
+
|
| 27 |
+
TORCH_API ~DynamicLibrary();
|
| 28 |
+
|
| 29 |
+
private:
|
| 30 |
+
bool leak_handle;
|
| 31 |
+
void* handle = nullptr;
|
| 32 |
+
};
|
| 33 |
+
|
| 34 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/ExpandUtils.h
ADDED
|
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 4 |
+
#include <ATen/Functions.h>
|
| 5 |
+
#else
|
| 6 |
+
#include <ATen/ops/view.h>
|
| 7 |
+
#include <ATen/ops/view_copy.h>
|
| 8 |
+
#endif
|
| 9 |
+
|
| 10 |
+
#include <ATen/Tensor.h>
|
| 11 |
+
#include <ATen/core/DimVector.h>
|
| 12 |
+
#include <c10/util/Exception.h>
|
| 13 |
+
#include <c10/util/MaybeOwned.h>
|
| 14 |
+
#include <c10/util/irange.h>
|
| 15 |
+
|
| 16 |
+
#include <functional>
|
| 17 |
+
#include <tuple>
|
| 18 |
+
#include <utility>
|
| 19 |
+
|
| 20 |
+
namespace at {
|
| 21 |
+
|
| 22 |
+
TORCH_API std::vector<int64_t> infer_size(IntArrayRef a, IntArrayRef b);
|
| 23 |
+
TORCH_API std::vector<SymInt> infer_size_symint(
|
| 24 |
+
SymIntArrayRef a,
|
| 25 |
+
SymIntArrayRef b);
|
| 26 |
+
TORCH_API DimVector infer_size_dimvector(IntArrayRef a, IntArrayRef b);
|
| 27 |
+
TORCH_API SymDimVector
|
| 28 |
+
infer_size_symdimvector(SymIntArrayRef a, SymIntArrayRef b);
|
| 29 |
+
|
| 30 |
+
// Named type instead of a pair/tuple so that we can be sure to
|
| 31 |
+
// construct the vectors in place and get NRVO.
|
| 32 |
+
template <typename Container>
|
| 33 |
+
struct InferExpandGeometryResult {
|
| 34 |
+
Container sizes;
|
| 35 |
+
Container strides;
|
| 36 |
+
explicit InferExpandGeometryResult(size_t ndim)
|
| 37 |
+
: sizes(ndim), strides(ndim) {}
|
| 38 |
+
explicit InferExpandGeometryResult(IntArrayRef sizes_, size_t ndim)
|
| 39 |
+
: sizes(sizes_.begin(), sizes_.end()), strides(ndim) {}
|
| 40 |
+
};
|
| 41 |
+
|
| 42 |
+
TORCH_API std::tuple<std::vector<int64_t>, std::vector<int64_t>>
|
| 43 |
+
inferExpandGeometry(
|
| 44 |
+
IntArrayRef tensor_sizes,
|
| 45 |
+
IntArrayRef tensor_strides,
|
| 46 |
+
IntArrayRef sizes);
|
| 47 |
+
|
| 48 |
+
TORCH_API InferExpandGeometryResult<DimVector> inferExpandGeometry_dimvector(
|
| 49 |
+
IntArrayRef tensor_sizes,
|
| 50 |
+
IntArrayRef tensor_strides,
|
| 51 |
+
IntArrayRef sizes);
|
| 52 |
+
|
| 53 |
+
TORCH_API std::vector<int64_t> infer_dense_strides(
|
| 54 |
+
IntArrayRef tensor_sizes,
|
| 55 |
+
IntArrayRef tensor_strides);
|
| 56 |
+
|
| 57 |
+
// True if input shapes are expandable
|
| 58 |
+
// NOTE: infer_size did a similar check, please keep them sync if change is
|
| 59 |
+
// needed
|
| 60 |
+
inline bool are_expandable(IntArrayRef shape1, IntArrayRef shape2) {
|
| 61 |
+
size_t ndim1 = shape1.size();
|
| 62 |
+
size_t ndim2 = shape2.size();
|
| 63 |
+
size_t ndim = ndim1 < ndim2 ? ndim1 : ndim2;
|
| 64 |
+
|
| 65 |
+
for (int64_t i = static_cast<int64_t>(ndim) - 1; i >= 0; --i) {
|
| 66 |
+
if (shape1[--ndim1] == shape2[--ndim2] || shape1[ndim1] == 1 ||
|
| 67 |
+
shape2[ndim2] == 1) {
|
| 68 |
+
continue;
|
| 69 |
+
}
|
| 70 |
+
return false;
|
| 71 |
+
}
|
| 72 |
+
return true;
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
// avoid copy-construction of Tensor by using a reference_wrapper.
|
| 76 |
+
inline void check_defined(
|
| 77 |
+
std::initializer_list<std::reference_wrapper<const Tensor>> tensors,
|
| 78 |
+
const char* api_name) {
|
| 79 |
+
for (auto& t : tensors) {
|
| 80 |
+
if (!t.get().defined()) {
|
| 81 |
+
AT_ERROR(api_name, "(...) called with an undefined Tensor");
|
| 82 |
+
}
|
| 83 |
+
}
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
// NOTE [ ExpandUtils Borrowing ]
|
| 87 |
+
//
|
| 88 |
+
// Functions in ExpandUtils return `c10::MaybeOwned<Tensor>` because
|
| 89 |
+
// expansion may not actually be needed, in which case we can improve
|
| 90 |
+
// efficiency by returning
|
| 91 |
+
// `c10::MaybeOwned<Tensor>::borrowed(to_expand)`. However, this means
|
| 92 |
+
// that you need to be careful: the returned `c10::MaybeOwned<Tensor>`
|
| 93 |
+
// must not outlive the original `Tensor` object that `to_expand`
|
| 94 |
+
// referred to! The deleted rvalue reference overloads of these
|
| 95 |
+
// functions help with this by preventing trivial use of a temporary
|
| 96 |
+
// resulting from a function call, but it is still possible to make a
|
| 97 |
+
// mistake.
|
| 98 |
+
|
| 99 |
+
inline c10::MaybeOwned<Tensor> expand_inplace(
|
| 100 |
+
const Tensor& tensor,
|
| 101 |
+
const Tensor& to_expand) {
|
| 102 |
+
if (tensor.sym_sizes().equals(to_expand.sym_sizes())) {
|
| 103 |
+
return c10::MaybeOwned<Tensor>::borrowed(to_expand);
|
| 104 |
+
}
|
| 105 |
+
return c10::MaybeOwned<Tensor>::owned(
|
| 106 |
+
to_expand.expand_symint(tensor.sym_sizes()));
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
inline c10::MaybeOwned<Tensor> expand_inplace(
|
| 110 |
+
const Tensor& tensor,
|
| 111 |
+
Tensor&& to_expand) = delete;
|
| 112 |
+
|
| 113 |
+
inline c10::MaybeOwned<Tensor> expand_inplace(
|
| 114 |
+
const Tensor& tensor,
|
| 115 |
+
const Tensor& to_expand,
|
| 116 |
+
const char* api_name) {
|
| 117 |
+
check_defined({tensor, to_expand}, api_name);
|
| 118 |
+
return expand_inplace(tensor, to_expand);
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
inline c10::MaybeOwned<Tensor> expand_inplace(
|
| 122 |
+
const Tensor& tensor,
|
| 123 |
+
Tensor&& to_expand,
|
| 124 |
+
const char* api_name) = delete;
|
| 125 |
+
|
| 126 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 127 |
+
expand_inplace(
|
| 128 |
+
const Tensor& tensor,
|
| 129 |
+
const Tensor& to_expand1,
|
| 130 |
+
const Tensor& to_expand2) {
|
| 131 |
+
if (tensor.sizes().equals(to_expand1.sizes()) &&
|
| 132 |
+
tensor.sizes().equals((to_expand2.sizes()))) {
|
| 133 |
+
return std::make_tuple(
|
| 134 |
+
c10::MaybeOwned<Tensor>::borrowed(to_expand1),
|
| 135 |
+
c10::MaybeOwned<Tensor>::borrowed(to_expand2));
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
return std::make_tuple(
|
| 139 |
+
c10::MaybeOwned<Tensor>::owned(to_expand1.expand(tensor.sizes())),
|
| 140 |
+
c10::MaybeOwned<Tensor>::owned(to_expand2.expand(tensor.sizes())));
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 144 |
+
expand_inplace(
|
| 145 |
+
const Tensor& tensor,
|
| 146 |
+
Tensor&& to_expand1,
|
| 147 |
+
const Tensor& to_expand2) = delete;
|
| 148 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 149 |
+
expand_inplace(
|
| 150 |
+
const Tensor& tensor,
|
| 151 |
+
const Tensor& to_expand1,
|
| 152 |
+
Tensor&& to_expand2) = delete;
|
| 153 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 154 |
+
expand_inplace(const Tensor& tensor, Tensor&& to_expand1, Tensor&& to_expand2) =
|
| 155 |
+
delete;
|
| 156 |
+
|
| 157 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 158 |
+
expand_inplace(
|
| 159 |
+
const Tensor& tensor,
|
| 160 |
+
const Tensor& to_expand1,
|
| 161 |
+
const Tensor& to_expand2,
|
| 162 |
+
const char* api_name) {
|
| 163 |
+
check_defined({tensor, to_expand1, to_expand2}, api_name);
|
| 164 |
+
return expand_inplace(tensor, to_expand1, to_expand2);
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 168 |
+
expand_inplace(
|
| 169 |
+
const Tensor& tensor,
|
| 170 |
+
Tensor&& to_expand1,
|
| 171 |
+
const Tensor& to_expand2,
|
| 172 |
+
const char* api_name) = delete;
|
| 173 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 174 |
+
expand_inplace(
|
| 175 |
+
const Tensor& tensor,
|
| 176 |
+
const Tensor& to_expand1,
|
| 177 |
+
Tensor&& to_expand2,
|
| 178 |
+
const char* api_name) = delete;
|
| 179 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 180 |
+
expand_inplace(
|
| 181 |
+
const Tensor& tensor,
|
| 182 |
+
Tensor&& to_expand1,
|
| 183 |
+
Tensor&& to_expand2,
|
| 184 |
+
const char* api_name) = delete;
|
| 185 |
+
|
| 186 |
+
// See NOTE [ ExpandUtils Borrowing ] above for `MaybeOwned` explanation.
|
| 187 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 188 |
+
expand_outplace(const Tensor& to_expand1, const Tensor& to_expand2) {
|
| 189 |
+
auto s1 = to_expand1.sym_sizes();
|
| 190 |
+
auto s2 = to_expand2.sym_sizes();
|
| 191 |
+
if (s1.equals(s2)) {
|
| 192 |
+
return std::make_tuple(
|
| 193 |
+
c10::MaybeOwned<Tensor>::borrowed(to_expand1),
|
| 194 |
+
c10::MaybeOwned<Tensor>::borrowed(to_expand2));
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
auto expanded_size = infer_size_symdimvector(s1, s2);
|
| 198 |
+
return std::make_tuple(
|
| 199 |
+
c10::MaybeOwned<Tensor>::owned(to_expand1.expand_symint(expanded_size)),
|
| 200 |
+
c10::MaybeOwned<Tensor>::owned(to_expand2.expand_symint(expanded_size)));
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 204 |
+
expand_outplace(Tensor&& to_expand1, const Tensor& to_expand2) = delete;
|
| 205 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 206 |
+
expand_outplace(const Tensor& to_expand1, Tensor&& to_expand2) = delete;
|
| 207 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 208 |
+
expand_outplace(Tensor&& to_expand1, Tensor&& to_expand2) = delete;
|
| 209 |
+
|
| 210 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 211 |
+
expand_outplace(
|
| 212 |
+
const Tensor& to_expand1,
|
| 213 |
+
const Tensor& to_expand2,
|
| 214 |
+
const char* api_name) {
|
| 215 |
+
check_defined({to_expand1, to_expand2}, api_name);
|
| 216 |
+
return expand_outplace(to_expand1, to_expand2);
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 220 |
+
expand_outplace(
|
| 221 |
+
Tensor&& to_expand1,
|
| 222 |
+
const Tensor& to_expand2,
|
| 223 |
+
const char* api_name) = delete;
|
| 224 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 225 |
+
expand_outplace(
|
| 226 |
+
const Tensor& to_expand1,
|
| 227 |
+
Tensor&& to_expand2,
|
| 228 |
+
const char* api_name) = delete;
|
| 229 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 230 |
+
expand_outplace(
|
| 231 |
+
Tensor&& to_expand1,
|
| 232 |
+
Tensor&& to_expand2,
|
| 233 |
+
const char* api_name) = delete;
|
| 234 |
+
|
| 235 |
+
inline std::tuple<
|
| 236 |
+
c10::MaybeOwned<Tensor>,
|
| 237 |
+
c10::MaybeOwned<Tensor>,
|
| 238 |
+
c10::MaybeOwned<Tensor>>
|
| 239 |
+
expand_outplace(
|
| 240 |
+
const Tensor& to_expand1,
|
| 241 |
+
const Tensor& to_expand2,
|
| 242 |
+
const Tensor& to_expand3) {
|
| 243 |
+
if (to_expand1.sizes().equals(to_expand2.sizes()) &&
|
| 244 |
+
to_expand1.sizes().equals(to_expand3.sizes())) {
|
| 245 |
+
return std::make_tuple(
|
| 246 |
+
c10::MaybeOwned<Tensor>::borrowed(to_expand1),
|
| 247 |
+
c10::MaybeOwned<Tensor>::borrowed(to_expand2),
|
| 248 |
+
c10::MaybeOwned<Tensor>::borrowed(to_expand3));
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
auto expanded_size12 =
|
| 252 |
+
infer_size_dimvector(to_expand1.sizes(), to_expand2.sizes());
|
| 253 |
+
auto expanded_size =
|
| 254 |
+
infer_size_dimvector(expanded_size12, to_expand3.sizes());
|
| 255 |
+
return std::make_tuple(
|
| 256 |
+
c10::MaybeOwned<Tensor>::owned(to_expand1.expand(expanded_size)),
|
| 257 |
+
c10::MaybeOwned<Tensor>::owned(to_expand2.expand(expanded_size)),
|
| 258 |
+
c10::MaybeOwned<Tensor>::owned(to_expand3.expand(expanded_size)));
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
inline std::tuple<
|
| 262 |
+
c10::MaybeOwned<Tensor>,
|
| 263 |
+
c10::MaybeOwned<Tensor>,
|
| 264 |
+
c10::MaybeOwned<Tensor>>
|
| 265 |
+
expand_outplace(
|
| 266 |
+
Tensor&& to_expand1,
|
| 267 |
+
const Tensor& to_expand2,
|
| 268 |
+
const Tensor& to_expand3) = delete;
|
| 269 |
+
inline std::tuple<
|
| 270 |
+
c10::MaybeOwned<Tensor>,
|
| 271 |
+
c10::MaybeOwned<Tensor>,
|
| 272 |
+
c10::MaybeOwned<Tensor>>
|
| 273 |
+
expand_outplace(
|
| 274 |
+
const Tensor& to_expand1,
|
| 275 |
+
Tensor&& to_expand2,
|
| 276 |
+
const Tensor& to_expand3) = delete;
|
| 277 |
+
inline std::tuple<
|
| 278 |
+
c10::MaybeOwned<Tensor>,
|
| 279 |
+
c10::MaybeOwned<Tensor>,
|
| 280 |
+
c10::MaybeOwned<Tensor>>
|
| 281 |
+
expand_outplace(
|
| 282 |
+
Tensor&& to_expand1,
|
| 283 |
+
Tensor&& to_expand2,
|
| 284 |
+
const Tensor& to_expand3) = delete;
|
| 285 |
+
inline std::tuple<
|
| 286 |
+
c10::MaybeOwned<Tensor>,
|
| 287 |
+
c10::MaybeOwned<Tensor>,
|
| 288 |
+
c10::MaybeOwned<Tensor>>
|
| 289 |
+
expand_outplace(
|
| 290 |
+
const Tensor& to_expand1,
|
| 291 |
+
const Tensor& to_expand2,
|
| 292 |
+
Tensor&& to_expand3) = delete;
|
| 293 |
+
inline std::tuple<
|
| 294 |
+
c10::MaybeOwned<Tensor>,
|
| 295 |
+
c10::MaybeOwned<Tensor>,
|
| 296 |
+
c10::MaybeOwned<Tensor>>
|
| 297 |
+
expand_outplace(
|
| 298 |
+
Tensor&& to_expand1,
|
| 299 |
+
const Tensor& to_expand2,
|
| 300 |
+
Tensor&& to_expand3) = delete;
|
| 301 |
+
inline std::tuple<
|
| 302 |
+
c10::MaybeOwned<Tensor>,
|
| 303 |
+
c10::MaybeOwned<Tensor>,
|
| 304 |
+
c10::MaybeOwned<Tensor>>
|
| 305 |
+
expand_outplace(
|
| 306 |
+
const Tensor& to_expand1,
|
| 307 |
+
Tensor&& to_expand2,
|
| 308 |
+
Tensor&& to_expand3) = delete;
|
| 309 |
+
inline std::tuple<
|
| 310 |
+
c10::MaybeOwned<Tensor>,
|
| 311 |
+
c10::MaybeOwned<Tensor>,
|
| 312 |
+
c10::MaybeOwned<Tensor>>
|
| 313 |
+
expand_outplace(Tensor&& to_expand1, Tensor&& to_expand2, Tensor&& to_expand3) =
|
| 314 |
+
delete;
|
| 315 |
+
|
| 316 |
+
inline std::tuple<
|
| 317 |
+
c10::MaybeOwned<Tensor>,
|
| 318 |
+
c10::MaybeOwned<Tensor>,
|
| 319 |
+
c10::MaybeOwned<Tensor>>
|
| 320 |
+
expand_outplace(
|
| 321 |
+
const Tensor& to_expand1,
|
| 322 |
+
const Tensor& to_expand2,
|
| 323 |
+
const Tensor& to_expand3,
|
| 324 |
+
const char* api_name) {
|
| 325 |
+
check_defined({to_expand1, to_expand2, to_expand3}, api_name);
|
| 326 |
+
return expand_outplace(to_expand1, to_expand2, to_expand3);
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
inline std::tuple<
|
| 330 |
+
c10::MaybeOwned<Tensor>,
|
| 331 |
+
c10::MaybeOwned<Tensor>,
|
| 332 |
+
c10::MaybeOwned<Tensor>>
|
| 333 |
+
expand_outplace(
|
| 334 |
+
Tensor&& to_expand1,
|
| 335 |
+
const Tensor& to_expand2,
|
| 336 |
+
const Tensor& to_expand3,
|
| 337 |
+
const char* api_name) = delete;
|
| 338 |
+
inline std::tuple<
|
| 339 |
+
c10::MaybeOwned<Tensor>,
|
| 340 |
+
c10::MaybeOwned<Tensor>,
|
| 341 |
+
c10::MaybeOwned<Tensor>>
|
| 342 |
+
expand_outplace(
|
| 343 |
+
const Tensor& to_expand1,
|
| 344 |
+
Tensor&& to_expand2,
|
| 345 |
+
const Tensor& to_expand3,
|
| 346 |
+
const char* api_name) = delete;
|
| 347 |
+
inline std::tuple<
|
| 348 |
+
c10::MaybeOwned<Tensor>,
|
| 349 |
+
c10::MaybeOwned<Tensor>,
|
| 350 |
+
c10::MaybeOwned<Tensor>>
|
| 351 |
+
expand_outplace(
|
| 352 |
+
Tensor&& to_expand1,
|
| 353 |
+
Tensor&& to_expand2,
|
| 354 |
+
const Tensor& to_expand3,
|
| 355 |
+
const char* api_name) = delete;
|
| 356 |
+
inline std::tuple<
|
| 357 |
+
c10::MaybeOwned<Tensor>,
|
| 358 |
+
c10::MaybeOwned<Tensor>,
|
| 359 |
+
c10::MaybeOwned<Tensor>>
|
| 360 |
+
expand_outplace(
|
| 361 |
+
const Tensor& to_expand1,
|
| 362 |
+
const Tensor& to_expand2,
|
| 363 |
+
Tensor&& to_expand3,
|
| 364 |
+
const char* api_name) = delete;
|
| 365 |
+
inline std::tuple<
|
| 366 |
+
c10::MaybeOwned<Tensor>,
|
| 367 |
+
c10::MaybeOwned<Tensor>,
|
| 368 |
+
c10::MaybeOwned<Tensor>>
|
| 369 |
+
expand_outplace(
|
| 370 |
+
Tensor&& to_expand1,
|
| 371 |
+
const Tensor& to_expand2,
|
| 372 |
+
Tensor&& to_expand3,
|
| 373 |
+
const char* api_name) = delete;
|
| 374 |
+
inline std::tuple<
|
| 375 |
+
c10::MaybeOwned<Tensor>,
|
| 376 |
+
c10::MaybeOwned<Tensor>,
|
| 377 |
+
c10::MaybeOwned<Tensor>>
|
| 378 |
+
expand_outplace(
|
| 379 |
+
const Tensor& to_expand1,
|
| 380 |
+
Tensor&& to_expand2,
|
| 381 |
+
Tensor&& to_expand3,
|
| 382 |
+
const char* api_name) = delete;
|
| 383 |
+
inline std::tuple<
|
| 384 |
+
c10::MaybeOwned<Tensor>,
|
| 385 |
+
c10::MaybeOwned<Tensor>,
|
| 386 |
+
c10::MaybeOwned<Tensor>>
|
| 387 |
+
expand_outplace(
|
| 388 |
+
Tensor&& to_expand1,
|
| 389 |
+
Tensor&& to_expand2,
|
| 390 |
+
Tensor&& to_expand3,
|
| 391 |
+
const char* api_name) = delete;
|
| 392 |
+
|
| 393 |
+
inline c10::MaybeOwned<Tensor> expand_size(
|
| 394 |
+
const Tensor& to_expand,
|
| 395 |
+
IntArrayRef sizes) {
|
| 396 |
+
if (to_expand.sizes().equals(sizes)) {
|
| 397 |
+
return c10::MaybeOwned<Tensor>::borrowed(to_expand);
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
return c10::MaybeOwned<Tensor>::owned(to_expand.expand(sizes));
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
inline c10::MaybeOwned<Tensor> expand_size(
|
| 404 |
+
Tensor&& to_expand,
|
| 405 |
+
IntArrayRef sizes) = delete;
|
| 406 |
+
|
| 407 |
+
inline c10::MaybeOwned<Tensor> expand_size(
|
| 408 |
+
const Tensor& to_expand,
|
| 409 |
+
IntArrayRef sizes,
|
| 410 |
+
const char* api_name) {
|
| 411 |
+
check_defined({to_expand}, api_name);
|
| 412 |
+
return expand_size(to_expand, sizes);
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
inline c10::MaybeOwned<Tensor> expand_size(
|
| 416 |
+
Tensor&& to_expand,
|
| 417 |
+
IntArrayRef sizes,
|
| 418 |
+
const char* api_name) = delete;
|
| 419 |
+
|
| 420 |
+
inline std::vector<Tensor> expand_outplace(TensorList to_expand) {
|
| 421 |
+
// expands a list of Tensors; ignores undefined (null) tensors
|
| 422 |
+
bool first = true;
|
| 423 |
+
DimVector sizes;
|
| 424 |
+
for (const auto i : c10::irange(to_expand.size())) {
|
| 425 |
+
if (!to_expand[i].defined()) {
|
| 426 |
+
continue;
|
| 427 |
+
} else if (first) {
|
| 428 |
+
sizes = to_expand[i].sizes();
|
| 429 |
+
first = false;
|
| 430 |
+
} else {
|
| 431 |
+
sizes = infer_size_dimvector(sizes, to_expand[i].sizes());
|
| 432 |
+
}
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
std::vector<Tensor> result(to_expand.size());
|
| 436 |
+
for (const auto i : c10::irange(to_expand.size())) {
|
| 437 |
+
if (!to_expand[i].defined()) {
|
| 438 |
+
continue;
|
| 439 |
+
} else if (to_expand[i].sizes().equals(sizes)) {
|
| 440 |
+
result[i] = to_expand[i];
|
| 441 |
+
} else {
|
| 442 |
+
result[i] = to_expand[i].expand(sizes);
|
| 443 |
+
}
|
| 444 |
+
}
|
| 445 |
+
return result;
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
template <typename T>
|
| 449 |
+
inline Tensor _sum_to(
|
| 450 |
+
Tensor tensor,
|
| 451 |
+
const c10::ArrayRef<T> shape,
|
| 452 |
+
bool always_return_non_view = false) {
|
| 453 |
+
if (shape.size() == 0) {
|
| 454 |
+
return tensor.sum();
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
auto sizes = at::symint::sizes<T>(tensor);
|
| 458 |
+
c10::SmallVector<int64_t, 8> reduce_dims;
|
| 459 |
+
const int64_t leading_dims = sizes.size() - shape.size();
|
| 460 |
+
for (const auto i : c10::irange(leading_dims)) {
|
| 461 |
+
reduce_dims.push_back(i);
|
| 462 |
+
}
|
| 463 |
+
for (int64_t i = leading_dims; i < static_cast<int64_t>(sizes.size()); ++i) {
|
| 464 |
+
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(shape[i - leading_dims], 1)) &&
|
| 465 |
+
TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(sizes[i], 1))) {
|
| 466 |
+
reduce_dims.push_back(i);
|
| 467 |
+
}
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
if (!reduce_dims.empty()) {
|
| 471 |
+
tensor = tensor.sum(reduce_dims, /*keepdim=*/true);
|
| 472 |
+
}
|
| 473 |
+
|
| 474 |
+
if (always_return_non_view) {
|
| 475 |
+
// This is only actually used by the functionalization pass.
|
| 476 |
+
// We want to be able to guarantee that this function doesn't return a view
|
| 477 |
+
// of the input.
|
| 478 |
+
return leading_dims > 0 ? at::symint::view_copy<T>(tensor, shape)
|
| 479 |
+
: tensor.clone();
|
| 480 |
+
} else {
|
| 481 |
+
return leading_dims > 0 ? at::symint::view<T>(tensor, shape) : tensor;
|
| 482 |
+
}
|
| 483 |
+
}
|
| 484 |
+
|
| 485 |
+
inline Tensor sum_to(
|
| 486 |
+
Tensor tensor,
|
| 487 |
+
const c10::SymIntArrayRef shape,
|
| 488 |
+
bool always_return_non_view = false) {
|
| 489 |
+
return _sum_to(std::move(tensor), shape, always_return_non_view);
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
// Sums `tensor` repeatedly to produce a tensor of shape `shape`.
|
| 493 |
+
// Precondition: is_expandable_to(shape, tensor.sizes()) must be true
|
| 494 |
+
inline Tensor sum_to(
|
| 495 |
+
Tensor tensor,
|
| 496 |
+
const IntArrayRef shape,
|
| 497 |
+
bool always_return_non_view = false) {
|
| 498 |
+
return _sum_to(std::move(tensor), shape, always_return_non_view);
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
inline bool is_expandable_to(
|
| 502 |
+
SymIntArrayRef shape,
|
| 503 |
+
c10::SymIntArrayRef desired) {
|
| 504 |
+
size_t ndim = shape.size();
|
| 505 |
+
size_t target_dim = desired.size();
|
| 506 |
+
if (ndim > target_dim) {
|
| 507 |
+
return false;
|
| 508 |
+
}
|
| 509 |
+
for (const auto i : c10::irange(ndim)) {
|
| 510 |
+
const auto& size = shape[ndim - i - 1];
|
| 511 |
+
const auto& target = desired[target_dim - i - 1];
|
| 512 |
+
if (size != target && size != 1) {
|
| 513 |
+
return false;
|
| 514 |
+
}
|
| 515 |
+
}
|
| 516 |
+
return true;
|
| 517 |
+
}
|
| 518 |
+
|
| 519 |
+
inline bool is_expandable_to(IntArrayRef shape, IntArrayRef desired) {
|
| 520 |
+
auto sym_shape = c10::SymIntArrayRef(
|
| 521 |
+
reinterpret_cast<const c10::SymInt*>(shape.data()), shape.size());
|
| 522 |
+
auto sym_desired = c10::SymIntArrayRef(
|
| 523 |
+
reinterpret_cast<const c10::SymInt*>(desired.data()), desired.size());
|
| 524 |
+
return is_expandable_to(sym_shape, sym_desired);
|
| 525 |
+
}
|
| 526 |
+
|
| 527 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/Functions.h
ADDED
|
@@ -0,0 +1,1454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Functions.h
|
| 4 |
+
|
| 5 |
+
#ifdef TORCH_ASSERT_NO_OPERATORS
|
| 6 |
+
#error This change adds a dependency on native_functions.yaml, \
|
| 7 |
+
meaning the file will need to be re-compiled every time an operator \
|
| 8 |
+
is changed or added. Consider if your change would be better placed in \
|
| 9 |
+
another file, or if a more specific header might achieve the same goal. \
|
| 10 |
+
See NOTE: [Tensor vs. TensorBase]
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 14 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 15 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 16 |
+
Consider including a specific operator from <ATen/ops/{my_operator}.h> and \
|
| 17 |
+
see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 18 |
+
#endif
|
| 19 |
+
|
| 20 |
+
// NOTE: [TORCH_ASSERT_ONLY_METHOD_OPERATORS]
|
| 21 |
+
//
|
| 22 |
+
// In ATen, certain generated headers files include the definitions of
|
| 23 |
+
// every single operator in PyTorch. Unfortunately this means every
|
| 24 |
+
// time an operator signature is updated or changed in
|
| 25 |
+
// native_functions.yaml, you (and every other PyTorch developer) need
|
| 26 |
+
// to recompile every source file that includes any of these headers.
|
| 27 |
+
//
|
| 28 |
+
// To break up these header dependencies, and improve incremental
|
| 29 |
+
// build times for all PyTorch developers. These headers are split
|
| 30 |
+
// into per-operator headers in the `ATen/ops` folder. This limits
|
| 31 |
+
// incremental builds to only changes to methods of `Tensor`, or files
|
| 32 |
+
// that use the specific operator being changed. With `at::sum` as an
|
| 33 |
+
// example, you should include
|
| 34 |
+
//
|
| 35 |
+
// <ATen/ops/sum.h> // instead of ATen/Functions.h
|
| 36 |
+
// <ATen/ops/sum_native.h> // instead of ATen/NativeFunctions.h
|
| 37 |
+
// <ATen/ops/sum_ops.h> // instead of ATen/Operators.h
|
| 38 |
+
// <ATen/ops/sum_cpu_dispatch.h> // instead of ATen/CPUFunctions.h
|
| 39 |
+
//
|
| 40 |
+
// However, even if you're careful to use this in your own code.
|
| 41 |
+
// `Functions.h` might be included indirectly through another header
|
| 42 |
+
// without you realising. To avoid this, you can add
|
| 43 |
+
//
|
| 44 |
+
// #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
| 45 |
+
//
|
| 46 |
+
// to the top of your source file. This way any time the non-specific
|
| 47 |
+
// headers are included, the compiler will error out.
|
| 48 |
+
//
|
| 49 |
+
// Also, be aware that `ops` are not available in all build
|
| 50 |
+
// configurations (namely fb-internal) so you must guard these
|
| 51 |
+
// includes with `#ifdef AT_PER_OPERATOR_HEADERS`. e.g.
|
| 52 |
+
//
|
| 53 |
+
// #ifndef AT_PER_OPERATOR_HEADERS
|
| 54 |
+
// #include <ATen/Functions.h>
|
| 55 |
+
// #else
|
| 56 |
+
// #include <ATen/ops/sum.h>
|
| 57 |
+
// #endif
|
| 58 |
+
|
| 59 |
+
#include <ATen/Context.h>
|
| 60 |
+
#include <ATen/DeviceGuard.h>
|
| 61 |
+
#include <ATen/TensorUtils.h>
|
| 62 |
+
#include <ATen/TracerMode.h>
|
| 63 |
+
#include <ATen/core/Generator.h>
|
| 64 |
+
#include <ATen/core/Reduction.h>
|
| 65 |
+
#include <c10/core/SymInt.h>
|
| 66 |
+
#include <ATen/core/Tensor.h>
|
| 67 |
+
#include <c10/core/Scalar.h>
|
| 68 |
+
#include <c10/core/Storage.h>
|
| 69 |
+
#include <c10/core/TensorOptions.h>
|
| 70 |
+
#include <c10/util/Deprecated.h>
|
| 71 |
+
#include <optional>
|
| 72 |
+
#include <c10/util/OptionalArrayRef.h>
|
| 73 |
+
|
| 74 |
+
#include <ATen/ops/from_blob.h>
|
| 75 |
+
#include <ATen/ops/tensor.h>
|
| 76 |
+
|
| 77 |
+
#include <ATen/ops/_adaptive_avg_pool2d.h>
|
| 78 |
+
#include <ATen/ops/_adaptive_avg_pool2d_backward.h>
|
| 79 |
+
#include <ATen/ops/_adaptive_avg_pool3d.h>
|
| 80 |
+
#include <ATen/ops/_adaptive_avg_pool3d_backward.h>
|
| 81 |
+
#include <ATen/ops/_add_batch_dim.h>
|
| 82 |
+
#include <ATen/ops/_add_relu.h>
|
| 83 |
+
#include <ATen/ops/_addmm_activation.h>
|
| 84 |
+
#include <ATen/ops/_aminmax.h>
|
| 85 |
+
#include <ATen/ops/_amp_foreach_non_finite_check_and_unscale.h>
|
| 86 |
+
#include <ATen/ops/_amp_update_scale.h>
|
| 87 |
+
#include <ATen/ops/_assert_async.h>
|
| 88 |
+
#include <ATen/ops/_assert_scalar.h>
|
| 89 |
+
#include <ATen/ops/_assert_tensor_metadata.h>
|
| 90 |
+
#include <ATen/ops/_autocast_to_full_precision.h>
|
| 91 |
+
#include <ATen/ops/_autocast_to_reduced_precision.h>
|
| 92 |
+
#include <ATen/ops/_backward.h>
|
| 93 |
+
#include <ATen/ops/_batch_norm_impl_index.h>
|
| 94 |
+
#include <ATen/ops/_batch_norm_impl_index_backward.h>
|
| 95 |
+
#include <ATen/ops/_batch_norm_no_update.h>
|
| 96 |
+
#include <ATen/ops/_batch_norm_with_update.h>
|
| 97 |
+
#include <ATen/ops/_cast_Byte.h>
|
| 98 |
+
#include <ATen/ops/_cast_Char.h>
|
| 99 |
+
#include <ATen/ops/_cast_Double.h>
|
| 100 |
+
#include <ATen/ops/_cast_Float.h>
|
| 101 |
+
#include <ATen/ops/_cast_Half.h>
|
| 102 |
+
#include <ATen/ops/_cast_Int.h>
|
| 103 |
+
#include <ATen/ops/_cast_Long.h>
|
| 104 |
+
#include <ATen/ops/_cast_Short.h>
|
| 105 |
+
#include <ATen/ops/_cdist_backward.h>
|
| 106 |
+
#include <ATen/ops/_cdist_forward.h>
|
| 107 |
+
#include <ATen/ops/_cholesky_solve_helper.h>
|
| 108 |
+
#include <ATen/ops/_choose_qparams_per_tensor.h>
|
| 109 |
+
#include <ATen/ops/_chunk_cat.h>
|
| 110 |
+
#include <ATen/ops/_coalesce.h>
|
| 111 |
+
#include <ATen/ops/_coalesced.h>
|
| 112 |
+
#include <ATen/ops/_compute_linear_combination.h>
|
| 113 |
+
#include <ATen/ops/_conj.h>
|
| 114 |
+
#include <ATen/ops/_conj_copy.h>
|
| 115 |
+
#include <ATen/ops/_conj_physical.h>
|
| 116 |
+
#include <ATen/ops/_conv_depthwise2d.h>
|
| 117 |
+
#include <ATen/ops/_convert_indices_from_coo_to_csr.h>
|
| 118 |
+
#include <ATen/ops/_convert_indices_from_csr_to_coo.h>
|
| 119 |
+
#include <ATen/ops/_convert_weight_to_int4pack.h>
|
| 120 |
+
#include <ATen/ops/_convolution.h>
|
| 121 |
+
#include <ATen/ops/_convolution_double_backward.h>
|
| 122 |
+
#include <ATen/ops/_convolution_mode.h>
|
| 123 |
+
#include <ATen/ops/_copy_from.h>
|
| 124 |
+
#include <ATen/ops/_copy_from_and_resize.h>
|
| 125 |
+
#include <ATen/ops/_cslt_compress.h>
|
| 126 |
+
#include <ATen/ops/_cslt_sparse_mm.h>
|
| 127 |
+
#include <ATen/ops/_cslt_sparse_mm_search.h>
|
| 128 |
+
#include <ATen/ops/_ctc_loss.h>
|
| 129 |
+
#include <ATen/ops/_ctc_loss_backward.h>
|
| 130 |
+
#include <ATen/ops/_cudnn_ctc_loss.h>
|
| 131 |
+
#include <ATen/ops/_cudnn_init_dropout_state.h>
|
| 132 |
+
#include <ATen/ops/_cudnn_rnn.h>
|
| 133 |
+
#include <ATen/ops/_cudnn_rnn_backward.h>
|
| 134 |
+
#include <ATen/ops/_cudnn_rnn_flatten_weight.h>
|
| 135 |
+
#include <ATen/ops/_cufft_clear_plan_cache.h>
|
| 136 |
+
#include <ATen/ops/_cufft_get_plan_cache_max_size.h>
|
| 137 |
+
#include <ATen/ops/_cufft_get_plan_cache_size.h>
|
| 138 |
+
#include <ATen/ops/_cufft_set_plan_cache_max_size.h>
|
| 139 |
+
#include <ATen/ops/_cummax_helper.h>
|
| 140 |
+
#include <ATen/ops/_cummin_helper.h>
|
| 141 |
+
#include <ATen/ops/_debug_has_internal_overlap.h>
|
| 142 |
+
#include <ATen/ops/_dimI.h>
|
| 143 |
+
#include <ATen/ops/_dimV.h>
|
| 144 |
+
#include <ATen/ops/_dim_arange.h>
|
| 145 |
+
#include <ATen/ops/_dirichlet_grad.h>
|
| 146 |
+
#include <ATen/ops/_efficient_attention_backward.h>
|
| 147 |
+
#include <ATen/ops/_efficient_attention_forward.h>
|
| 148 |
+
#include <ATen/ops/_efficientzerotensor.h>
|
| 149 |
+
#include <ATen/ops/_embedding_bag.h>
|
| 150 |
+
#include <ATen/ops/_embedding_bag_backward.h>
|
| 151 |
+
#include <ATen/ops/_embedding_bag_dense_backward.h>
|
| 152 |
+
#include <ATen/ops/_embedding_bag_forward_only.h>
|
| 153 |
+
#include <ATen/ops/_embedding_bag_per_sample_weights_backward.h>
|
| 154 |
+
#include <ATen/ops/_embedding_bag_sparse_backward.h>
|
| 155 |
+
#include <ATen/ops/_empty_affine_quantized.h>
|
| 156 |
+
#include <ATen/ops/_empty_per_channel_affine_quantized.h>
|
| 157 |
+
#include <ATen/ops/_euclidean_dist.h>
|
| 158 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine.h>
|
| 159 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_backward.h>
|
| 160 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine.h>
|
| 161 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_backward.h>
|
| 162 |
+
#include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams.h>
|
| 163 |
+
#include <ATen/ops/_fft_c2c.h>
|
| 164 |
+
#include <ATen/ops/_fft_c2r.h>
|
| 165 |
+
#include <ATen/ops/_fft_r2c.h>
|
| 166 |
+
#include <ATen/ops/_fill_mem_eff_dropout_mask.h>
|
| 167 |
+
#include <ATen/ops/_flash_attention_backward.h>
|
| 168 |
+
#include <ATen/ops/_flash_attention_forward.h>
|
| 169 |
+
#include <ATen/ops/_foobar.h>
|
| 170 |
+
#include <ATen/ops/_foreach_abs.h>
|
| 171 |
+
#include <ATen/ops/_foreach_acos.h>
|
| 172 |
+
#include <ATen/ops/_foreach_add.h>
|
| 173 |
+
#include <ATen/ops/_foreach_addcdiv.h>
|
| 174 |
+
#include <ATen/ops/_foreach_addcmul.h>
|
| 175 |
+
#include <ATen/ops/_foreach_asin.h>
|
| 176 |
+
#include <ATen/ops/_foreach_atan.h>
|
| 177 |
+
#include <ATen/ops/_foreach_ceil.h>
|
| 178 |
+
#include <ATen/ops/_foreach_clamp_max.h>
|
| 179 |
+
#include <ATen/ops/_foreach_clamp_min.h>
|
| 180 |
+
#include <ATen/ops/_foreach_copy.h>
|
| 181 |
+
#include <ATen/ops/_foreach_cos.h>
|
| 182 |
+
#include <ATen/ops/_foreach_cosh.h>
|
| 183 |
+
#include <ATen/ops/_foreach_div.h>
|
| 184 |
+
#include <ATen/ops/_foreach_erf.h>
|
| 185 |
+
#include <ATen/ops/_foreach_erfc.h>
|
| 186 |
+
#include <ATen/ops/_foreach_exp.h>
|
| 187 |
+
#include <ATen/ops/_foreach_expm1.h>
|
| 188 |
+
#include <ATen/ops/_foreach_floor.h>
|
| 189 |
+
#include <ATen/ops/_foreach_frac.h>
|
| 190 |
+
#include <ATen/ops/_foreach_lerp.h>
|
| 191 |
+
#include <ATen/ops/_foreach_lgamma.h>
|
| 192 |
+
#include <ATen/ops/_foreach_log.h>
|
| 193 |
+
#include <ATen/ops/_foreach_log10.h>
|
| 194 |
+
#include <ATen/ops/_foreach_log1p.h>
|
| 195 |
+
#include <ATen/ops/_foreach_log2.h>
|
| 196 |
+
#include <ATen/ops/_foreach_max.h>
|
| 197 |
+
#include <ATen/ops/_foreach_maximum.h>
|
| 198 |
+
#include <ATen/ops/_foreach_minimum.h>
|
| 199 |
+
#include <ATen/ops/_foreach_mul.h>
|
| 200 |
+
#include <ATen/ops/_foreach_neg.h>
|
| 201 |
+
#include <ATen/ops/_foreach_norm.h>
|
| 202 |
+
#include <ATen/ops/_foreach_pow.h>
|
| 203 |
+
#include <ATen/ops/_foreach_reciprocal.h>
|
| 204 |
+
#include <ATen/ops/_foreach_round.h>
|
| 205 |
+
#include <ATen/ops/_foreach_sigmoid.h>
|
| 206 |
+
#include <ATen/ops/_foreach_sign.h>
|
| 207 |
+
#include <ATen/ops/_foreach_sin.h>
|
| 208 |
+
#include <ATen/ops/_foreach_sinh.h>
|
| 209 |
+
#include <ATen/ops/_foreach_sqrt.h>
|
| 210 |
+
#include <ATen/ops/_foreach_sub.h>
|
| 211 |
+
#include <ATen/ops/_foreach_tan.h>
|
| 212 |
+
#include <ATen/ops/_foreach_tanh.h>
|
| 213 |
+
#include <ATen/ops/_foreach_trunc.h>
|
| 214 |
+
#include <ATen/ops/_foreach_zero.h>
|
| 215 |
+
#include <ATen/ops/_functional_assert_async.h>
|
| 216 |
+
#include <ATen/ops/_functional_assert_scalar.h>
|
| 217 |
+
#include <ATen/ops/_functional_sym_constrain_range.h>
|
| 218 |
+
#include <ATen/ops/_functional_sym_constrain_range_for_size.h>
|
| 219 |
+
#include <ATen/ops/_fused_adagrad.h>
|
| 220 |
+
#include <ATen/ops/_fused_adam.h>
|
| 221 |
+
#include <ATen/ops/_fused_adamw.h>
|
| 222 |
+
#include <ATen/ops/_fused_dropout.h>
|
| 223 |
+
#include <ATen/ops/_fused_moving_avg_obs_fq_helper.h>
|
| 224 |
+
#include <ATen/ops/_fused_sdp_choice.h>
|
| 225 |
+
#include <ATen/ops/_fused_sgd.h>
|
| 226 |
+
#include <ATen/ops/_fw_primal.h>
|
| 227 |
+
#include <ATen/ops/_fw_primal_copy.h>
|
| 228 |
+
#include <ATen/ops/_gather_sparse_backward.h>
|
| 229 |
+
#include <ATen/ops/_grid_sampler_2d_cpu_fallback.h>
|
| 230 |
+
#include <ATen/ops/_grid_sampler_2d_cpu_fallback_backward.h>
|
| 231 |
+
#include <ATen/ops/_has_compatible_shallow_copy_type.h>
|
| 232 |
+
#include <ATen/ops/_has_same_storage_numel.h>
|
| 233 |
+
#include <ATen/ops/_histogramdd_bin_edges.h>
|
| 234 |
+
#include <ATen/ops/_histogramdd_from_bin_cts.h>
|
| 235 |
+
#include <ATen/ops/_histogramdd_from_bin_tensors.h>
|
| 236 |
+
#include <ATen/ops/_index_put_impl.h>
|
| 237 |
+
#include <ATen/ops/_indices.h>
|
| 238 |
+
#include <ATen/ops/_indices_copy.h>
|
| 239 |
+
#include <ATen/ops/_int_mm.h>
|
| 240 |
+
#include <ATen/ops/_is_all_true.h>
|
| 241 |
+
#include <ATen/ops/_is_any_true.h>
|
| 242 |
+
#include <ATen/ops/_is_zerotensor.h>
|
| 243 |
+
#include <ATen/ops/_jagged_to_padded_dense_forward.h>
|
| 244 |
+
#include <ATen/ops/_lazy_clone.h>
|
| 245 |
+
#include <ATen/ops/_linalg_check_errors.h>
|
| 246 |
+
#include <ATen/ops/_linalg_det.h>
|
| 247 |
+
#include <ATen/ops/_linalg_eigh.h>
|
| 248 |
+
#include <ATen/ops/_linalg_eigvals.h>
|
| 249 |
+
#include <ATen/ops/_linalg_slogdet.h>
|
| 250 |
+
#include <ATen/ops/_linalg_solve_ex.h>
|
| 251 |
+
#include <ATen/ops/_linalg_svd.h>
|
| 252 |
+
#include <ATen/ops/_local_scalar_dense.h>
|
| 253 |
+
#include <ATen/ops/_log_softmax.h>
|
| 254 |
+
#include <ATen/ops/_log_softmax_backward_data.h>
|
| 255 |
+
#include <ATen/ops/_logcumsumexp.h>
|
| 256 |
+
#include <ATen/ops/_lstm_mps.h>
|
| 257 |
+
#include <ATen/ops/_lu_with_info.h>
|
| 258 |
+
#include <ATen/ops/_make_dep_token.h>
|
| 259 |
+
#include <ATen/ops/_make_dual.h>
|
| 260 |
+
#include <ATen/ops/_make_dual_copy.h>
|
| 261 |
+
#include <ATen/ops/_make_per_channel_quantized_tensor.h>
|
| 262 |
+
#include <ATen/ops/_make_per_tensor_quantized_tensor.h>
|
| 263 |
+
#include <ATen/ops/_masked_scale.h>
|
| 264 |
+
#include <ATen/ops/_masked_softmax.h>
|
| 265 |
+
#include <ATen/ops/_masked_softmax_backward.h>
|
| 266 |
+
#include <ATen/ops/_mixed_dtypes_linear.h>
|
| 267 |
+
#include <ATen/ops/_mkldnn_reshape.h>
|
| 268 |
+
#include <ATen/ops/_mkldnn_transpose.h>
|
| 269 |
+
#include <ATen/ops/_mps_convolution.h>
|
| 270 |
+
#include <ATen/ops/_mps_convolution_transpose.h>
|
| 271 |
+
#include <ATen/ops/_native_batch_norm_legit.h>
|
| 272 |
+
#include <ATen/ops/_native_batch_norm_legit_no_training.h>
|
| 273 |
+
#include <ATen/ops/_native_multi_head_attention.h>
|
| 274 |
+
#include <ATen/ops/_neg_view.h>
|
| 275 |
+
#include <ATen/ops/_neg_view_copy.h>
|
| 276 |
+
#include <ATen/ops/_nested_compute_contiguous_strides_offsets.h>
|
| 277 |
+
#include <ATen/ops/_nested_from_padded.h>
|
| 278 |
+
#include <ATen/ops/_nested_from_padded_and_nested_example.h>
|
| 279 |
+
#include <ATen/ops/_nested_get_jagged_dummy.h>
|
| 280 |
+
#include <ATen/ops/_nested_get_lengths.h>
|
| 281 |
+
#include <ATen/ops/_nested_get_max_seqlen.h>
|
| 282 |
+
#include <ATen/ops/_nested_get_min_seqlen.h>
|
| 283 |
+
#include <ATen/ops/_nested_get_offsets.h>
|
| 284 |
+
#include <ATen/ops/_nested_get_ragged_idx.h>
|
| 285 |
+
#include <ATen/ops/_nested_get_values.h>
|
| 286 |
+
#include <ATen/ops/_nested_get_values_copy.h>
|
| 287 |
+
#include <ATen/ops/_nested_select_backward.h>
|
| 288 |
+
#include <ATen/ops/_nested_sum_backward.h>
|
| 289 |
+
#include <ATen/ops/_nested_tensor_from_mask.h>
|
| 290 |
+
#include <ATen/ops/_nested_tensor_from_mask_left_aligned.h>
|
| 291 |
+
#include <ATen/ops/_nested_tensor_from_tensor_list.h>
|
| 292 |
+
#include <ATen/ops/_nested_tensor_size.h>
|
| 293 |
+
#include <ATen/ops/_nested_tensor_softmax_with_shape.h>
|
| 294 |
+
#include <ATen/ops/_nested_tensor_storage_offsets.h>
|
| 295 |
+
#include <ATen/ops/_nested_tensor_strides.h>
|
| 296 |
+
#include <ATen/ops/_nested_view_from_buffer.h>
|
| 297 |
+
#include <ATen/ops/_nested_view_from_buffer_copy.h>
|
| 298 |
+
#include <ATen/ops/_nested_view_from_jagged.h>
|
| 299 |
+
#include <ATen/ops/_nested_view_from_jagged_copy.h>
|
| 300 |
+
#include <ATen/ops/_new_zeros_with_same_feature_meta.h>
|
| 301 |
+
#include <ATen/ops/_nnpack_available.h>
|
| 302 |
+
#include <ATen/ops/_nnpack_spatial_convolution.h>
|
| 303 |
+
#include <ATen/ops/_nnz.h>
|
| 304 |
+
#include <ATen/ops/_pack_padded_sequence.h>
|
| 305 |
+
#include <ATen/ops/_pack_padded_sequence_backward.h>
|
| 306 |
+
#include <ATen/ops/_pad_circular.h>
|
| 307 |
+
#include <ATen/ops/_pad_enum.h>
|
| 308 |
+
#include <ATen/ops/_pad_packed_sequence.h>
|
| 309 |
+
#include <ATen/ops/_padded_dense_to_jagged_forward.h>
|
| 310 |
+
#include <ATen/ops/_pdist_backward.h>
|
| 311 |
+
#include <ATen/ops/_pdist_forward.h>
|
| 312 |
+
#include <ATen/ops/_pin_memory.h>
|
| 313 |
+
#include <ATen/ops/_prelu_kernel.h>
|
| 314 |
+
#include <ATen/ops/_prelu_kernel_backward.h>
|
| 315 |
+
#include <ATen/ops/_print.h>
|
| 316 |
+
#include <ATen/ops/_propagate_xla_data.h>
|
| 317 |
+
#include <ATen/ops/_remove_batch_dim.h>
|
| 318 |
+
#include <ATen/ops/_reshape_alias.h>
|
| 319 |
+
#include <ATen/ops/_reshape_alias_copy.h>
|
| 320 |
+
#include <ATen/ops/_reshape_copy.h>
|
| 321 |
+
#include <ATen/ops/_reshape_from_tensor.h>
|
| 322 |
+
#include <ATen/ops/_resize_output.h>
|
| 323 |
+
#include <ATen/ops/_rowwise_prune.h>
|
| 324 |
+
#include <ATen/ops/_safe_softmax.h>
|
| 325 |
+
#include <ATen/ops/_sample_dirichlet.h>
|
| 326 |
+
#include <ATen/ops/_saturate_weight_to_fp16.h>
|
| 327 |
+
#include <ATen/ops/_scaled_dot_product_attention_math.h>
|
| 328 |
+
#include <ATen/ops/_scaled_dot_product_attention_math_for_mps.h>
|
| 329 |
+
#include <ATen/ops/_scaled_dot_product_cudnn_attention.h>
|
| 330 |
+
#include <ATen/ops/_scaled_dot_product_cudnn_attention_backward.h>
|
| 331 |
+
#include <ATen/ops/_scaled_dot_product_efficient_attention.h>
|
| 332 |
+
#include <ATen/ops/_scaled_dot_product_efficient_attention_backward.h>
|
| 333 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention.h>
|
| 334 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_backward.h>
|
| 335 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu.h>
|
| 336 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward.h>
|
| 337 |
+
#include <ATen/ops/_scaled_dot_product_fused_attention_overrideable.h>
|
| 338 |
+
#include <ATen/ops/_scaled_dot_product_fused_attention_overrideable_backward.h>
|
| 339 |
+
#include <ATen/ops/_scaled_mm.h>
|
| 340 |
+
#include <ATen/ops/_segment_reduce_backward.h>
|
| 341 |
+
#include <ATen/ops/_shape_as_tensor.h>
|
| 342 |
+
#include <ATen/ops/_slow_conv2d_backward.h>
|
| 343 |
+
#include <ATen/ops/_slow_conv2d_forward.h>
|
| 344 |
+
#include <ATen/ops/_sobol_engine_draw.h>
|
| 345 |
+
#include <ATen/ops/_sobol_engine_ff.h>
|
| 346 |
+
#include <ATen/ops/_sobol_engine_initialize_state.h>
|
| 347 |
+
#include <ATen/ops/_sobol_engine_scramble.h>
|
| 348 |
+
#include <ATen/ops/_softmax.h>
|
| 349 |
+
#include <ATen/ops/_softmax_backward_data.h>
|
| 350 |
+
#include <ATen/ops/_sparse_addmm.h>
|
| 351 |
+
#include <ATen/ops/_sparse_broadcast_to.h>
|
| 352 |
+
#include <ATen/ops/_sparse_broadcast_to_copy.h>
|
| 353 |
+
#include <ATen/ops/_sparse_bsc_tensor_unsafe.h>
|
| 354 |
+
#include <ATen/ops/_sparse_bsr_tensor_unsafe.h>
|
| 355 |
+
#include <ATen/ops/_sparse_compressed_tensor_unsafe.h>
|
| 356 |
+
#include <ATen/ops/_sparse_compressed_tensor_with_dims.h>
|
| 357 |
+
#include <ATen/ops/_sparse_coo_tensor_unsafe.h>
|
| 358 |
+
#include <ATen/ops/_sparse_coo_tensor_with_dims.h>
|
| 359 |
+
#include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors.h>
|
| 360 |
+
#include <ATen/ops/_sparse_csc_tensor_unsafe.h>
|
| 361 |
+
#include <ATen/ops/_sparse_csr_prod.h>
|
| 362 |
+
#include <ATen/ops/_sparse_csr_sum.h>
|
| 363 |
+
#include <ATen/ops/_sparse_csr_tensor_unsafe.h>
|
| 364 |
+
#include <ATen/ops/_sparse_log_softmax.h>
|
| 365 |
+
#include <ATen/ops/_sparse_log_softmax_backward_data.h>
|
| 366 |
+
#include <ATen/ops/_sparse_mask_projection.h>
|
| 367 |
+
#include <ATen/ops/_sparse_mm.h>
|
| 368 |
+
#include <ATen/ops/_sparse_mm_reduce_impl.h>
|
| 369 |
+
#include <ATen/ops/_sparse_mm_reduce_impl_backward.h>
|
| 370 |
+
#include <ATen/ops/_sparse_semi_structured_addmm.h>
|
| 371 |
+
#include <ATen/ops/_sparse_semi_structured_apply.h>
|
| 372 |
+
#include <ATen/ops/_sparse_semi_structured_apply_dense.h>
|
| 373 |
+
#include <ATen/ops/_sparse_semi_structured_linear.h>
|
| 374 |
+
#include <ATen/ops/_sparse_semi_structured_mm.h>
|
| 375 |
+
#include <ATen/ops/_sparse_semi_structured_tile.h>
|
| 376 |
+
#include <ATen/ops/_sparse_softmax.h>
|
| 377 |
+
#include <ATen/ops/_sparse_softmax_backward_data.h>
|
| 378 |
+
#include <ATen/ops/_sparse_sparse_matmul.h>
|
| 379 |
+
#include <ATen/ops/_sparse_sum.h>
|
| 380 |
+
#include <ATen/ops/_sparse_sum_backward.h>
|
| 381 |
+
#include <ATen/ops/_spdiags.h>
|
| 382 |
+
#include <ATen/ops/_spsolve.h>
|
| 383 |
+
#include <ATen/ops/_stack.h>
|
| 384 |
+
#include <ATen/ops/_standard_gamma.h>
|
| 385 |
+
#include <ATen/ops/_standard_gamma_grad.h>
|
| 386 |
+
#include <ATen/ops/_test_ambiguous_defaults.h>
|
| 387 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch.h>
|
| 388 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_view.h>
|
| 389 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_view_copy.h>
|
| 390 |
+
#include <ATen/ops/_test_check_tensor.h>
|
| 391 |
+
#include <ATen/ops/_test_functorch_fallback.h>
|
| 392 |
+
#include <ATen/ops/_test_optional_filled_intlist.h>
|
| 393 |
+
#include <ATen/ops/_test_optional_floatlist.h>
|
| 394 |
+
#include <ATen/ops/_test_optional_intlist.h>
|
| 395 |
+
#include <ATen/ops/_test_parallel_materialize.h>
|
| 396 |
+
#include <ATen/ops/_test_serialization_subcmul.h>
|
| 397 |
+
#include <ATen/ops/_test_string_default.h>
|
| 398 |
+
#include <ATen/ops/_test_warn_in_autograd.h>
|
| 399 |
+
#include <ATen/ops/_thnn_differentiable_gru_cell_backward.h>
|
| 400 |
+
#include <ATen/ops/_thnn_differentiable_lstm_cell_backward.h>
|
| 401 |
+
#include <ATen/ops/_thnn_fused_gru_cell.h>
|
| 402 |
+
#include <ATen/ops/_thnn_fused_gru_cell_backward.h>
|
| 403 |
+
#include <ATen/ops/_thnn_fused_lstm_cell.h>
|
| 404 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_backward.h>
|
| 405 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_backward_impl.h>
|
| 406 |
+
#include <ATen/ops/_to_copy.h>
|
| 407 |
+
#include <ATen/ops/_to_cpu.h>
|
| 408 |
+
#include <ATen/ops/_to_dense.h>
|
| 409 |
+
#include <ATen/ops/_to_sparse.h>
|
| 410 |
+
#include <ATen/ops/_to_sparse_bsc.h>
|
| 411 |
+
#include <ATen/ops/_to_sparse_bsr.h>
|
| 412 |
+
#include <ATen/ops/_to_sparse_csc.h>
|
| 413 |
+
#include <ATen/ops/_to_sparse_csr.h>
|
| 414 |
+
#include <ATen/ops/_to_sparse_semi_structured.h>
|
| 415 |
+
#include <ATen/ops/_transform_bias_rescale_qkv.h>
|
| 416 |
+
#include <ATen/ops/_transformer_encoder_layer_fwd.h>
|
| 417 |
+
#include <ATen/ops/_trilinear.h>
|
| 418 |
+
#include <ATen/ops/_triton_multi_head_attention.h>
|
| 419 |
+
#include <ATen/ops/_triton_scaled_dot_attention.h>
|
| 420 |
+
#include <ATen/ops/_unique.h>
|
| 421 |
+
#include <ATen/ops/_unique2.h>
|
| 422 |
+
#include <ATen/ops/_unpack_dual.h>
|
| 423 |
+
#include <ATen/ops/_unsafe_index.h>
|
| 424 |
+
#include <ATen/ops/_unsafe_index_put.h>
|
| 425 |
+
#include <ATen/ops/_unsafe_masked_index.h>
|
| 426 |
+
#include <ATen/ops/_unsafe_masked_index_put_accumulate.h>
|
| 427 |
+
#include <ATen/ops/_unsafe_view.h>
|
| 428 |
+
#include <ATen/ops/_upsample_bicubic2d_aa.h>
|
| 429 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_backward.h>
|
| 430 |
+
#include <ATen/ops/_upsample_bilinear2d_aa.h>
|
| 431 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_backward.h>
|
| 432 |
+
#include <ATen/ops/_upsample_nearest_exact1d.h>
|
| 433 |
+
#include <ATen/ops/_upsample_nearest_exact1d_backward.h>
|
| 434 |
+
#include <ATen/ops/_upsample_nearest_exact2d.h>
|
| 435 |
+
#include <ATen/ops/_upsample_nearest_exact2d_backward.h>
|
| 436 |
+
#include <ATen/ops/_upsample_nearest_exact3d.h>
|
| 437 |
+
#include <ATen/ops/_upsample_nearest_exact3d_backward.h>
|
| 438 |
+
#include <ATen/ops/_use_cudnn_ctc_loss.h>
|
| 439 |
+
#include <ATen/ops/_use_cudnn_rnn_flatten_weight.h>
|
| 440 |
+
#include <ATen/ops/_validate_compressed_sparse_indices.h>
|
| 441 |
+
#include <ATen/ops/_validate_sparse_bsc_tensor_args.h>
|
| 442 |
+
#include <ATen/ops/_validate_sparse_bsr_tensor_args.h>
|
| 443 |
+
#include <ATen/ops/_validate_sparse_compressed_tensor_args.h>
|
| 444 |
+
#include <ATen/ops/_validate_sparse_coo_tensor_args.h>
|
| 445 |
+
#include <ATen/ops/_validate_sparse_csc_tensor_args.h>
|
| 446 |
+
#include <ATen/ops/_validate_sparse_csr_tensor_args.h>
|
| 447 |
+
#include <ATen/ops/_values.h>
|
| 448 |
+
#include <ATen/ops/_values_copy.h>
|
| 449 |
+
#include <ATen/ops/_version.h>
|
| 450 |
+
#include <ATen/ops/_weight_int4pack_mm.h>
|
| 451 |
+
#include <ATen/ops/_weight_int8pack_mm.h>
|
| 452 |
+
#include <ATen/ops/_weight_norm.h>
|
| 453 |
+
#include <ATen/ops/_weight_norm_differentiable_backward.h>
|
| 454 |
+
#include <ATen/ops/_weight_norm_interface.h>
|
| 455 |
+
#include <ATen/ops/_weight_norm_interface_backward.h>
|
| 456 |
+
#include <ATen/ops/_wrapped_linear_prepack.h>
|
| 457 |
+
#include <ATen/ops/_wrapped_quantized_linear_prepacked.h>
|
| 458 |
+
#include <ATen/ops/abs.h>
|
| 459 |
+
#include <ATen/ops/absolute.h>
|
| 460 |
+
#include <ATen/ops/acos.h>
|
| 461 |
+
#include <ATen/ops/acosh.h>
|
| 462 |
+
#include <ATen/ops/adaptive_avg_pool1d.h>
|
| 463 |
+
#include <ATen/ops/adaptive_avg_pool2d.h>
|
| 464 |
+
#include <ATen/ops/adaptive_avg_pool3d.h>
|
| 465 |
+
#include <ATen/ops/adaptive_avg_pool3d_backward.h>
|
| 466 |
+
#include <ATen/ops/adaptive_max_pool1d.h>
|
| 467 |
+
#include <ATen/ops/adaptive_max_pool2d.h>
|
| 468 |
+
#include <ATen/ops/adaptive_max_pool2d_backward.h>
|
| 469 |
+
#include <ATen/ops/adaptive_max_pool3d.h>
|
| 470 |
+
#include <ATen/ops/adaptive_max_pool3d_backward.h>
|
| 471 |
+
#include <ATen/ops/add.h>
|
| 472 |
+
#include <ATen/ops/addbmm.h>
|
| 473 |
+
#include <ATen/ops/addcdiv.h>
|
| 474 |
+
#include <ATen/ops/addcmul.h>
|
| 475 |
+
#include <ATen/ops/addmm.h>
|
| 476 |
+
#include <ATen/ops/addmv.h>
|
| 477 |
+
#include <ATen/ops/addr.h>
|
| 478 |
+
#include <ATen/ops/adjoint.h>
|
| 479 |
+
#include <ATen/ops/affine_grid_generator.h>
|
| 480 |
+
#include <ATen/ops/affine_grid_generator_backward.h>
|
| 481 |
+
#include <ATen/ops/alias.h>
|
| 482 |
+
#include <ATen/ops/alias_copy.h>
|
| 483 |
+
#include <ATen/ops/align_as.h>
|
| 484 |
+
#include <ATen/ops/align_tensors.h>
|
| 485 |
+
#include <ATen/ops/align_to.h>
|
| 486 |
+
#include <ATen/ops/all.h>
|
| 487 |
+
#include <ATen/ops/allclose.h>
|
| 488 |
+
#include <ATen/ops/alpha_dropout.h>
|
| 489 |
+
#include <ATen/ops/amax.h>
|
| 490 |
+
#include <ATen/ops/amin.h>
|
| 491 |
+
#include <ATen/ops/aminmax.h>
|
| 492 |
+
#include <ATen/ops/and.h>
|
| 493 |
+
#include <ATen/ops/angle.h>
|
| 494 |
+
#include <ATen/ops/any.h>
|
| 495 |
+
#include <ATen/ops/arange.h>
|
| 496 |
+
#include <ATen/ops/arccos.h>
|
| 497 |
+
#include <ATen/ops/arccosh.h>
|
| 498 |
+
#include <ATen/ops/arcsin.h>
|
| 499 |
+
#include <ATen/ops/arcsinh.h>
|
| 500 |
+
#include <ATen/ops/arctan.h>
|
| 501 |
+
#include <ATen/ops/arctan2.h>
|
| 502 |
+
#include <ATen/ops/arctanh.h>
|
| 503 |
+
#include <ATen/ops/argmax.h>
|
| 504 |
+
#include <ATen/ops/argmin.h>
|
| 505 |
+
#include <ATen/ops/argsort.h>
|
| 506 |
+
#include <ATen/ops/argwhere.h>
|
| 507 |
+
#include <ATen/ops/as_strided.h>
|
| 508 |
+
#include <ATen/ops/as_strided_copy.h>
|
| 509 |
+
#include <ATen/ops/as_strided_scatter.h>
|
| 510 |
+
#include <ATen/ops/asin.h>
|
| 511 |
+
#include <ATen/ops/asinh.h>
|
| 512 |
+
#include <ATen/ops/atan.h>
|
| 513 |
+
#include <ATen/ops/atan2.h>
|
| 514 |
+
#include <ATen/ops/atanh.h>
|
| 515 |
+
#include <ATen/ops/atleast_1d.h>
|
| 516 |
+
#include <ATen/ops/atleast_2d.h>
|
| 517 |
+
#include <ATen/ops/atleast_3d.h>
|
| 518 |
+
#include <ATen/ops/avg_pool1d.h>
|
| 519 |
+
#include <ATen/ops/avg_pool2d.h>
|
| 520 |
+
#include <ATen/ops/avg_pool2d_backward.h>
|
| 521 |
+
#include <ATen/ops/avg_pool3d.h>
|
| 522 |
+
#include <ATen/ops/avg_pool3d_backward.h>
|
| 523 |
+
#include <ATen/ops/baddbmm.h>
|
| 524 |
+
#include <ATen/ops/bartlett_window.h>
|
| 525 |
+
#include <ATen/ops/batch_norm.h>
|
| 526 |
+
#include <ATen/ops/batch_norm_backward.h>
|
| 527 |
+
#include <ATen/ops/batch_norm_backward_elemt.h>
|
| 528 |
+
#include <ATen/ops/batch_norm_backward_reduce.h>
|
| 529 |
+
#include <ATen/ops/batch_norm_elemt.h>
|
| 530 |
+
#include <ATen/ops/batch_norm_gather_stats.h>
|
| 531 |
+
#include <ATen/ops/batch_norm_gather_stats_with_counts.h>
|
| 532 |
+
#include <ATen/ops/batch_norm_stats.h>
|
| 533 |
+
#include <ATen/ops/batch_norm_update_stats.h>
|
| 534 |
+
#include <ATen/ops/bernoulli.h>
|
| 535 |
+
#include <ATen/ops/bilinear.h>
|
| 536 |
+
#include <ATen/ops/binary_cross_entropy.h>
|
| 537 |
+
#include <ATen/ops/binary_cross_entropy_backward.h>
|
| 538 |
+
#include <ATen/ops/binary_cross_entropy_with_logits.h>
|
| 539 |
+
#include <ATen/ops/bincount.h>
|
| 540 |
+
#include <ATen/ops/binomial.h>
|
| 541 |
+
#include <ATen/ops/bitwise_and.h>
|
| 542 |
+
#include <ATen/ops/bitwise_left_shift.h>
|
| 543 |
+
#include <ATen/ops/bitwise_not.h>
|
| 544 |
+
#include <ATen/ops/bitwise_or.h>
|
| 545 |
+
#include <ATen/ops/bitwise_right_shift.h>
|
| 546 |
+
#include <ATen/ops/bitwise_xor.h>
|
| 547 |
+
#include <ATen/ops/blackman_window.h>
|
| 548 |
+
#include <ATen/ops/block_diag.h>
|
| 549 |
+
#include <ATen/ops/bmm.h>
|
| 550 |
+
#include <ATen/ops/broadcast_tensors.h>
|
| 551 |
+
#include <ATen/ops/broadcast_to.h>
|
| 552 |
+
#include <ATen/ops/bucketize.h>
|
| 553 |
+
#include <ATen/ops/can_cast.h>
|
| 554 |
+
#include <ATen/ops/cartesian_prod.h>
|
| 555 |
+
#include <ATen/ops/cat.h>
|
| 556 |
+
#include <ATen/ops/cauchy.h>
|
| 557 |
+
#include <ATen/ops/ccol_indices.h>
|
| 558 |
+
#include <ATen/ops/ccol_indices_copy.h>
|
| 559 |
+
#include <ATen/ops/cdist.h>
|
| 560 |
+
#include <ATen/ops/ceil.h>
|
| 561 |
+
#include <ATen/ops/celu.h>
|
| 562 |
+
#include <ATen/ops/chain_matmul.h>
|
| 563 |
+
#include <ATen/ops/chalf.h>
|
| 564 |
+
#include <ATen/ops/channel_shuffle.h>
|
| 565 |
+
#include <ATen/ops/cholesky.h>
|
| 566 |
+
#include <ATen/ops/cholesky_inverse.h>
|
| 567 |
+
#include <ATen/ops/cholesky_solve.h>
|
| 568 |
+
#include <ATen/ops/choose_qparams_optimized.h>
|
| 569 |
+
#include <ATen/ops/chunk.h>
|
| 570 |
+
#include <ATen/ops/clamp.h>
|
| 571 |
+
#include <ATen/ops/clamp_max.h>
|
| 572 |
+
#include <ATen/ops/clamp_min.h>
|
| 573 |
+
#include <ATen/ops/clip.h>
|
| 574 |
+
#include <ATen/ops/clone.h>
|
| 575 |
+
#include <ATen/ops/coalesce.h>
|
| 576 |
+
#include <ATen/ops/col2im.h>
|
| 577 |
+
#include <ATen/ops/col_indices.h>
|
| 578 |
+
#include <ATen/ops/col_indices_copy.h>
|
| 579 |
+
#include <ATen/ops/column_stack.h>
|
| 580 |
+
#include <ATen/ops/combinations.h>
|
| 581 |
+
#include <ATen/ops/complex.h>
|
| 582 |
+
#include <ATen/ops/concat.h>
|
| 583 |
+
#include <ATen/ops/concatenate.h>
|
| 584 |
+
#include <ATen/ops/conj.h>
|
| 585 |
+
#include <ATen/ops/conj_physical.h>
|
| 586 |
+
#include <ATen/ops/constant_pad_nd.h>
|
| 587 |
+
#include <ATen/ops/contiguous.h>
|
| 588 |
+
#include <ATen/ops/conv1d.h>
|
| 589 |
+
#include <ATen/ops/conv2d.h>
|
| 590 |
+
#include <ATen/ops/conv3d.h>
|
| 591 |
+
#include <ATen/ops/conv_depthwise3d.h>
|
| 592 |
+
#include <ATen/ops/conv_tbc.h>
|
| 593 |
+
#include <ATen/ops/conv_tbc_backward.h>
|
| 594 |
+
#include <ATen/ops/conv_transpose1d.h>
|
| 595 |
+
#include <ATen/ops/conv_transpose2d.h>
|
| 596 |
+
#include <ATen/ops/conv_transpose3d.h>
|
| 597 |
+
#include <ATen/ops/convolution.h>
|
| 598 |
+
#include <ATen/ops/convolution_backward.h>
|
| 599 |
+
#include <ATen/ops/convolution_backward_overrideable.h>
|
| 600 |
+
#include <ATen/ops/convolution_overrideable.h>
|
| 601 |
+
#include <ATen/ops/copy.h>
|
| 602 |
+
#include <ATen/ops/copy_sparse_to_sparse.h>
|
| 603 |
+
#include <ATen/ops/copysign.h>
|
| 604 |
+
#include <ATen/ops/corrcoef.h>
|
| 605 |
+
#include <ATen/ops/cos.h>
|
| 606 |
+
#include <ATen/ops/cosh.h>
|
| 607 |
+
#include <ATen/ops/cosine_embedding_loss.h>
|
| 608 |
+
#include <ATen/ops/cosine_similarity.h>
|
| 609 |
+
#include <ATen/ops/count_nonzero.h>
|
| 610 |
+
#include <ATen/ops/cov.h>
|
| 611 |
+
#include <ATen/ops/cross.h>
|
| 612 |
+
#include <ATen/ops/cross_entropy_loss.h>
|
| 613 |
+
#include <ATen/ops/crow_indices.h>
|
| 614 |
+
#include <ATen/ops/crow_indices_copy.h>
|
| 615 |
+
#include <ATen/ops/ctc_loss.h>
|
| 616 |
+
#include <ATen/ops/cudnn_affine_grid_generator.h>
|
| 617 |
+
#include <ATen/ops/cudnn_affine_grid_generator_backward.h>
|
| 618 |
+
#include <ATen/ops/cudnn_batch_norm.h>
|
| 619 |
+
#include <ATen/ops/cudnn_batch_norm_backward.h>
|
| 620 |
+
#include <ATen/ops/cudnn_convolution.h>
|
| 621 |
+
#include <ATen/ops/cudnn_convolution_add_relu.h>
|
| 622 |
+
#include <ATen/ops/cudnn_convolution_relu.h>
|
| 623 |
+
#include <ATen/ops/cudnn_convolution_transpose.h>
|
| 624 |
+
#include <ATen/ops/cudnn_grid_sampler.h>
|
| 625 |
+
#include <ATen/ops/cudnn_grid_sampler_backward.h>
|
| 626 |
+
#include <ATen/ops/cudnn_is_acceptable.h>
|
| 627 |
+
#include <ATen/ops/cummax.h>
|
| 628 |
+
#include <ATen/ops/cummaxmin_backward.h>
|
| 629 |
+
#include <ATen/ops/cummin.h>
|
| 630 |
+
#include <ATen/ops/cumprod.h>
|
| 631 |
+
#include <ATen/ops/cumprod_backward.h>
|
| 632 |
+
#include <ATen/ops/cumsum.h>
|
| 633 |
+
#include <ATen/ops/cumulative_trapezoid.h>
|
| 634 |
+
#include <ATen/ops/data.h>
|
| 635 |
+
#include <ATen/ops/deg2rad.h>
|
| 636 |
+
#include <ATen/ops/dense_dim.h>
|
| 637 |
+
#include <ATen/ops/dequantize.h>
|
| 638 |
+
#include <ATen/ops/det.h>
|
| 639 |
+
#include <ATen/ops/detach.h>
|
| 640 |
+
#include <ATen/ops/detach_copy.h>
|
| 641 |
+
#include <ATen/ops/diag.h>
|
| 642 |
+
#include <ATen/ops/diag_embed.h>
|
| 643 |
+
#include <ATen/ops/diagflat.h>
|
| 644 |
+
#include <ATen/ops/diagonal.h>
|
| 645 |
+
#include <ATen/ops/diagonal_backward.h>
|
| 646 |
+
#include <ATen/ops/diagonal_copy.h>
|
| 647 |
+
#include <ATen/ops/diagonal_scatter.h>
|
| 648 |
+
#include <ATen/ops/diff.h>
|
| 649 |
+
#include <ATen/ops/digamma.h>
|
| 650 |
+
#include <ATen/ops/dist.h>
|
| 651 |
+
#include <ATen/ops/div.h>
|
| 652 |
+
#include <ATen/ops/divide.h>
|
| 653 |
+
#include <ATen/ops/dot.h>
|
| 654 |
+
#include <ATen/ops/dropout.h>
|
| 655 |
+
#include <ATen/ops/dsplit.h>
|
| 656 |
+
#include <ATen/ops/dstack.h>
|
| 657 |
+
#include <ATen/ops/einsum.h>
|
| 658 |
+
#include <ATen/ops/elu.h>
|
| 659 |
+
#include <ATen/ops/elu_backward.h>
|
| 660 |
+
#include <ATen/ops/embedding.h>
|
| 661 |
+
#include <ATen/ops/embedding_backward.h>
|
| 662 |
+
#include <ATen/ops/embedding_bag.h>
|
| 663 |
+
#include <ATen/ops/embedding_dense_backward.h>
|
| 664 |
+
#include <ATen/ops/embedding_renorm.h>
|
| 665 |
+
#include <ATen/ops/embedding_sparse_backward.h>
|
| 666 |
+
#include <ATen/ops/empty.h>
|
| 667 |
+
#include <ATen/ops/empty_like.h>
|
| 668 |
+
#include <ATen/ops/empty_permuted.h>
|
| 669 |
+
#include <ATen/ops/empty_quantized.h>
|
| 670 |
+
#include <ATen/ops/empty_strided.h>
|
| 671 |
+
#include <ATen/ops/eq.h>
|
| 672 |
+
#include <ATen/ops/equal.h>
|
| 673 |
+
#include <ATen/ops/erf.h>
|
| 674 |
+
#include <ATen/ops/erfc.h>
|
| 675 |
+
#include <ATen/ops/erfinv.h>
|
| 676 |
+
#include <ATen/ops/exp.h>
|
| 677 |
+
#include <ATen/ops/exp2.h>
|
| 678 |
+
#include <ATen/ops/expand.h>
|
| 679 |
+
#include <ATen/ops/expand_as.h>
|
| 680 |
+
#include <ATen/ops/expand_copy.h>
|
| 681 |
+
#include <ATen/ops/expm1.h>
|
| 682 |
+
#include <ATen/ops/exponential.h>
|
| 683 |
+
#include <ATen/ops/eye.h>
|
| 684 |
+
#include <ATen/ops/fake_quantize_per_channel_affine.h>
|
| 685 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask.h>
|
| 686 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_backward.h>
|
| 687 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine.h>
|
| 688 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask.h>
|
| 689 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_backward.h>
|
| 690 |
+
#include <ATen/ops/fbgemm_linear_fp16_weight.h>
|
| 691 |
+
#include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation.h>
|
| 692 |
+
#include <ATen/ops/fbgemm_linear_int8_weight.h>
|
| 693 |
+
#include <ATen/ops/fbgemm_linear_int8_weight_fp32_activation.h>
|
| 694 |
+
#include <ATen/ops/fbgemm_linear_quantize_weight.h>
|
| 695 |
+
#include <ATen/ops/fbgemm_pack_gemm_matrix_fp16.h>
|
| 696 |
+
#include <ATen/ops/fbgemm_pack_quantized_matrix.h>
|
| 697 |
+
#include <ATen/ops/feature_alpha_dropout.h>
|
| 698 |
+
#include <ATen/ops/feature_dropout.h>
|
| 699 |
+
#include <ATen/ops/fft_fft.h>
|
| 700 |
+
#include <ATen/ops/fft_fft2.h>
|
| 701 |
+
#include <ATen/ops/fft_fftfreq.h>
|
| 702 |
+
#include <ATen/ops/fft_fftn.h>
|
| 703 |
+
#include <ATen/ops/fft_fftshift.h>
|
| 704 |
+
#include <ATen/ops/fft_hfft.h>
|
| 705 |
+
#include <ATen/ops/fft_hfft2.h>
|
| 706 |
+
#include <ATen/ops/fft_hfftn.h>
|
| 707 |
+
#include <ATen/ops/fft_ifft.h>
|
| 708 |
+
#include <ATen/ops/fft_ifft2.h>
|
| 709 |
+
#include <ATen/ops/fft_ifftn.h>
|
| 710 |
+
#include <ATen/ops/fft_ifftshift.h>
|
| 711 |
+
#include <ATen/ops/fft_ihfft.h>
|
| 712 |
+
#include <ATen/ops/fft_ihfft2.h>
|
| 713 |
+
#include <ATen/ops/fft_ihfftn.h>
|
| 714 |
+
#include <ATen/ops/fft_irfft.h>
|
| 715 |
+
#include <ATen/ops/fft_irfft2.h>
|
| 716 |
+
#include <ATen/ops/fft_irfftn.h>
|
| 717 |
+
#include <ATen/ops/fft_rfft.h>
|
| 718 |
+
#include <ATen/ops/fft_rfft2.h>
|
| 719 |
+
#include <ATen/ops/fft_rfftfreq.h>
|
| 720 |
+
#include <ATen/ops/fft_rfftn.h>
|
| 721 |
+
#include <ATen/ops/fill.h>
|
| 722 |
+
#include <ATen/ops/fill_diagonal.h>
|
| 723 |
+
#include <ATen/ops/fix.h>
|
| 724 |
+
#include <ATen/ops/flatten.h>
|
| 725 |
+
#include <ATen/ops/flatten_dense_tensors.h>
|
| 726 |
+
#include <ATen/ops/flip.h>
|
| 727 |
+
#include <ATen/ops/fliplr.h>
|
| 728 |
+
#include <ATen/ops/flipud.h>
|
| 729 |
+
#include <ATen/ops/float_power.h>
|
| 730 |
+
#include <ATen/ops/floor.h>
|
| 731 |
+
#include <ATen/ops/floor_divide.h>
|
| 732 |
+
#include <ATen/ops/fmax.h>
|
| 733 |
+
#include <ATen/ops/fmin.h>
|
| 734 |
+
#include <ATen/ops/fmod.h>
|
| 735 |
+
#include <ATen/ops/frac.h>
|
| 736 |
+
#include <ATen/ops/fractional_max_pool2d.h>
|
| 737 |
+
#include <ATen/ops/fractional_max_pool2d_backward.h>
|
| 738 |
+
#include <ATen/ops/fractional_max_pool3d.h>
|
| 739 |
+
#include <ATen/ops/fractional_max_pool3d_backward.h>
|
| 740 |
+
#include <ATen/ops/frexp.h>
|
| 741 |
+
#include <ATen/ops/frobenius_norm.h>
|
| 742 |
+
#include <ATen/ops/from_file.h>
|
| 743 |
+
#include <ATen/ops/full.h>
|
| 744 |
+
#include <ATen/ops/full_like.h>
|
| 745 |
+
#include <ATen/ops/fused_moving_avg_obs_fake_quant.h>
|
| 746 |
+
#include <ATen/ops/gather.h>
|
| 747 |
+
#include <ATen/ops/gather_backward.h>
|
| 748 |
+
#include <ATen/ops/gcd.h>
|
| 749 |
+
#include <ATen/ops/ge.h>
|
| 750 |
+
#include <ATen/ops/gelu.h>
|
| 751 |
+
#include <ATen/ops/gelu_backward.h>
|
| 752 |
+
#include <ATen/ops/geometric.h>
|
| 753 |
+
#include <ATen/ops/geqrf.h>
|
| 754 |
+
#include <ATen/ops/ger.h>
|
| 755 |
+
#include <ATen/ops/glu.h>
|
| 756 |
+
#include <ATen/ops/glu_backward.h>
|
| 757 |
+
#include <ATen/ops/glu_backward_jvp.h>
|
| 758 |
+
#include <ATen/ops/glu_jvp.h>
|
| 759 |
+
#include <ATen/ops/gradient.h>
|
| 760 |
+
#include <ATen/ops/greater.h>
|
| 761 |
+
#include <ATen/ops/greater_equal.h>
|
| 762 |
+
#include <ATen/ops/grid_sampler.h>
|
| 763 |
+
#include <ATen/ops/grid_sampler_2d.h>
|
| 764 |
+
#include <ATen/ops/grid_sampler_2d_backward.h>
|
| 765 |
+
#include <ATen/ops/grid_sampler_3d.h>
|
| 766 |
+
#include <ATen/ops/grid_sampler_3d_backward.h>
|
| 767 |
+
#include <ATen/ops/group_norm.h>
|
| 768 |
+
#include <ATen/ops/gru.h>
|
| 769 |
+
#include <ATen/ops/gru_cell.h>
|
| 770 |
+
#include <ATen/ops/gt.h>
|
| 771 |
+
#include <ATen/ops/hamming_window.h>
|
| 772 |
+
#include <ATen/ops/hann_window.h>
|
| 773 |
+
#include <ATen/ops/hardshrink.h>
|
| 774 |
+
#include <ATen/ops/hardshrink_backward.h>
|
| 775 |
+
#include <ATen/ops/hardsigmoid.h>
|
| 776 |
+
#include <ATen/ops/hardsigmoid_backward.h>
|
| 777 |
+
#include <ATen/ops/hardswish.h>
|
| 778 |
+
#include <ATen/ops/hardswish_backward.h>
|
| 779 |
+
#include <ATen/ops/hardtanh.h>
|
| 780 |
+
#include <ATen/ops/hardtanh_backward.h>
|
| 781 |
+
#include <ATen/ops/heaviside.h>
|
| 782 |
+
#include <ATen/ops/hinge_embedding_loss.h>
|
| 783 |
+
#include <ATen/ops/histc.h>
|
| 784 |
+
#include <ATen/ops/histogram.h>
|
| 785 |
+
#include <ATen/ops/histogramdd.h>
|
| 786 |
+
#include <ATen/ops/hsplit.h>
|
| 787 |
+
#include <ATen/ops/hspmm.h>
|
| 788 |
+
#include <ATen/ops/hstack.h>
|
| 789 |
+
#include <ATen/ops/huber_loss.h>
|
| 790 |
+
#include <ATen/ops/huber_loss_backward.h>
|
| 791 |
+
#include <ATen/ops/hypot.h>
|
| 792 |
+
#include <ATen/ops/i0.h>
|
| 793 |
+
#include <ATen/ops/igamma.h>
|
| 794 |
+
#include <ATen/ops/igammac.h>
|
| 795 |
+
#include <ATen/ops/im2col.h>
|
| 796 |
+
#include <ATen/ops/imag.h>
|
| 797 |
+
#include <ATen/ops/index.h>
|
| 798 |
+
#include <ATen/ops/index_add.h>
|
| 799 |
+
#include <ATen/ops/index_copy.h>
|
| 800 |
+
#include <ATen/ops/index_fill.h>
|
| 801 |
+
#include <ATen/ops/index_put.h>
|
| 802 |
+
#include <ATen/ops/index_reduce.h>
|
| 803 |
+
#include <ATen/ops/index_select.h>
|
| 804 |
+
#include <ATen/ops/index_select_backward.h>
|
| 805 |
+
#include <ATen/ops/indices.h>
|
| 806 |
+
#include <ATen/ops/indices_copy.h>
|
| 807 |
+
#include <ATen/ops/infinitely_differentiable_gelu_backward.h>
|
| 808 |
+
#include <ATen/ops/inner.h>
|
| 809 |
+
#include <ATen/ops/instance_norm.h>
|
| 810 |
+
#include <ATen/ops/int_repr.h>
|
| 811 |
+
#include <ATen/ops/inverse.h>
|
| 812 |
+
#include <ATen/ops/is_coalesced.h>
|
| 813 |
+
#include <ATen/ops/is_complex.h>
|
| 814 |
+
#include <ATen/ops/is_conj.h>
|
| 815 |
+
#include <ATen/ops/is_distributed.h>
|
| 816 |
+
#include <ATen/ops/is_floating_point.h>
|
| 817 |
+
#include <ATen/ops/is_inference.h>
|
| 818 |
+
#include <ATen/ops/is_leaf.h>
|
| 819 |
+
#include <ATen/ops/is_neg.h>
|
| 820 |
+
#include <ATen/ops/is_nonzero.h>
|
| 821 |
+
#include <ATen/ops/is_pinned.h>
|
| 822 |
+
#include <ATen/ops/is_same_size.h>
|
| 823 |
+
#include <ATen/ops/is_set_to.h>
|
| 824 |
+
#include <ATen/ops/is_signed.h>
|
| 825 |
+
#include <ATen/ops/is_vulkan_available.h>
|
| 826 |
+
#include <ATen/ops/isclose.h>
|
| 827 |
+
#include <ATen/ops/isfinite.h>
|
| 828 |
+
#include <ATen/ops/isin.h>
|
| 829 |
+
#include <ATen/ops/isinf.h>
|
| 830 |
+
#include <ATen/ops/isnan.h>
|
| 831 |
+
#include <ATen/ops/isneginf.h>
|
| 832 |
+
#include <ATen/ops/isposinf.h>
|
| 833 |
+
#include <ATen/ops/isreal.h>
|
| 834 |
+
#include <ATen/ops/istft.h>
|
| 835 |
+
#include <ATen/ops/item.h>
|
| 836 |
+
#include <ATen/ops/kaiser_window.h>
|
| 837 |
+
#include <ATen/ops/kl_div.h>
|
| 838 |
+
#include <ATen/ops/kron.h>
|
| 839 |
+
#include <ATen/ops/kthvalue.h>
|
| 840 |
+
#include <ATen/ops/l1_loss.h>
|
| 841 |
+
#include <ATen/ops/layer_norm.h>
|
| 842 |
+
#include <ATen/ops/lcm.h>
|
| 843 |
+
#include <ATen/ops/ldexp.h>
|
| 844 |
+
#include <ATen/ops/le.h>
|
| 845 |
+
#include <ATen/ops/leaky_relu.h>
|
| 846 |
+
#include <ATen/ops/leaky_relu_backward.h>
|
| 847 |
+
#include <ATen/ops/lerp.h>
|
| 848 |
+
#include <ATen/ops/less.h>
|
| 849 |
+
#include <ATen/ops/less_equal.h>
|
| 850 |
+
#include <ATen/ops/lgamma.h>
|
| 851 |
+
#include <ATen/ops/lift.h>
|
| 852 |
+
#include <ATen/ops/lift_fresh.h>
|
| 853 |
+
#include <ATen/ops/lift_fresh_copy.h>
|
| 854 |
+
#include <ATen/ops/linalg_cholesky.h>
|
| 855 |
+
#include <ATen/ops/linalg_cholesky_ex.h>
|
| 856 |
+
#include <ATen/ops/linalg_cond.h>
|
| 857 |
+
#include <ATen/ops/linalg_cross.h>
|
| 858 |
+
#include <ATen/ops/linalg_det.h>
|
| 859 |
+
#include <ATen/ops/linalg_diagonal.h>
|
| 860 |
+
#include <ATen/ops/linalg_eig.h>
|
| 861 |
+
#include <ATen/ops/linalg_eigh.h>
|
| 862 |
+
#include <ATen/ops/linalg_eigvals.h>
|
| 863 |
+
#include <ATen/ops/linalg_eigvalsh.h>
|
| 864 |
+
#include <ATen/ops/linalg_householder_product.h>
|
| 865 |
+
#include <ATen/ops/linalg_inv.h>
|
| 866 |
+
#include <ATen/ops/linalg_inv_ex.h>
|
| 867 |
+
#include <ATen/ops/linalg_ldl_factor.h>
|
| 868 |
+
#include <ATen/ops/linalg_ldl_factor_ex.h>
|
| 869 |
+
#include <ATen/ops/linalg_ldl_solve.h>
|
| 870 |
+
#include <ATen/ops/linalg_lstsq.h>
|
| 871 |
+
#include <ATen/ops/linalg_lu.h>
|
| 872 |
+
#include <ATen/ops/linalg_lu_factor.h>
|
| 873 |
+
#include <ATen/ops/linalg_lu_factor_ex.h>
|
| 874 |
+
#include <ATen/ops/linalg_lu_solve.h>
|
| 875 |
+
#include <ATen/ops/linalg_matmul.h>
|
| 876 |
+
#include <ATen/ops/linalg_matrix_exp.h>
|
| 877 |
+
#include <ATen/ops/linalg_matrix_norm.h>
|
| 878 |
+
#include <ATen/ops/linalg_matrix_power.h>
|
| 879 |
+
#include <ATen/ops/linalg_matrix_rank.h>
|
| 880 |
+
#include <ATen/ops/linalg_multi_dot.h>
|
| 881 |
+
#include <ATen/ops/linalg_norm.h>
|
| 882 |
+
#include <ATen/ops/linalg_pinv.h>
|
| 883 |
+
#include <ATen/ops/linalg_qr.h>
|
| 884 |
+
#include <ATen/ops/linalg_slogdet.h>
|
| 885 |
+
#include <ATen/ops/linalg_solve.h>
|
| 886 |
+
#include <ATen/ops/linalg_solve_ex.h>
|
| 887 |
+
#include <ATen/ops/linalg_solve_triangular.h>
|
| 888 |
+
#include <ATen/ops/linalg_svd.h>
|
| 889 |
+
#include <ATen/ops/linalg_svdvals.h>
|
| 890 |
+
#include <ATen/ops/linalg_tensorinv.h>
|
| 891 |
+
#include <ATen/ops/linalg_tensorsolve.h>
|
| 892 |
+
#include <ATen/ops/linalg_vander.h>
|
| 893 |
+
#include <ATen/ops/linalg_vecdot.h>
|
| 894 |
+
#include <ATen/ops/linalg_vector_norm.h>
|
| 895 |
+
#include <ATen/ops/linear.h>
|
| 896 |
+
#include <ATen/ops/linear_backward.h>
|
| 897 |
+
#include <ATen/ops/linspace.h>
|
| 898 |
+
#include <ATen/ops/log.h>
|
| 899 |
+
#include <ATen/ops/log10.h>
|
| 900 |
+
#include <ATen/ops/log1p.h>
|
| 901 |
+
#include <ATen/ops/log2.h>
|
| 902 |
+
#include <ATen/ops/log_normal.h>
|
| 903 |
+
#include <ATen/ops/log_sigmoid.h>
|
| 904 |
+
#include <ATen/ops/log_sigmoid_backward.h>
|
| 905 |
+
#include <ATen/ops/log_sigmoid_forward.h>
|
| 906 |
+
#include <ATen/ops/log_softmax.h>
|
| 907 |
+
#include <ATen/ops/logaddexp.h>
|
| 908 |
+
#include <ATen/ops/logaddexp2.h>
|
| 909 |
+
#include <ATen/ops/logcumsumexp.h>
|
| 910 |
+
#include <ATen/ops/logdet.h>
|
| 911 |
+
#include <ATen/ops/logical_and.h>
|
| 912 |
+
#include <ATen/ops/logical_not.h>
|
| 913 |
+
#include <ATen/ops/logical_or.h>
|
| 914 |
+
#include <ATen/ops/logical_xor.h>
|
| 915 |
+
#include <ATen/ops/logit.h>
|
| 916 |
+
#include <ATen/ops/logit_backward.h>
|
| 917 |
+
#include <ATen/ops/logspace.h>
|
| 918 |
+
#include <ATen/ops/logsumexp.h>
|
| 919 |
+
#include <ATen/ops/lshift.h>
|
| 920 |
+
#include <ATen/ops/lstm.h>
|
| 921 |
+
#include <ATen/ops/lstm_cell.h>
|
| 922 |
+
#include <ATen/ops/lstm_mps_backward.h>
|
| 923 |
+
#include <ATen/ops/lt.h>
|
| 924 |
+
#include <ATen/ops/lu_solve.h>
|
| 925 |
+
#include <ATen/ops/lu_unpack.h>
|
| 926 |
+
#include <ATen/ops/mH.h>
|
| 927 |
+
#include <ATen/ops/mT.h>
|
| 928 |
+
#include <ATen/ops/margin_ranking_loss.h>
|
| 929 |
+
#include <ATen/ops/masked_fill.h>
|
| 930 |
+
#include <ATen/ops/masked_scatter.h>
|
| 931 |
+
#include <ATen/ops/masked_scatter_backward.h>
|
| 932 |
+
#include <ATen/ops/masked_select.h>
|
| 933 |
+
#include <ATen/ops/masked_select_backward.h>
|
| 934 |
+
#include <ATen/ops/matmul.h>
|
| 935 |
+
#include <ATen/ops/matmul_backward.h>
|
| 936 |
+
#include <ATen/ops/matrix_H.h>
|
| 937 |
+
#include <ATen/ops/matrix_exp.h>
|
| 938 |
+
#include <ATen/ops/matrix_exp_backward.h>
|
| 939 |
+
#include <ATen/ops/matrix_power.h>
|
| 940 |
+
#include <ATen/ops/max.h>
|
| 941 |
+
#include <ATen/ops/max_pool1d.h>
|
| 942 |
+
#include <ATen/ops/max_pool1d_with_indices.h>
|
| 943 |
+
#include <ATen/ops/max_pool2d.h>
|
| 944 |
+
#include <ATen/ops/max_pool2d_backward.h>
|
| 945 |
+
#include <ATen/ops/max_pool2d_with_indices.h>
|
| 946 |
+
#include <ATen/ops/max_pool2d_with_indices_backward.h>
|
| 947 |
+
#include <ATen/ops/max_pool3d.h>
|
| 948 |
+
#include <ATen/ops/max_pool3d_with_indices.h>
|
| 949 |
+
#include <ATen/ops/max_pool3d_with_indices_backward.h>
|
| 950 |
+
#include <ATen/ops/max_unpool2d.h>
|
| 951 |
+
#include <ATen/ops/max_unpool3d.h>
|
| 952 |
+
#include <ATen/ops/maximum.h>
|
| 953 |
+
#include <ATen/ops/mean.h>
|
| 954 |
+
#include <ATen/ops/median.h>
|
| 955 |
+
#include <ATen/ops/meshgrid.h>
|
| 956 |
+
#include <ATen/ops/min.h>
|
| 957 |
+
#include <ATen/ops/minimum.h>
|
| 958 |
+
#include <ATen/ops/miopen_batch_norm.h>
|
| 959 |
+
#include <ATen/ops/miopen_batch_norm_backward.h>
|
| 960 |
+
#include <ATen/ops/miopen_convolution.h>
|
| 961 |
+
#include <ATen/ops/miopen_convolution_add_relu.h>
|
| 962 |
+
#include <ATen/ops/miopen_convolution_relu.h>
|
| 963 |
+
#include <ATen/ops/miopen_convolution_transpose.h>
|
| 964 |
+
#include <ATen/ops/miopen_depthwise_convolution.h>
|
| 965 |
+
#include <ATen/ops/miopen_rnn.h>
|
| 966 |
+
#include <ATen/ops/miopen_rnn_backward.h>
|
| 967 |
+
#include <ATen/ops/mish.h>
|
| 968 |
+
#include <ATen/ops/mish_backward.h>
|
| 969 |
+
#include <ATen/ops/mkldnn_adaptive_avg_pool2d.h>
|
| 970 |
+
#include <ATen/ops/mkldnn_adaptive_avg_pool2d_backward.h>
|
| 971 |
+
#include <ATen/ops/mkldnn_convolution.h>
|
| 972 |
+
#include <ATen/ops/mkldnn_linear.h>
|
| 973 |
+
#include <ATen/ops/mkldnn_linear_backward.h>
|
| 974 |
+
#include <ATen/ops/mkldnn_linear_backward_input.h>
|
| 975 |
+
#include <ATen/ops/mkldnn_linear_backward_weights.h>
|
| 976 |
+
#include <ATen/ops/mkldnn_max_pool2d.h>
|
| 977 |
+
#include <ATen/ops/mkldnn_max_pool2d_backward.h>
|
| 978 |
+
#include <ATen/ops/mkldnn_max_pool3d.h>
|
| 979 |
+
#include <ATen/ops/mkldnn_max_pool3d_backward.h>
|
| 980 |
+
#include <ATen/ops/mkldnn_reorder_conv2d_weight.h>
|
| 981 |
+
#include <ATen/ops/mkldnn_reorder_conv3d_weight.h>
|
| 982 |
+
#include <ATen/ops/mkldnn_rnn_layer.h>
|
| 983 |
+
#include <ATen/ops/mkldnn_rnn_layer_backward.h>
|
| 984 |
+
#include <ATen/ops/mm.h>
|
| 985 |
+
#include <ATen/ops/mode.h>
|
| 986 |
+
#include <ATen/ops/moveaxis.h>
|
| 987 |
+
#include <ATen/ops/movedim.h>
|
| 988 |
+
#include <ATen/ops/mps_convolution_backward.h>
|
| 989 |
+
#include <ATen/ops/mps_convolution_transpose_backward.h>
|
| 990 |
+
#include <ATen/ops/mse_loss.h>
|
| 991 |
+
#include <ATen/ops/mse_loss_backward.h>
|
| 992 |
+
#include <ATen/ops/msort.h>
|
| 993 |
+
#include <ATen/ops/mul.h>
|
| 994 |
+
#include <ATen/ops/multi_margin_loss.h>
|
| 995 |
+
#include <ATen/ops/multi_margin_loss_backward.h>
|
| 996 |
+
#include <ATen/ops/multilabel_margin_loss.h>
|
| 997 |
+
#include <ATen/ops/multilabel_margin_loss_backward.h>
|
| 998 |
+
#include <ATen/ops/multilabel_margin_loss_forward.h>
|
| 999 |
+
#include <ATen/ops/multinomial.h>
|
| 1000 |
+
#include <ATen/ops/multiply.h>
|
| 1001 |
+
#include <ATen/ops/mv.h>
|
| 1002 |
+
#include <ATen/ops/mvlgamma.h>
|
| 1003 |
+
#include <ATen/ops/nan_to_num.h>
|
| 1004 |
+
#include <ATen/ops/nanmean.h>
|
| 1005 |
+
#include <ATen/ops/nanmedian.h>
|
| 1006 |
+
#include <ATen/ops/nanquantile.h>
|
| 1007 |
+
#include <ATen/ops/nansum.h>
|
| 1008 |
+
#include <ATen/ops/narrow.h>
|
| 1009 |
+
#include <ATen/ops/narrow_copy.h>
|
| 1010 |
+
#include <ATen/ops/native_batch_norm.h>
|
| 1011 |
+
#include <ATen/ops/native_batch_norm_backward.h>
|
| 1012 |
+
#include <ATen/ops/native_channel_shuffle.h>
|
| 1013 |
+
#include <ATen/ops/native_dropout.h>
|
| 1014 |
+
#include <ATen/ops/native_dropout_backward.h>
|
| 1015 |
+
#include <ATen/ops/native_group_norm.h>
|
| 1016 |
+
#include <ATen/ops/native_group_norm_backward.h>
|
| 1017 |
+
#include <ATen/ops/native_layer_norm.h>
|
| 1018 |
+
#include <ATen/ops/native_layer_norm_backward.h>
|
| 1019 |
+
#include <ATen/ops/native_norm.h>
|
| 1020 |
+
#include <ATen/ops/ne.h>
|
| 1021 |
+
#include <ATen/ops/neg.h>
|
| 1022 |
+
#include <ATen/ops/negative.h>
|
| 1023 |
+
#include <ATen/ops/nested_to_padded_tensor.h>
|
| 1024 |
+
#include <ATen/ops/new_empty.h>
|
| 1025 |
+
#include <ATen/ops/new_empty_strided.h>
|
| 1026 |
+
#include <ATen/ops/new_full.h>
|
| 1027 |
+
#include <ATen/ops/new_ones.h>
|
| 1028 |
+
#include <ATen/ops/new_zeros.h>
|
| 1029 |
+
#include <ATen/ops/nextafter.h>
|
| 1030 |
+
#include <ATen/ops/nll_loss.h>
|
| 1031 |
+
#include <ATen/ops/nll_loss2d.h>
|
| 1032 |
+
#include <ATen/ops/nll_loss2d_backward.h>
|
| 1033 |
+
#include <ATen/ops/nll_loss2d_forward.h>
|
| 1034 |
+
#include <ATen/ops/nll_loss_backward.h>
|
| 1035 |
+
#include <ATen/ops/nll_loss_forward.h>
|
| 1036 |
+
#include <ATen/ops/nll_loss_nd.h>
|
| 1037 |
+
#include <ATen/ops/nonzero.h>
|
| 1038 |
+
#include <ATen/ops/nonzero_numpy.h>
|
| 1039 |
+
#include <ATen/ops/nonzero_static.h>
|
| 1040 |
+
#include <ATen/ops/norm.h>
|
| 1041 |
+
#include <ATen/ops/norm_except_dim.h>
|
| 1042 |
+
#include <ATen/ops/normal.h>
|
| 1043 |
+
#include <ATen/ops/not_equal.h>
|
| 1044 |
+
#include <ATen/ops/nuclear_norm.h>
|
| 1045 |
+
#include <ATen/ops/numpy_T.h>
|
| 1046 |
+
#include <ATen/ops/one_hot.h>
|
| 1047 |
+
#include <ATen/ops/ones.h>
|
| 1048 |
+
#include <ATen/ops/ones_like.h>
|
| 1049 |
+
#include <ATen/ops/or.h>
|
| 1050 |
+
#include <ATen/ops/orgqr.h>
|
| 1051 |
+
#include <ATen/ops/ormqr.h>
|
| 1052 |
+
#include <ATen/ops/outer.h>
|
| 1053 |
+
#include <ATen/ops/output_nr.h>
|
| 1054 |
+
#include <ATen/ops/pad.h>
|
| 1055 |
+
#include <ATen/ops/pad_sequence.h>
|
| 1056 |
+
#include <ATen/ops/pairwise_distance.h>
|
| 1057 |
+
#include <ATen/ops/pdist.h>
|
| 1058 |
+
#include <ATen/ops/permute.h>
|
| 1059 |
+
#include <ATen/ops/permute_copy.h>
|
| 1060 |
+
#include <ATen/ops/pin_memory.h>
|
| 1061 |
+
#include <ATen/ops/pinverse.h>
|
| 1062 |
+
#include <ATen/ops/pixel_shuffle.h>
|
| 1063 |
+
#include <ATen/ops/pixel_unshuffle.h>
|
| 1064 |
+
#include <ATen/ops/poisson.h>
|
| 1065 |
+
#include <ATen/ops/poisson_nll_loss.h>
|
| 1066 |
+
#include <ATen/ops/polar.h>
|
| 1067 |
+
#include <ATen/ops/polygamma.h>
|
| 1068 |
+
#include <ATen/ops/positive.h>
|
| 1069 |
+
#include <ATen/ops/pow.h>
|
| 1070 |
+
#include <ATen/ops/prelu.h>
|
| 1071 |
+
#include <ATen/ops/prod.h>
|
| 1072 |
+
#include <ATen/ops/promote_types.h>
|
| 1073 |
+
#include <ATen/ops/put.h>
|
| 1074 |
+
#include <ATen/ops/q_per_channel_axis.h>
|
| 1075 |
+
#include <ATen/ops/q_per_channel_scales.h>
|
| 1076 |
+
#include <ATen/ops/q_per_channel_zero_points.h>
|
| 1077 |
+
#include <ATen/ops/q_scale.h>
|
| 1078 |
+
#include <ATen/ops/q_zero_point.h>
|
| 1079 |
+
#include <ATen/ops/qr.h>
|
| 1080 |
+
#include <ATen/ops/qscheme.h>
|
| 1081 |
+
#include <ATen/ops/quantile.h>
|
| 1082 |
+
#include <ATen/ops/quantize_per_channel.h>
|
| 1083 |
+
#include <ATen/ops/quantize_per_tensor.h>
|
| 1084 |
+
#include <ATen/ops/quantize_per_tensor_dynamic.h>
|
| 1085 |
+
#include <ATen/ops/quantized_batch_norm.h>
|
| 1086 |
+
#include <ATen/ops/quantized_gru_cell.h>
|
| 1087 |
+
#include <ATen/ops/quantized_lstm_cell.h>
|
| 1088 |
+
#include <ATen/ops/quantized_max_pool1d.h>
|
| 1089 |
+
#include <ATen/ops/quantized_max_pool2d.h>
|
| 1090 |
+
#include <ATen/ops/quantized_max_pool3d.h>
|
| 1091 |
+
#include <ATen/ops/quantized_rnn_relu_cell.h>
|
| 1092 |
+
#include <ATen/ops/quantized_rnn_tanh_cell.h>
|
| 1093 |
+
#include <ATen/ops/rad2deg.h>
|
| 1094 |
+
#include <ATen/ops/rand.h>
|
| 1095 |
+
#include <ATen/ops/rand_like.h>
|
| 1096 |
+
#include <ATen/ops/randint.h>
|
| 1097 |
+
#include <ATen/ops/randint_like.h>
|
| 1098 |
+
#include <ATen/ops/randn.h>
|
| 1099 |
+
#include <ATen/ops/randn_like.h>
|
| 1100 |
+
#include <ATen/ops/random.h>
|
| 1101 |
+
#include <ATen/ops/randperm.h>
|
| 1102 |
+
#include <ATen/ops/range.h>
|
| 1103 |
+
#include <ATen/ops/ravel.h>
|
| 1104 |
+
#include <ATen/ops/real.h>
|
| 1105 |
+
#include <ATen/ops/reciprocal.h>
|
| 1106 |
+
#include <ATen/ops/record_stream.h>
|
| 1107 |
+
#include <ATen/ops/refine_names.h>
|
| 1108 |
+
#include <ATen/ops/reflection_pad1d.h>
|
| 1109 |
+
#include <ATen/ops/reflection_pad1d_backward.h>
|
| 1110 |
+
#include <ATen/ops/reflection_pad2d.h>
|
| 1111 |
+
#include <ATen/ops/reflection_pad2d_backward.h>
|
| 1112 |
+
#include <ATen/ops/reflection_pad3d.h>
|
| 1113 |
+
#include <ATen/ops/reflection_pad3d_backward.h>
|
| 1114 |
+
#include <ATen/ops/relu.h>
|
| 1115 |
+
#include <ATen/ops/relu6.h>
|
| 1116 |
+
#include <ATen/ops/remainder.h>
|
| 1117 |
+
#include <ATen/ops/rename.h>
|
| 1118 |
+
#include <ATen/ops/renorm.h>
|
| 1119 |
+
#include <ATen/ops/repeat.h>
|
| 1120 |
+
#include <ATen/ops/repeat_interleave.h>
|
| 1121 |
+
#include <ATen/ops/replication_pad1d.h>
|
| 1122 |
+
#include <ATen/ops/replication_pad1d_backward.h>
|
| 1123 |
+
#include <ATen/ops/replication_pad2d.h>
|
| 1124 |
+
#include <ATen/ops/replication_pad2d_backward.h>
|
| 1125 |
+
#include <ATen/ops/replication_pad3d.h>
|
| 1126 |
+
#include <ATen/ops/replication_pad3d_backward.h>
|
| 1127 |
+
#include <ATen/ops/requires_grad.h>
|
| 1128 |
+
#include <ATen/ops/reshape.h>
|
| 1129 |
+
#include <ATen/ops/reshape_as.h>
|
| 1130 |
+
#include <ATen/ops/resize.h>
|
| 1131 |
+
#include <ATen/ops/resize_as.h>
|
| 1132 |
+
#include <ATen/ops/resize_as_sparse.h>
|
| 1133 |
+
#include <ATen/ops/resolve_conj.h>
|
| 1134 |
+
#include <ATen/ops/resolve_neg.h>
|
| 1135 |
+
#include <ATen/ops/result_type.h>
|
| 1136 |
+
#include <ATen/ops/retain_grad.h>
|
| 1137 |
+
#include <ATen/ops/retains_grad.h>
|
| 1138 |
+
#include <ATen/ops/rms_norm.h>
|
| 1139 |
+
#include <ATen/ops/rnn_relu.h>
|
| 1140 |
+
#include <ATen/ops/rnn_relu_cell.h>
|
| 1141 |
+
#include <ATen/ops/rnn_tanh.h>
|
| 1142 |
+
#include <ATen/ops/rnn_tanh_cell.h>
|
| 1143 |
+
#include <ATen/ops/roll.h>
|
| 1144 |
+
#include <ATen/ops/rot90.h>
|
| 1145 |
+
#include <ATen/ops/round.h>
|
| 1146 |
+
#include <ATen/ops/row_indices.h>
|
| 1147 |
+
#include <ATen/ops/row_indices_copy.h>
|
| 1148 |
+
#include <ATen/ops/row_stack.h>
|
| 1149 |
+
#include <ATen/ops/rrelu.h>
|
| 1150 |
+
#include <ATen/ops/rrelu_with_noise.h>
|
| 1151 |
+
#include <ATen/ops/rrelu_with_noise_backward.h>
|
| 1152 |
+
#include <ATen/ops/rshift.h>
|
| 1153 |
+
#include <ATen/ops/rsqrt.h>
|
| 1154 |
+
#include <ATen/ops/rsub.h>
|
| 1155 |
+
#include <ATen/ops/scalar_tensor.h>
|
| 1156 |
+
#include <ATen/ops/scaled_dot_product_attention.h>
|
| 1157 |
+
#include <ATen/ops/scatter.h>
|
| 1158 |
+
#include <ATen/ops/scatter_add.h>
|
| 1159 |
+
#include <ATen/ops/scatter_reduce.h>
|
| 1160 |
+
#include <ATen/ops/searchsorted.h>
|
| 1161 |
+
#include <ATen/ops/segment_reduce.h>
|
| 1162 |
+
#include <ATen/ops/select.h>
|
| 1163 |
+
#include <ATen/ops/select_backward.h>
|
| 1164 |
+
#include <ATen/ops/select_copy.h>
|
| 1165 |
+
#include <ATen/ops/select_scatter.h>
|
| 1166 |
+
#include <ATen/ops/selu.h>
|
| 1167 |
+
#include <ATen/ops/set.h>
|
| 1168 |
+
#include <ATen/ops/set_data.h>
|
| 1169 |
+
#include <ATen/ops/sgn.h>
|
| 1170 |
+
#include <ATen/ops/sigmoid.h>
|
| 1171 |
+
#include <ATen/ops/sigmoid_backward.h>
|
| 1172 |
+
#include <ATen/ops/sign.h>
|
| 1173 |
+
#include <ATen/ops/signbit.h>
|
| 1174 |
+
#include <ATen/ops/silu.h>
|
| 1175 |
+
#include <ATen/ops/silu_backward.h>
|
| 1176 |
+
#include <ATen/ops/sin.h>
|
| 1177 |
+
#include <ATen/ops/sinc.h>
|
| 1178 |
+
#include <ATen/ops/sinh.h>
|
| 1179 |
+
#include <ATen/ops/size.h>
|
| 1180 |
+
#include <ATen/ops/slice.h>
|
| 1181 |
+
#include <ATen/ops/slice_backward.h>
|
| 1182 |
+
#include <ATen/ops/slice_copy.h>
|
| 1183 |
+
#include <ATen/ops/slice_inverse.h>
|
| 1184 |
+
#include <ATen/ops/slice_scatter.h>
|
| 1185 |
+
#include <ATen/ops/slogdet.h>
|
| 1186 |
+
#include <ATen/ops/slow_conv3d.h>
|
| 1187 |
+
#include <ATen/ops/slow_conv3d_forward.h>
|
| 1188 |
+
#include <ATen/ops/slow_conv_dilated2d.h>
|
| 1189 |
+
#include <ATen/ops/slow_conv_dilated3d.h>
|
| 1190 |
+
#include <ATen/ops/slow_conv_transpose2d.h>
|
| 1191 |
+
#include <ATen/ops/slow_conv_transpose3d.h>
|
| 1192 |
+
#include <ATen/ops/smm.h>
|
| 1193 |
+
#include <ATen/ops/smooth_l1_loss.h>
|
| 1194 |
+
#include <ATen/ops/smooth_l1_loss_backward.h>
|
| 1195 |
+
#include <ATen/ops/soft_margin_loss.h>
|
| 1196 |
+
#include <ATen/ops/soft_margin_loss_backward.h>
|
| 1197 |
+
#include <ATen/ops/softmax.h>
|
| 1198 |
+
#include <ATen/ops/softplus.h>
|
| 1199 |
+
#include <ATen/ops/softplus_backward.h>
|
| 1200 |
+
#include <ATen/ops/softshrink.h>
|
| 1201 |
+
#include <ATen/ops/softshrink_backward.h>
|
| 1202 |
+
#include <ATen/ops/sort.h>
|
| 1203 |
+
#include <ATen/ops/sparse_bsc_tensor.h>
|
| 1204 |
+
#include <ATen/ops/sparse_bsr_tensor.h>
|
| 1205 |
+
#include <ATen/ops/sparse_compressed_tensor.h>
|
| 1206 |
+
#include <ATen/ops/sparse_coo_tensor.h>
|
| 1207 |
+
#include <ATen/ops/sparse_csc_tensor.h>
|
| 1208 |
+
#include <ATen/ops/sparse_csr_tensor.h>
|
| 1209 |
+
#include <ATen/ops/sparse_dim.h>
|
| 1210 |
+
#include <ATen/ops/sparse_mask.h>
|
| 1211 |
+
#include <ATen/ops/sparse_resize.h>
|
| 1212 |
+
#include <ATen/ops/sparse_resize_and_clear.h>
|
| 1213 |
+
#include <ATen/ops/sparse_sampled_addmm.h>
|
| 1214 |
+
#include <ATen/ops/special_airy_ai.h>
|
| 1215 |
+
#include <ATen/ops/special_bessel_j0.h>
|
| 1216 |
+
#include <ATen/ops/special_bessel_j1.h>
|
| 1217 |
+
#include <ATen/ops/special_bessel_y0.h>
|
| 1218 |
+
#include <ATen/ops/special_bessel_y1.h>
|
| 1219 |
+
#include <ATen/ops/special_chebyshev_polynomial_t.h>
|
| 1220 |
+
#include <ATen/ops/special_chebyshev_polynomial_u.h>
|
| 1221 |
+
#include <ATen/ops/special_chebyshev_polynomial_v.h>
|
| 1222 |
+
#include <ATen/ops/special_chebyshev_polynomial_w.h>
|
| 1223 |
+
#include <ATen/ops/special_digamma.h>
|
| 1224 |
+
#include <ATen/ops/special_entr.h>
|
| 1225 |
+
#include <ATen/ops/special_erf.h>
|
| 1226 |
+
#include <ATen/ops/special_erfc.h>
|
| 1227 |
+
#include <ATen/ops/special_erfcx.h>
|
| 1228 |
+
#include <ATen/ops/special_erfinv.h>
|
| 1229 |
+
#include <ATen/ops/special_exp2.h>
|
| 1230 |
+
#include <ATen/ops/special_expit.h>
|
| 1231 |
+
#include <ATen/ops/special_expm1.h>
|
| 1232 |
+
#include <ATen/ops/special_gammainc.h>
|
| 1233 |
+
#include <ATen/ops/special_gammaincc.h>
|
| 1234 |
+
#include <ATen/ops/special_gammaln.h>
|
| 1235 |
+
#include <ATen/ops/special_hermite_polynomial_h.h>
|
| 1236 |
+
#include <ATen/ops/special_hermite_polynomial_he.h>
|
| 1237 |
+
#include <ATen/ops/special_i0.h>
|
| 1238 |
+
#include <ATen/ops/special_i0e.h>
|
| 1239 |
+
#include <ATen/ops/special_i1.h>
|
| 1240 |
+
#include <ATen/ops/special_i1e.h>
|
| 1241 |
+
#include <ATen/ops/special_laguerre_polynomial_l.h>
|
| 1242 |
+
#include <ATen/ops/special_legendre_polynomial_p.h>
|
| 1243 |
+
#include <ATen/ops/special_log1p.h>
|
| 1244 |
+
#include <ATen/ops/special_log_ndtr.h>
|
| 1245 |
+
#include <ATen/ops/special_log_softmax.h>
|
| 1246 |
+
#include <ATen/ops/special_logit.h>
|
| 1247 |
+
#include <ATen/ops/special_logsumexp.h>
|
| 1248 |
+
#include <ATen/ops/special_modified_bessel_i0.h>
|
| 1249 |
+
#include <ATen/ops/special_modified_bessel_i1.h>
|
| 1250 |
+
#include <ATen/ops/special_modified_bessel_k0.h>
|
| 1251 |
+
#include <ATen/ops/special_modified_bessel_k1.h>
|
| 1252 |
+
#include <ATen/ops/special_multigammaln.h>
|
| 1253 |
+
#include <ATen/ops/special_ndtr.h>
|
| 1254 |
+
#include <ATen/ops/special_ndtri.h>
|
| 1255 |
+
#include <ATen/ops/special_polygamma.h>
|
| 1256 |
+
#include <ATen/ops/special_psi.h>
|
| 1257 |
+
#include <ATen/ops/special_round.h>
|
| 1258 |
+
#include <ATen/ops/special_scaled_modified_bessel_k0.h>
|
| 1259 |
+
#include <ATen/ops/special_scaled_modified_bessel_k1.h>
|
| 1260 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_t.h>
|
| 1261 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_u.h>
|
| 1262 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_v.h>
|
| 1263 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_w.h>
|
| 1264 |
+
#include <ATen/ops/special_sinc.h>
|
| 1265 |
+
#include <ATen/ops/special_softmax.h>
|
| 1266 |
+
#include <ATen/ops/special_spherical_bessel_j0.h>
|
| 1267 |
+
#include <ATen/ops/special_xlog1py.h>
|
| 1268 |
+
#include <ATen/ops/special_xlogy.h>
|
| 1269 |
+
#include <ATen/ops/special_zeta.h>
|
| 1270 |
+
#include <ATen/ops/split.h>
|
| 1271 |
+
#include <ATen/ops/split_copy.h>
|
| 1272 |
+
#include <ATen/ops/split_with_sizes.h>
|
| 1273 |
+
#include <ATen/ops/split_with_sizes_copy.h>
|
| 1274 |
+
#include <ATen/ops/sqrt.h>
|
| 1275 |
+
#include <ATen/ops/square.h>
|
| 1276 |
+
#include <ATen/ops/squeeze.h>
|
| 1277 |
+
#include <ATen/ops/squeeze_copy.h>
|
| 1278 |
+
#include <ATen/ops/sspaddmm.h>
|
| 1279 |
+
#include <ATen/ops/stack.h>
|
| 1280 |
+
#include <ATen/ops/std.h>
|
| 1281 |
+
#include <ATen/ops/std_mean.h>
|
| 1282 |
+
#include <ATen/ops/stft.h>
|
| 1283 |
+
#include <ATen/ops/stride.h>
|
| 1284 |
+
#include <ATen/ops/sub.h>
|
| 1285 |
+
#include <ATen/ops/subtract.h>
|
| 1286 |
+
#include <ATen/ops/sum.h>
|
| 1287 |
+
#include <ATen/ops/sum_to_size.h>
|
| 1288 |
+
#include <ATen/ops/svd.h>
|
| 1289 |
+
#include <ATen/ops/swapaxes.h>
|
| 1290 |
+
#include <ATen/ops/swapdims.h>
|
| 1291 |
+
#include <ATen/ops/sym_constrain_range.h>
|
| 1292 |
+
#include <ATen/ops/sym_constrain_range_for_size.h>
|
| 1293 |
+
#include <ATen/ops/sym_numel.h>
|
| 1294 |
+
#include <ATen/ops/sym_size.h>
|
| 1295 |
+
#include <ATen/ops/sym_storage_offset.h>
|
| 1296 |
+
#include <ATen/ops/sym_stride.h>
|
| 1297 |
+
#include <ATen/ops/t.h>
|
| 1298 |
+
#include <ATen/ops/t_copy.h>
|
| 1299 |
+
#include <ATen/ops/take.h>
|
| 1300 |
+
#include <ATen/ops/take_along_dim.h>
|
| 1301 |
+
#include <ATen/ops/tan.h>
|
| 1302 |
+
#include <ATen/ops/tanh.h>
|
| 1303 |
+
#include <ATen/ops/tanh_backward.h>
|
| 1304 |
+
#include <ATen/ops/tensor_split.h>
|
| 1305 |
+
#include <ATen/ops/tensordot.h>
|
| 1306 |
+
#include <ATen/ops/thnn_conv2d.h>
|
| 1307 |
+
#include <ATen/ops/threshold.h>
|
| 1308 |
+
#include <ATen/ops/threshold_backward.h>
|
| 1309 |
+
#include <ATen/ops/tile.h>
|
| 1310 |
+
#include <ATen/ops/to.h>
|
| 1311 |
+
#include <ATen/ops/to_dense.h>
|
| 1312 |
+
#include <ATen/ops/to_dense_backward.h>
|
| 1313 |
+
#include <ATen/ops/to_mkldnn.h>
|
| 1314 |
+
#include <ATen/ops/to_mkldnn_backward.h>
|
| 1315 |
+
#include <ATen/ops/to_padded_tensor.h>
|
| 1316 |
+
#include <ATen/ops/to_sparse.h>
|
| 1317 |
+
#include <ATen/ops/to_sparse_bsc.h>
|
| 1318 |
+
#include <ATen/ops/to_sparse_bsr.h>
|
| 1319 |
+
#include <ATen/ops/to_sparse_csc.h>
|
| 1320 |
+
#include <ATen/ops/to_sparse_csr.h>
|
| 1321 |
+
#include <ATen/ops/topk.h>
|
| 1322 |
+
#include <ATen/ops/trace.h>
|
| 1323 |
+
#include <ATen/ops/trace_backward.h>
|
| 1324 |
+
#include <ATen/ops/transpose.h>
|
| 1325 |
+
#include <ATen/ops/transpose_copy.h>
|
| 1326 |
+
#include <ATen/ops/trapezoid.h>
|
| 1327 |
+
#include <ATen/ops/trapz.h>
|
| 1328 |
+
#include <ATen/ops/triangular_solve.h>
|
| 1329 |
+
#include <ATen/ops/tril.h>
|
| 1330 |
+
#include <ATen/ops/tril_indices.h>
|
| 1331 |
+
#include <ATen/ops/triplet_margin_loss.h>
|
| 1332 |
+
#include <ATen/ops/triu.h>
|
| 1333 |
+
#include <ATen/ops/triu_indices.h>
|
| 1334 |
+
#include <ATen/ops/true_divide.h>
|
| 1335 |
+
#include <ATen/ops/trunc.h>
|
| 1336 |
+
#include <ATen/ops/type_as.h>
|
| 1337 |
+
#include <ATen/ops/unbind.h>
|
| 1338 |
+
#include <ATen/ops/unbind_copy.h>
|
| 1339 |
+
#include <ATen/ops/unflatten.h>
|
| 1340 |
+
#include <ATen/ops/unflatten_dense_tensors.h>
|
| 1341 |
+
#include <ATen/ops/unfold.h>
|
| 1342 |
+
#include <ATen/ops/unfold_backward.h>
|
| 1343 |
+
#include <ATen/ops/unfold_copy.h>
|
| 1344 |
+
#include <ATen/ops/uniform.h>
|
| 1345 |
+
#include <ATen/ops/unique_consecutive.h>
|
| 1346 |
+
#include <ATen/ops/unique_dim.h>
|
| 1347 |
+
#include <ATen/ops/unique_dim_consecutive.h>
|
| 1348 |
+
#include <ATen/ops/unsafe_chunk.h>
|
| 1349 |
+
#include <ATen/ops/unsafe_split.h>
|
| 1350 |
+
#include <ATen/ops/unsafe_split_with_sizes.h>
|
| 1351 |
+
#include <ATen/ops/unsqueeze.h>
|
| 1352 |
+
#include <ATen/ops/unsqueeze_copy.h>
|
| 1353 |
+
#include <ATen/ops/upsample_bicubic2d.h>
|
| 1354 |
+
#include <ATen/ops/upsample_bicubic2d_backward.h>
|
| 1355 |
+
#include <ATen/ops/upsample_bilinear2d.h>
|
| 1356 |
+
#include <ATen/ops/upsample_bilinear2d_backward.h>
|
| 1357 |
+
#include <ATen/ops/upsample_linear1d.h>
|
| 1358 |
+
#include <ATen/ops/upsample_linear1d_backward.h>
|
| 1359 |
+
#include <ATen/ops/upsample_nearest1d.h>
|
| 1360 |
+
#include <ATen/ops/upsample_nearest1d_backward.h>
|
| 1361 |
+
#include <ATen/ops/upsample_nearest2d.h>
|
| 1362 |
+
#include <ATen/ops/upsample_nearest2d_backward.h>
|
| 1363 |
+
#include <ATen/ops/upsample_nearest3d.h>
|
| 1364 |
+
#include <ATen/ops/upsample_nearest3d_backward.h>
|
| 1365 |
+
#include <ATen/ops/upsample_trilinear3d.h>
|
| 1366 |
+
#include <ATen/ops/upsample_trilinear3d_backward.h>
|
| 1367 |
+
#include <ATen/ops/value_selecting_reduction_backward.h>
|
| 1368 |
+
#include <ATen/ops/values.h>
|
| 1369 |
+
#include <ATen/ops/values_copy.h>
|
| 1370 |
+
#include <ATen/ops/vander.h>
|
| 1371 |
+
#include <ATen/ops/var.h>
|
| 1372 |
+
#include <ATen/ops/var_mean.h>
|
| 1373 |
+
#include <ATen/ops/vdot.h>
|
| 1374 |
+
#include <ATen/ops/view.h>
|
| 1375 |
+
#include <ATen/ops/view_as.h>
|
| 1376 |
+
#include <ATen/ops/view_as_complex.h>
|
| 1377 |
+
#include <ATen/ops/view_as_complex_copy.h>
|
| 1378 |
+
#include <ATen/ops/view_as_real.h>
|
| 1379 |
+
#include <ATen/ops/view_as_real_copy.h>
|
| 1380 |
+
#include <ATen/ops/view_copy.h>
|
| 1381 |
+
#include <ATen/ops/vsplit.h>
|
| 1382 |
+
#include <ATen/ops/vstack.h>
|
| 1383 |
+
#include <ATen/ops/where.h>
|
| 1384 |
+
#include <ATen/ops/xlogy.h>
|
| 1385 |
+
#include <ATen/ops/xor.h>
|
| 1386 |
+
#include <ATen/ops/zero.h>
|
| 1387 |
+
#include <ATen/ops/zeros.h>
|
| 1388 |
+
#include <ATen/ops/zeros_like.h>
|
| 1389 |
+
|
| 1390 |
+
namespace at {
|
| 1391 |
+
|
| 1392 |
+
|
| 1393 |
+
|
| 1394 |
+
// Special C++ only overloads for std()-like functions (See gh-40287)
|
| 1395 |
+
// These are needed because int -> bool conversion takes precedence over int -> IntArrayRef
|
| 1396 |
+
// So, for example std(0) would select the std(unbiased=False) overload
|
| 1397 |
+
TORCH_API inline Tensor var(const Tensor& self, int dim) {
|
| 1398 |
+
return at::var(self, IntArrayRef{dim});
|
| 1399 |
+
}
|
| 1400 |
+
TORCH_API inline std::tuple<Tensor, Tensor> var_mean(const Tensor& self, int dim) {
|
| 1401 |
+
return at::var_mean(self, IntArrayRef{dim});
|
| 1402 |
+
}
|
| 1403 |
+
TORCH_API inline Tensor std(const Tensor& self, int dim) {
|
| 1404 |
+
return at::std(self, IntArrayRef{dim});
|
| 1405 |
+
}
|
| 1406 |
+
TORCH_API inline std::tuple<Tensor, Tensor> std_mean(const Tensor& self, int dim) {
|
| 1407 |
+
return at::std_mean(self, IntArrayRef{dim});
|
| 1408 |
+
}
|
| 1409 |
+
|
| 1410 |
+
inline int64_t numel(const Tensor& tensor) {
|
| 1411 |
+
return tensor.numel();
|
| 1412 |
+
}
|
| 1413 |
+
|
| 1414 |
+
inline int64_t size(const Tensor& tensor, int64_t dim) {
|
| 1415 |
+
return tensor.size(dim);
|
| 1416 |
+
}
|
| 1417 |
+
|
| 1418 |
+
inline int64_t stride(const Tensor& tensor, int64_t dim) {
|
| 1419 |
+
return tensor.stride(dim);
|
| 1420 |
+
}
|
| 1421 |
+
|
| 1422 |
+
inline bool is_complex(const Tensor& tensor) {
|
| 1423 |
+
return tensor.is_complex();
|
| 1424 |
+
}
|
| 1425 |
+
|
| 1426 |
+
inline bool is_floating_point(const Tensor& tensor) {
|
| 1427 |
+
return tensor.is_floating_point();
|
| 1428 |
+
}
|
| 1429 |
+
|
| 1430 |
+
inline bool is_signed(const Tensor& tensor) {
|
| 1431 |
+
return tensor.is_signed();
|
| 1432 |
+
}
|
| 1433 |
+
|
| 1434 |
+
inline bool is_inference(const Tensor& tensor) {
|
| 1435 |
+
return tensor.is_inference();
|
| 1436 |
+
}
|
| 1437 |
+
|
| 1438 |
+
inline bool _is_zerotensor(const Tensor& tensor) {
|
| 1439 |
+
return tensor._is_zerotensor();
|
| 1440 |
+
}
|
| 1441 |
+
|
| 1442 |
+
inline bool is_conj(const Tensor& tensor) {
|
| 1443 |
+
return tensor.is_conj();
|
| 1444 |
+
}
|
| 1445 |
+
|
| 1446 |
+
inline Tensor conj(const Tensor& tensor) {
|
| 1447 |
+
return tensor.conj();
|
| 1448 |
+
}
|
| 1449 |
+
|
| 1450 |
+
inline bool is_neg(const Tensor& tensor) {
|
| 1451 |
+
return tensor.is_neg();
|
| 1452 |
+
}
|
| 1453 |
+
|
| 1454 |
+
}
|
.venv/lib/python3.11/site-packages/torch/include/ATen/Generator.h
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Generator.h>
|
.venv/lib/python3.11/site-packages/torch/include/ATen/LinalgBackend.h
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/util/Exception.h>
|
| 4 |
+
|
| 5 |
+
#include <ostream>
|
| 6 |
+
#include <string>
|
| 7 |
+
|
| 8 |
+
namespace at {
|
| 9 |
+
|
| 10 |
+
enum class LinalgBackend : int8_t { Default, Cusolver, Magma };
|
| 11 |
+
|
| 12 |
+
inline std::string LinalgBackendToString(at::LinalgBackend backend) {
|
| 13 |
+
switch (backend) {
|
| 14 |
+
case LinalgBackend::Default:
|
| 15 |
+
return "at::LinalgBackend::Default";
|
| 16 |
+
case LinalgBackend::Cusolver:
|
| 17 |
+
return "at::LinalgBackend::Cusolver";
|
| 18 |
+
case LinalgBackend::Magma:
|
| 19 |
+
return "at::LinalgBackend::Magma";
|
| 20 |
+
default:
|
| 21 |
+
TORCH_CHECK(false, "Unknown linalg backend");
|
| 22 |
+
}
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
inline std::ostream& operator<<(
|
| 26 |
+
std::ostream& stream,
|
| 27 |
+
at::LinalgBackend backend) {
|
| 28 |
+
return stream << LinalgBackendToString(backend);
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/MemoryOverlap.h
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/macros/Export.h>
|
| 4 |
+
|
| 5 |
+
namespace c10 {
|
| 6 |
+
struct TensorImpl;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
namespace at {
|
| 10 |
+
class TensorBase;
|
| 11 |
+
|
| 12 |
+
// MemOverlap: Whether or not there is memory overlap
|
| 13 |
+
//
|
| 14 |
+
// No: Absolutely no memory overlap
|
| 15 |
+
// Yes: Absolutely yes memory overlap
|
| 16 |
+
// TooHard: There might be memory overlap, but it was too expensive to compute.
|
| 17 |
+
//
|
| 18 |
+
// NB: Please update the python test for these if you renumber them.
|
| 19 |
+
enum class MemOverlap { No, Yes, TooHard };
|
| 20 |
+
|
| 21 |
+
enum class MemOverlapStatus { Full, Partial, No, TooHard };
|
| 22 |
+
|
| 23 |
+
TORCH_API MemOverlap has_internal_overlap(const TensorBase& t);
|
| 24 |
+
TORCH_API MemOverlap has_internal_overlap(c10::TensorImpl* t);
|
| 25 |
+
|
| 26 |
+
TORCH_API void assert_no_internal_overlap(const TensorBase& t);
|
| 27 |
+
TORCH_API void assert_no_internal_overlap(c10::TensorImpl* t);
|
| 28 |
+
|
| 29 |
+
TORCH_API MemOverlapStatus
|
| 30 |
+
get_overlap_status(const TensorBase& a, const TensorBase& b);
|
| 31 |
+
TORCH_API MemOverlapStatus
|
| 32 |
+
get_overlap_status(const c10::TensorImpl* a, const c10::TensorImpl* b);
|
| 33 |
+
|
| 34 |
+
TORCH_API void assert_no_partial_overlap(
|
| 35 |
+
const TensorBase& a,
|
| 36 |
+
const TensorBase& b);
|
| 37 |
+
void assert_no_partial_overlap(c10::TensorImpl* a, c10::TensorImpl* b);
|
| 38 |
+
|
| 39 |
+
TORCH_API void assert_no_overlap(const TensorBase& a, const TensorBase& b);
|
| 40 |
+
TORCH_API void assert_no_overlap(c10::TensorImpl* a, c10::TensorImpl* b);
|
| 41 |
+
|
| 42 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/NativeMetaFunctions.h
ADDED
|
@@ -0,0 +1,1330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from NativeMetaFunctions.h
|
| 4 |
+
|
| 5 |
+
#include <ATen/core/Tensor.h>
|
| 6 |
+
#include <ATen/core/IListRef.h>
|
| 7 |
+
#include <ATen/TensorMeta.h>
|
| 8 |
+
#include <ATen/TensorIterator.h>
|
| 9 |
+
|
| 10 |
+
#include <ATen/ops/_adaptive_avg_pool2d_meta.h>
|
| 11 |
+
#include <ATen/ops/_adaptive_avg_pool2d_backward_meta.h>
|
| 12 |
+
#include <ATen/ops/_adaptive_avg_pool3d_meta.h>
|
| 13 |
+
#include <ATen/ops/_adaptive_avg_pool3d_backward_meta.h>
|
| 14 |
+
#include <ATen/ops/_add_batch_dim_meta.h>
|
| 15 |
+
#include <ATen/ops/_add_relu_meta.h>
|
| 16 |
+
#include <ATen/ops/_addmm_activation_meta.h>
|
| 17 |
+
#include <ATen/ops/_aminmax_meta.h>
|
| 18 |
+
#include <ATen/ops/_amp_foreach_non_finite_check_and_unscale_meta.h>
|
| 19 |
+
#include <ATen/ops/_amp_update_scale_meta.h>
|
| 20 |
+
#include <ATen/ops/_assert_async_meta.h>
|
| 21 |
+
#include <ATen/ops/_assert_scalar_meta.h>
|
| 22 |
+
#include <ATen/ops/_assert_tensor_metadata_meta.h>
|
| 23 |
+
#include <ATen/ops/_autocast_to_full_precision_meta.h>
|
| 24 |
+
#include <ATen/ops/_autocast_to_reduced_precision_meta.h>
|
| 25 |
+
#include <ATen/ops/_backward_meta.h>
|
| 26 |
+
#include <ATen/ops/_batch_norm_impl_index_meta.h>
|
| 27 |
+
#include <ATen/ops/_batch_norm_impl_index_backward_meta.h>
|
| 28 |
+
#include <ATen/ops/_batch_norm_no_update_meta.h>
|
| 29 |
+
#include <ATen/ops/_batch_norm_with_update_meta.h>
|
| 30 |
+
#include <ATen/ops/_cast_Byte_meta.h>
|
| 31 |
+
#include <ATen/ops/_cast_Char_meta.h>
|
| 32 |
+
#include <ATen/ops/_cast_Double_meta.h>
|
| 33 |
+
#include <ATen/ops/_cast_Float_meta.h>
|
| 34 |
+
#include <ATen/ops/_cast_Half_meta.h>
|
| 35 |
+
#include <ATen/ops/_cast_Int_meta.h>
|
| 36 |
+
#include <ATen/ops/_cast_Long_meta.h>
|
| 37 |
+
#include <ATen/ops/_cast_Short_meta.h>
|
| 38 |
+
#include <ATen/ops/_cdist_backward_meta.h>
|
| 39 |
+
#include <ATen/ops/_cdist_forward_meta.h>
|
| 40 |
+
#include <ATen/ops/_cholesky_solve_helper_meta.h>
|
| 41 |
+
#include <ATen/ops/_choose_qparams_per_tensor_meta.h>
|
| 42 |
+
#include <ATen/ops/_chunk_cat_meta.h>
|
| 43 |
+
#include <ATen/ops/_coalesce_meta.h>
|
| 44 |
+
#include <ATen/ops/_coalesced_meta.h>
|
| 45 |
+
#include <ATen/ops/_compute_linear_combination_meta.h>
|
| 46 |
+
#include <ATen/ops/_conj_meta.h>
|
| 47 |
+
#include <ATen/ops/_conj_copy_meta.h>
|
| 48 |
+
#include <ATen/ops/_conj_physical_meta.h>
|
| 49 |
+
#include <ATen/ops/_conv_depthwise2d_meta.h>
|
| 50 |
+
#include <ATen/ops/_convert_indices_from_coo_to_csr_meta.h>
|
| 51 |
+
#include <ATen/ops/_convert_indices_from_csr_to_coo_meta.h>
|
| 52 |
+
#include <ATen/ops/_convert_weight_to_int4pack_meta.h>
|
| 53 |
+
#include <ATen/ops/_convolution_meta.h>
|
| 54 |
+
#include <ATen/ops/_convolution_double_backward_meta.h>
|
| 55 |
+
#include <ATen/ops/_convolution_mode_meta.h>
|
| 56 |
+
#include <ATen/ops/_copy_from_meta.h>
|
| 57 |
+
#include <ATen/ops/_copy_from_and_resize_meta.h>
|
| 58 |
+
#include <ATen/ops/_cslt_compress_meta.h>
|
| 59 |
+
#include <ATen/ops/_cslt_sparse_mm_meta.h>
|
| 60 |
+
#include <ATen/ops/_cslt_sparse_mm_search_meta.h>
|
| 61 |
+
#include <ATen/ops/_ctc_loss_meta.h>
|
| 62 |
+
#include <ATen/ops/_ctc_loss_backward_meta.h>
|
| 63 |
+
#include <ATen/ops/_cudnn_ctc_loss_meta.h>
|
| 64 |
+
#include <ATen/ops/_cudnn_init_dropout_state_meta.h>
|
| 65 |
+
#include <ATen/ops/_cudnn_rnn_meta.h>
|
| 66 |
+
#include <ATen/ops/_cudnn_rnn_backward_meta.h>
|
| 67 |
+
#include <ATen/ops/_cudnn_rnn_flatten_weight_meta.h>
|
| 68 |
+
#include <ATen/ops/_cufft_clear_plan_cache_meta.h>
|
| 69 |
+
#include <ATen/ops/_cufft_get_plan_cache_max_size_meta.h>
|
| 70 |
+
#include <ATen/ops/_cufft_get_plan_cache_size_meta.h>
|
| 71 |
+
#include <ATen/ops/_cufft_set_plan_cache_max_size_meta.h>
|
| 72 |
+
#include <ATen/ops/_cummax_helper_meta.h>
|
| 73 |
+
#include <ATen/ops/_cummin_helper_meta.h>
|
| 74 |
+
#include <ATen/ops/_debug_has_internal_overlap_meta.h>
|
| 75 |
+
#include <ATen/ops/_dimI_meta.h>
|
| 76 |
+
#include <ATen/ops/_dimV_meta.h>
|
| 77 |
+
#include <ATen/ops/_dim_arange_meta.h>
|
| 78 |
+
#include <ATen/ops/_dirichlet_grad_meta.h>
|
| 79 |
+
#include <ATen/ops/_efficient_attention_backward_meta.h>
|
| 80 |
+
#include <ATen/ops/_efficient_attention_forward_meta.h>
|
| 81 |
+
#include <ATen/ops/_efficientzerotensor_meta.h>
|
| 82 |
+
#include <ATen/ops/_embedding_bag_meta.h>
|
| 83 |
+
#include <ATen/ops/_embedding_bag_backward_meta.h>
|
| 84 |
+
#include <ATen/ops/_embedding_bag_dense_backward_meta.h>
|
| 85 |
+
#include <ATen/ops/_embedding_bag_forward_only_meta.h>
|
| 86 |
+
#include <ATen/ops/_embedding_bag_per_sample_weights_backward_meta.h>
|
| 87 |
+
#include <ATen/ops/_embedding_bag_sparse_backward_meta.h>
|
| 88 |
+
#include <ATen/ops/_empty_affine_quantized_meta.h>
|
| 89 |
+
#include <ATen/ops/_empty_per_channel_affine_quantized_meta.h>
|
| 90 |
+
#include <ATen/ops/_euclidean_dist_meta.h>
|
| 91 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_meta.h>
|
| 92 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_backward_meta.h>
|
| 93 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_meta.h>
|
| 94 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_backward_meta.h>
|
| 95 |
+
#include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_meta.h>
|
| 96 |
+
#include <ATen/ops/_fft_c2c_meta.h>
|
| 97 |
+
#include <ATen/ops/_fft_c2r_meta.h>
|
| 98 |
+
#include <ATen/ops/_fft_r2c_meta.h>
|
| 99 |
+
#include <ATen/ops/_fill_mem_eff_dropout_mask_meta.h>
|
| 100 |
+
#include <ATen/ops/_flash_attention_backward_meta.h>
|
| 101 |
+
#include <ATen/ops/_flash_attention_forward_meta.h>
|
| 102 |
+
#include <ATen/ops/_foobar_meta.h>
|
| 103 |
+
#include <ATen/ops/_foreach_abs_meta.h>
|
| 104 |
+
#include <ATen/ops/_foreach_acos_meta.h>
|
| 105 |
+
#include <ATen/ops/_foreach_add_meta.h>
|
| 106 |
+
#include <ATen/ops/_foreach_addcdiv_meta.h>
|
| 107 |
+
#include <ATen/ops/_foreach_addcmul_meta.h>
|
| 108 |
+
#include <ATen/ops/_foreach_asin_meta.h>
|
| 109 |
+
#include <ATen/ops/_foreach_atan_meta.h>
|
| 110 |
+
#include <ATen/ops/_foreach_ceil_meta.h>
|
| 111 |
+
#include <ATen/ops/_foreach_clamp_max_meta.h>
|
| 112 |
+
#include <ATen/ops/_foreach_clamp_min_meta.h>
|
| 113 |
+
#include <ATen/ops/_foreach_copy_meta.h>
|
| 114 |
+
#include <ATen/ops/_foreach_cos_meta.h>
|
| 115 |
+
#include <ATen/ops/_foreach_cosh_meta.h>
|
| 116 |
+
#include <ATen/ops/_foreach_div_meta.h>
|
| 117 |
+
#include <ATen/ops/_foreach_erf_meta.h>
|
| 118 |
+
#include <ATen/ops/_foreach_erfc_meta.h>
|
| 119 |
+
#include <ATen/ops/_foreach_exp_meta.h>
|
| 120 |
+
#include <ATen/ops/_foreach_expm1_meta.h>
|
| 121 |
+
#include <ATen/ops/_foreach_floor_meta.h>
|
| 122 |
+
#include <ATen/ops/_foreach_frac_meta.h>
|
| 123 |
+
#include <ATen/ops/_foreach_lerp_meta.h>
|
| 124 |
+
#include <ATen/ops/_foreach_lgamma_meta.h>
|
| 125 |
+
#include <ATen/ops/_foreach_log_meta.h>
|
| 126 |
+
#include <ATen/ops/_foreach_log10_meta.h>
|
| 127 |
+
#include <ATen/ops/_foreach_log1p_meta.h>
|
| 128 |
+
#include <ATen/ops/_foreach_log2_meta.h>
|
| 129 |
+
#include <ATen/ops/_foreach_max_meta.h>
|
| 130 |
+
#include <ATen/ops/_foreach_maximum_meta.h>
|
| 131 |
+
#include <ATen/ops/_foreach_minimum_meta.h>
|
| 132 |
+
#include <ATen/ops/_foreach_mul_meta.h>
|
| 133 |
+
#include <ATen/ops/_foreach_neg_meta.h>
|
| 134 |
+
#include <ATen/ops/_foreach_norm_meta.h>
|
| 135 |
+
#include <ATen/ops/_foreach_pow_meta.h>
|
| 136 |
+
#include <ATen/ops/_foreach_reciprocal_meta.h>
|
| 137 |
+
#include <ATen/ops/_foreach_round_meta.h>
|
| 138 |
+
#include <ATen/ops/_foreach_sigmoid_meta.h>
|
| 139 |
+
#include <ATen/ops/_foreach_sign_meta.h>
|
| 140 |
+
#include <ATen/ops/_foreach_sin_meta.h>
|
| 141 |
+
#include <ATen/ops/_foreach_sinh_meta.h>
|
| 142 |
+
#include <ATen/ops/_foreach_sqrt_meta.h>
|
| 143 |
+
#include <ATen/ops/_foreach_sub_meta.h>
|
| 144 |
+
#include <ATen/ops/_foreach_tan_meta.h>
|
| 145 |
+
#include <ATen/ops/_foreach_tanh_meta.h>
|
| 146 |
+
#include <ATen/ops/_foreach_trunc_meta.h>
|
| 147 |
+
#include <ATen/ops/_foreach_zero_meta.h>
|
| 148 |
+
#include <ATen/ops/_functional_assert_async_meta.h>
|
| 149 |
+
#include <ATen/ops/_functional_assert_scalar_meta.h>
|
| 150 |
+
#include <ATen/ops/_functional_sym_constrain_range_meta.h>
|
| 151 |
+
#include <ATen/ops/_functional_sym_constrain_range_for_size_meta.h>
|
| 152 |
+
#include <ATen/ops/_fused_adagrad_meta.h>
|
| 153 |
+
#include <ATen/ops/_fused_adam_meta.h>
|
| 154 |
+
#include <ATen/ops/_fused_adamw_meta.h>
|
| 155 |
+
#include <ATen/ops/_fused_dropout_meta.h>
|
| 156 |
+
#include <ATen/ops/_fused_moving_avg_obs_fq_helper_meta.h>
|
| 157 |
+
#include <ATen/ops/_fused_sdp_choice_meta.h>
|
| 158 |
+
#include <ATen/ops/_fused_sgd_meta.h>
|
| 159 |
+
#include <ATen/ops/_fw_primal_meta.h>
|
| 160 |
+
#include <ATen/ops/_fw_primal_copy_meta.h>
|
| 161 |
+
#include <ATen/ops/_gather_sparse_backward_meta.h>
|
| 162 |
+
#include <ATen/ops/_grid_sampler_2d_cpu_fallback_meta.h>
|
| 163 |
+
#include <ATen/ops/_grid_sampler_2d_cpu_fallback_backward_meta.h>
|
| 164 |
+
#include <ATen/ops/_has_compatible_shallow_copy_type_meta.h>
|
| 165 |
+
#include <ATen/ops/_has_same_storage_numel_meta.h>
|
| 166 |
+
#include <ATen/ops/_histogramdd_bin_edges_meta.h>
|
| 167 |
+
#include <ATen/ops/_histogramdd_from_bin_cts_meta.h>
|
| 168 |
+
#include <ATen/ops/_histogramdd_from_bin_tensors_meta.h>
|
| 169 |
+
#include <ATen/ops/_index_put_impl_meta.h>
|
| 170 |
+
#include <ATen/ops/_indices_meta.h>
|
| 171 |
+
#include <ATen/ops/_indices_copy_meta.h>
|
| 172 |
+
#include <ATen/ops/_int_mm_meta.h>
|
| 173 |
+
#include <ATen/ops/_is_all_true_meta.h>
|
| 174 |
+
#include <ATen/ops/_is_any_true_meta.h>
|
| 175 |
+
#include <ATen/ops/_is_zerotensor_meta.h>
|
| 176 |
+
#include <ATen/ops/_jagged_to_padded_dense_forward_meta.h>
|
| 177 |
+
#include <ATen/ops/_lazy_clone_meta.h>
|
| 178 |
+
#include <ATen/ops/_linalg_check_errors_meta.h>
|
| 179 |
+
#include <ATen/ops/_linalg_det_meta.h>
|
| 180 |
+
#include <ATen/ops/_linalg_eigh_meta.h>
|
| 181 |
+
#include <ATen/ops/_linalg_eigvals_meta.h>
|
| 182 |
+
#include <ATen/ops/_linalg_slogdet_meta.h>
|
| 183 |
+
#include <ATen/ops/_linalg_solve_ex_meta.h>
|
| 184 |
+
#include <ATen/ops/_linalg_svd_meta.h>
|
| 185 |
+
#include <ATen/ops/_local_scalar_dense_meta.h>
|
| 186 |
+
#include <ATen/ops/_log_softmax_meta.h>
|
| 187 |
+
#include <ATen/ops/_log_softmax_backward_data_meta.h>
|
| 188 |
+
#include <ATen/ops/_logcumsumexp_meta.h>
|
| 189 |
+
#include <ATen/ops/_lstm_mps_meta.h>
|
| 190 |
+
#include <ATen/ops/_lu_with_info_meta.h>
|
| 191 |
+
#include <ATen/ops/_make_dep_token_meta.h>
|
| 192 |
+
#include <ATen/ops/_make_dual_meta.h>
|
| 193 |
+
#include <ATen/ops/_make_dual_copy_meta.h>
|
| 194 |
+
#include <ATen/ops/_make_per_channel_quantized_tensor_meta.h>
|
| 195 |
+
#include <ATen/ops/_make_per_tensor_quantized_tensor_meta.h>
|
| 196 |
+
#include <ATen/ops/_masked_scale_meta.h>
|
| 197 |
+
#include <ATen/ops/_masked_softmax_meta.h>
|
| 198 |
+
#include <ATen/ops/_masked_softmax_backward_meta.h>
|
| 199 |
+
#include <ATen/ops/_mixed_dtypes_linear_meta.h>
|
| 200 |
+
#include <ATen/ops/_mkldnn_reshape_meta.h>
|
| 201 |
+
#include <ATen/ops/_mkldnn_transpose_meta.h>
|
| 202 |
+
#include <ATen/ops/_mps_convolution_meta.h>
|
| 203 |
+
#include <ATen/ops/_mps_convolution_transpose_meta.h>
|
| 204 |
+
#include <ATen/ops/_native_batch_norm_legit_meta.h>
|
| 205 |
+
#include <ATen/ops/_native_batch_norm_legit_no_training_meta.h>
|
| 206 |
+
#include <ATen/ops/_native_multi_head_attention_meta.h>
|
| 207 |
+
#include <ATen/ops/_neg_view_meta.h>
|
| 208 |
+
#include <ATen/ops/_neg_view_copy_meta.h>
|
| 209 |
+
#include <ATen/ops/_nested_compute_contiguous_strides_offsets_meta.h>
|
| 210 |
+
#include <ATen/ops/_nested_from_padded_meta.h>
|
| 211 |
+
#include <ATen/ops/_nested_from_padded_and_nested_example_meta.h>
|
| 212 |
+
#include <ATen/ops/_nested_get_jagged_dummy_meta.h>
|
| 213 |
+
#include <ATen/ops/_nested_get_lengths_meta.h>
|
| 214 |
+
#include <ATen/ops/_nested_get_max_seqlen_meta.h>
|
| 215 |
+
#include <ATen/ops/_nested_get_min_seqlen_meta.h>
|
| 216 |
+
#include <ATen/ops/_nested_get_offsets_meta.h>
|
| 217 |
+
#include <ATen/ops/_nested_get_ragged_idx_meta.h>
|
| 218 |
+
#include <ATen/ops/_nested_get_values_meta.h>
|
| 219 |
+
#include <ATen/ops/_nested_get_values_copy_meta.h>
|
| 220 |
+
#include <ATen/ops/_nested_select_backward_meta.h>
|
| 221 |
+
#include <ATen/ops/_nested_sum_backward_meta.h>
|
| 222 |
+
#include <ATen/ops/_nested_tensor_from_mask_meta.h>
|
| 223 |
+
#include <ATen/ops/_nested_tensor_from_mask_left_aligned_meta.h>
|
| 224 |
+
#include <ATen/ops/_nested_tensor_from_tensor_list_meta.h>
|
| 225 |
+
#include <ATen/ops/_nested_tensor_size_meta.h>
|
| 226 |
+
#include <ATen/ops/_nested_tensor_softmax_with_shape_meta.h>
|
| 227 |
+
#include <ATen/ops/_nested_tensor_storage_offsets_meta.h>
|
| 228 |
+
#include <ATen/ops/_nested_tensor_strides_meta.h>
|
| 229 |
+
#include <ATen/ops/_nested_view_from_buffer_meta.h>
|
| 230 |
+
#include <ATen/ops/_nested_view_from_buffer_copy_meta.h>
|
| 231 |
+
#include <ATen/ops/_nested_view_from_jagged_meta.h>
|
| 232 |
+
#include <ATen/ops/_nested_view_from_jagged_copy_meta.h>
|
| 233 |
+
#include <ATen/ops/_new_zeros_with_same_feature_meta_meta.h>
|
| 234 |
+
#include <ATen/ops/_nnpack_available_meta.h>
|
| 235 |
+
#include <ATen/ops/_nnpack_spatial_convolution_meta.h>
|
| 236 |
+
#include <ATen/ops/_nnz_meta.h>
|
| 237 |
+
#include <ATen/ops/_pack_padded_sequence_meta.h>
|
| 238 |
+
#include <ATen/ops/_pack_padded_sequence_backward_meta.h>
|
| 239 |
+
#include <ATen/ops/_pad_circular_meta.h>
|
| 240 |
+
#include <ATen/ops/_pad_enum_meta.h>
|
| 241 |
+
#include <ATen/ops/_pad_packed_sequence_meta.h>
|
| 242 |
+
#include <ATen/ops/_padded_dense_to_jagged_forward_meta.h>
|
| 243 |
+
#include <ATen/ops/_pdist_backward_meta.h>
|
| 244 |
+
#include <ATen/ops/_pdist_forward_meta.h>
|
| 245 |
+
#include <ATen/ops/_pin_memory_meta.h>
|
| 246 |
+
#include <ATen/ops/_prelu_kernel_meta.h>
|
| 247 |
+
#include <ATen/ops/_prelu_kernel_backward_meta.h>
|
| 248 |
+
#include <ATen/ops/_print_meta.h>
|
| 249 |
+
#include <ATen/ops/_propagate_xla_data_meta.h>
|
| 250 |
+
#include <ATen/ops/_remove_batch_dim_meta.h>
|
| 251 |
+
#include <ATen/ops/_reshape_alias_meta.h>
|
| 252 |
+
#include <ATen/ops/_reshape_alias_copy_meta.h>
|
| 253 |
+
#include <ATen/ops/_reshape_copy_meta.h>
|
| 254 |
+
#include <ATen/ops/_reshape_from_tensor_meta.h>
|
| 255 |
+
#include <ATen/ops/_resize_output_meta.h>
|
| 256 |
+
#include <ATen/ops/_rowwise_prune_meta.h>
|
| 257 |
+
#include <ATen/ops/_safe_softmax_meta.h>
|
| 258 |
+
#include <ATen/ops/_sample_dirichlet_meta.h>
|
| 259 |
+
#include <ATen/ops/_saturate_weight_to_fp16_meta.h>
|
| 260 |
+
#include <ATen/ops/_scaled_dot_product_attention_math_meta.h>
|
| 261 |
+
#include <ATen/ops/_scaled_dot_product_attention_math_for_mps_meta.h>
|
| 262 |
+
#include <ATen/ops/_scaled_dot_product_cudnn_attention_meta.h>
|
| 263 |
+
#include <ATen/ops/_scaled_dot_product_cudnn_attention_backward_meta.h>
|
| 264 |
+
#include <ATen/ops/_scaled_dot_product_efficient_attention_meta.h>
|
| 265 |
+
#include <ATen/ops/_scaled_dot_product_efficient_attention_backward_meta.h>
|
| 266 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_meta.h>
|
| 267 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_backward_meta.h>
|
| 268 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_meta.h>
|
| 269 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward_meta.h>
|
| 270 |
+
#include <ATen/ops/_scaled_dot_product_fused_attention_overrideable_meta.h>
|
| 271 |
+
#include <ATen/ops/_scaled_dot_product_fused_attention_overrideable_backward_meta.h>
|
| 272 |
+
#include <ATen/ops/_scaled_mm_meta.h>
|
| 273 |
+
#include <ATen/ops/_segment_reduce_backward_meta.h>
|
| 274 |
+
#include <ATen/ops/_shape_as_tensor_meta.h>
|
| 275 |
+
#include <ATen/ops/_slow_conv2d_backward_meta.h>
|
| 276 |
+
#include <ATen/ops/_slow_conv2d_forward_meta.h>
|
| 277 |
+
#include <ATen/ops/_sobol_engine_draw_meta.h>
|
| 278 |
+
#include <ATen/ops/_sobol_engine_ff_meta.h>
|
| 279 |
+
#include <ATen/ops/_sobol_engine_initialize_state_meta.h>
|
| 280 |
+
#include <ATen/ops/_sobol_engine_scramble_meta.h>
|
| 281 |
+
#include <ATen/ops/_softmax_meta.h>
|
| 282 |
+
#include <ATen/ops/_softmax_backward_data_meta.h>
|
| 283 |
+
#include <ATen/ops/_sparse_addmm_meta.h>
|
| 284 |
+
#include <ATen/ops/_sparse_broadcast_to_meta.h>
|
| 285 |
+
#include <ATen/ops/_sparse_broadcast_to_copy_meta.h>
|
| 286 |
+
#include <ATen/ops/_sparse_bsc_tensor_unsafe_meta.h>
|
| 287 |
+
#include <ATen/ops/_sparse_bsr_tensor_unsafe_meta.h>
|
| 288 |
+
#include <ATen/ops/_sparse_compressed_tensor_unsafe_meta.h>
|
| 289 |
+
#include <ATen/ops/_sparse_compressed_tensor_with_dims_meta.h>
|
| 290 |
+
#include <ATen/ops/_sparse_coo_tensor_unsafe_meta.h>
|
| 291 |
+
#include <ATen/ops/_sparse_coo_tensor_with_dims_meta.h>
|
| 292 |
+
#include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors_meta.h>
|
| 293 |
+
#include <ATen/ops/_sparse_csc_tensor_unsafe_meta.h>
|
| 294 |
+
#include <ATen/ops/_sparse_csr_prod_meta.h>
|
| 295 |
+
#include <ATen/ops/_sparse_csr_sum_meta.h>
|
| 296 |
+
#include <ATen/ops/_sparse_csr_tensor_unsafe_meta.h>
|
| 297 |
+
#include <ATen/ops/_sparse_log_softmax_meta.h>
|
| 298 |
+
#include <ATen/ops/_sparse_log_softmax_backward_data_meta.h>
|
| 299 |
+
#include <ATen/ops/_sparse_mask_projection_meta.h>
|
| 300 |
+
#include <ATen/ops/_sparse_mm_meta.h>
|
| 301 |
+
#include <ATen/ops/_sparse_mm_reduce_impl_meta.h>
|
| 302 |
+
#include <ATen/ops/_sparse_mm_reduce_impl_backward_meta.h>
|
| 303 |
+
#include <ATen/ops/_sparse_semi_structured_addmm_meta.h>
|
| 304 |
+
#include <ATen/ops/_sparse_semi_structured_apply_meta.h>
|
| 305 |
+
#include <ATen/ops/_sparse_semi_structured_apply_dense_meta.h>
|
| 306 |
+
#include <ATen/ops/_sparse_semi_structured_linear_meta.h>
|
| 307 |
+
#include <ATen/ops/_sparse_semi_structured_mm_meta.h>
|
| 308 |
+
#include <ATen/ops/_sparse_semi_structured_tile_meta.h>
|
| 309 |
+
#include <ATen/ops/_sparse_softmax_meta.h>
|
| 310 |
+
#include <ATen/ops/_sparse_softmax_backward_data_meta.h>
|
| 311 |
+
#include <ATen/ops/_sparse_sparse_matmul_meta.h>
|
| 312 |
+
#include <ATen/ops/_sparse_sum_meta.h>
|
| 313 |
+
#include <ATen/ops/_sparse_sum_backward_meta.h>
|
| 314 |
+
#include <ATen/ops/_spdiags_meta.h>
|
| 315 |
+
#include <ATen/ops/_spsolve_meta.h>
|
| 316 |
+
#include <ATen/ops/_stack_meta.h>
|
| 317 |
+
#include <ATen/ops/_standard_gamma_meta.h>
|
| 318 |
+
#include <ATen/ops/_standard_gamma_grad_meta.h>
|
| 319 |
+
#include <ATen/ops/_test_ambiguous_defaults_meta.h>
|
| 320 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_meta.h>
|
| 321 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_view_meta.h>
|
| 322 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_view_copy_meta.h>
|
| 323 |
+
#include <ATen/ops/_test_check_tensor_meta.h>
|
| 324 |
+
#include <ATen/ops/_test_functorch_fallback_meta.h>
|
| 325 |
+
#include <ATen/ops/_test_optional_filled_intlist_meta.h>
|
| 326 |
+
#include <ATen/ops/_test_optional_floatlist_meta.h>
|
| 327 |
+
#include <ATen/ops/_test_optional_intlist_meta.h>
|
| 328 |
+
#include <ATen/ops/_test_parallel_materialize_meta.h>
|
| 329 |
+
#include <ATen/ops/_test_serialization_subcmul_meta.h>
|
| 330 |
+
#include <ATen/ops/_test_string_default_meta.h>
|
| 331 |
+
#include <ATen/ops/_test_warn_in_autograd_meta.h>
|
| 332 |
+
#include <ATen/ops/_thnn_differentiable_gru_cell_backward_meta.h>
|
| 333 |
+
#include <ATen/ops/_thnn_differentiable_lstm_cell_backward_meta.h>
|
| 334 |
+
#include <ATen/ops/_thnn_fused_gru_cell_meta.h>
|
| 335 |
+
#include <ATen/ops/_thnn_fused_gru_cell_backward_meta.h>
|
| 336 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_meta.h>
|
| 337 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_backward_meta.h>
|
| 338 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_backward_impl_meta.h>
|
| 339 |
+
#include <ATen/ops/_to_copy_meta.h>
|
| 340 |
+
#include <ATen/ops/_to_cpu_meta.h>
|
| 341 |
+
#include <ATen/ops/_to_dense_meta.h>
|
| 342 |
+
#include <ATen/ops/_to_sparse_meta.h>
|
| 343 |
+
#include <ATen/ops/_to_sparse_bsc_meta.h>
|
| 344 |
+
#include <ATen/ops/_to_sparse_bsr_meta.h>
|
| 345 |
+
#include <ATen/ops/_to_sparse_csc_meta.h>
|
| 346 |
+
#include <ATen/ops/_to_sparse_csr_meta.h>
|
| 347 |
+
#include <ATen/ops/_to_sparse_semi_structured_meta.h>
|
| 348 |
+
#include <ATen/ops/_transform_bias_rescale_qkv_meta.h>
|
| 349 |
+
#include <ATen/ops/_transformer_encoder_layer_fwd_meta.h>
|
| 350 |
+
#include <ATen/ops/_trilinear_meta.h>
|
| 351 |
+
#include <ATen/ops/_triton_multi_head_attention_meta.h>
|
| 352 |
+
#include <ATen/ops/_triton_scaled_dot_attention_meta.h>
|
| 353 |
+
#include <ATen/ops/_unique_meta.h>
|
| 354 |
+
#include <ATen/ops/_unique2_meta.h>
|
| 355 |
+
#include <ATen/ops/_unpack_dual_meta.h>
|
| 356 |
+
#include <ATen/ops/_unsafe_index_meta.h>
|
| 357 |
+
#include <ATen/ops/_unsafe_index_put_meta.h>
|
| 358 |
+
#include <ATen/ops/_unsafe_masked_index_meta.h>
|
| 359 |
+
#include <ATen/ops/_unsafe_masked_index_put_accumulate_meta.h>
|
| 360 |
+
#include <ATen/ops/_unsafe_view_meta.h>
|
| 361 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_meta.h>
|
| 362 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_backward_meta.h>
|
| 363 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_meta.h>
|
| 364 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_backward_meta.h>
|
| 365 |
+
#include <ATen/ops/_upsample_nearest_exact1d_meta.h>
|
| 366 |
+
#include <ATen/ops/_upsample_nearest_exact1d_backward_meta.h>
|
| 367 |
+
#include <ATen/ops/_upsample_nearest_exact2d_meta.h>
|
| 368 |
+
#include <ATen/ops/_upsample_nearest_exact2d_backward_meta.h>
|
| 369 |
+
#include <ATen/ops/_upsample_nearest_exact3d_meta.h>
|
| 370 |
+
#include <ATen/ops/_upsample_nearest_exact3d_backward_meta.h>
|
| 371 |
+
#include <ATen/ops/_use_cudnn_ctc_loss_meta.h>
|
| 372 |
+
#include <ATen/ops/_use_cudnn_rnn_flatten_weight_meta.h>
|
| 373 |
+
#include <ATen/ops/_validate_compressed_sparse_indices_meta.h>
|
| 374 |
+
#include <ATen/ops/_validate_sparse_bsc_tensor_args_meta.h>
|
| 375 |
+
#include <ATen/ops/_validate_sparse_bsr_tensor_args_meta.h>
|
| 376 |
+
#include <ATen/ops/_validate_sparse_compressed_tensor_args_meta.h>
|
| 377 |
+
#include <ATen/ops/_validate_sparse_coo_tensor_args_meta.h>
|
| 378 |
+
#include <ATen/ops/_validate_sparse_csc_tensor_args_meta.h>
|
| 379 |
+
#include <ATen/ops/_validate_sparse_csr_tensor_args_meta.h>
|
| 380 |
+
#include <ATen/ops/_values_meta.h>
|
| 381 |
+
#include <ATen/ops/_values_copy_meta.h>
|
| 382 |
+
#include <ATen/ops/_version_meta.h>
|
| 383 |
+
#include <ATen/ops/_weight_int4pack_mm_meta.h>
|
| 384 |
+
#include <ATen/ops/_weight_int8pack_mm_meta.h>
|
| 385 |
+
#include <ATen/ops/_weight_norm_meta.h>
|
| 386 |
+
#include <ATen/ops/_weight_norm_differentiable_backward_meta.h>
|
| 387 |
+
#include <ATen/ops/_weight_norm_interface_meta.h>
|
| 388 |
+
#include <ATen/ops/_weight_norm_interface_backward_meta.h>
|
| 389 |
+
#include <ATen/ops/_wrapped_linear_prepack_meta.h>
|
| 390 |
+
#include <ATen/ops/_wrapped_quantized_linear_prepacked_meta.h>
|
| 391 |
+
#include <ATen/ops/abs_meta.h>
|
| 392 |
+
#include <ATen/ops/absolute_meta.h>
|
| 393 |
+
#include <ATen/ops/acos_meta.h>
|
| 394 |
+
#include <ATen/ops/acosh_meta.h>
|
| 395 |
+
#include <ATen/ops/adaptive_avg_pool1d_meta.h>
|
| 396 |
+
#include <ATen/ops/adaptive_avg_pool2d_meta.h>
|
| 397 |
+
#include <ATen/ops/adaptive_avg_pool3d_meta.h>
|
| 398 |
+
#include <ATen/ops/adaptive_avg_pool3d_backward_meta.h>
|
| 399 |
+
#include <ATen/ops/adaptive_max_pool1d_meta.h>
|
| 400 |
+
#include <ATen/ops/adaptive_max_pool2d_meta.h>
|
| 401 |
+
#include <ATen/ops/adaptive_max_pool2d_backward_meta.h>
|
| 402 |
+
#include <ATen/ops/adaptive_max_pool3d_meta.h>
|
| 403 |
+
#include <ATen/ops/adaptive_max_pool3d_backward_meta.h>
|
| 404 |
+
#include <ATen/ops/add_meta.h>
|
| 405 |
+
#include <ATen/ops/addbmm_meta.h>
|
| 406 |
+
#include <ATen/ops/addcdiv_meta.h>
|
| 407 |
+
#include <ATen/ops/addcmul_meta.h>
|
| 408 |
+
#include <ATen/ops/addmm_meta.h>
|
| 409 |
+
#include <ATen/ops/addmv_meta.h>
|
| 410 |
+
#include <ATen/ops/addr_meta.h>
|
| 411 |
+
#include <ATen/ops/adjoint_meta.h>
|
| 412 |
+
#include <ATen/ops/affine_grid_generator_meta.h>
|
| 413 |
+
#include <ATen/ops/affine_grid_generator_backward_meta.h>
|
| 414 |
+
#include <ATen/ops/alias_meta.h>
|
| 415 |
+
#include <ATen/ops/alias_copy_meta.h>
|
| 416 |
+
#include <ATen/ops/align_as_meta.h>
|
| 417 |
+
#include <ATen/ops/align_tensors_meta.h>
|
| 418 |
+
#include <ATen/ops/align_to_meta.h>
|
| 419 |
+
#include <ATen/ops/all_meta.h>
|
| 420 |
+
#include <ATen/ops/allclose_meta.h>
|
| 421 |
+
#include <ATen/ops/alpha_dropout_meta.h>
|
| 422 |
+
#include <ATen/ops/amax_meta.h>
|
| 423 |
+
#include <ATen/ops/amin_meta.h>
|
| 424 |
+
#include <ATen/ops/aminmax_meta.h>
|
| 425 |
+
#include <ATen/ops/and_meta.h>
|
| 426 |
+
#include <ATen/ops/angle_meta.h>
|
| 427 |
+
#include <ATen/ops/any_meta.h>
|
| 428 |
+
#include <ATen/ops/arange_meta.h>
|
| 429 |
+
#include <ATen/ops/arccos_meta.h>
|
| 430 |
+
#include <ATen/ops/arccosh_meta.h>
|
| 431 |
+
#include <ATen/ops/arcsin_meta.h>
|
| 432 |
+
#include <ATen/ops/arcsinh_meta.h>
|
| 433 |
+
#include <ATen/ops/arctan_meta.h>
|
| 434 |
+
#include <ATen/ops/arctan2_meta.h>
|
| 435 |
+
#include <ATen/ops/arctanh_meta.h>
|
| 436 |
+
#include <ATen/ops/argmax_meta.h>
|
| 437 |
+
#include <ATen/ops/argmin_meta.h>
|
| 438 |
+
#include <ATen/ops/argsort_meta.h>
|
| 439 |
+
#include <ATen/ops/argwhere_meta.h>
|
| 440 |
+
#include <ATen/ops/as_strided_meta.h>
|
| 441 |
+
#include <ATen/ops/as_strided_copy_meta.h>
|
| 442 |
+
#include <ATen/ops/as_strided_scatter_meta.h>
|
| 443 |
+
#include <ATen/ops/asin_meta.h>
|
| 444 |
+
#include <ATen/ops/asinh_meta.h>
|
| 445 |
+
#include <ATen/ops/atan_meta.h>
|
| 446 |
+
#include <ATen/ops/atan2_meta.h>
|
| 447 |
+
#include <ATen/ops/atanh_meta.h>
|
| 448 |
+
#include <ATen/ops/atleast_1d_meta.h>
|
| 449 |
+
#include <ATen/ops/atleast_2d_meta.h>
|
| 450 |
+
#include <ATen/ops/atleast_3d_meta.h>
|
| 451 |
+
#include <ATen/ops/avg_pool1d_meta.h>
|
| 452 |
+
#include <ATen/ops/avg_pool2d_meta.h>
|
| 453 |
+
#include <ATen/ops/avg_pool2d_backward_meta.h>
|
| 454 |
+
#include <ATen/ops/avg_pool3d_meta.h>
|
| 455 |
+
#include <ATen/ops/avg_pool3d_backward_meta.h>
|
| 456 |
+
#include <ATen/ops/baddbmm_meta.h>
|
| 457 |
+
#include <ATen/ops/bartlett_window_meta.h>
|
| 458 |
+
#include <ATen/ops/batch_norm_meta.h>
|
| 459 |
+
#include <ATen/ops/batch_norm_backward_meta.h>
|
| 460 |
+
#include <ATen/ops/batch_norm_backward_elemt_meta.h>
|
| 461 |
+
#include <ATen/ops/batch_norm_backward_reduce_meta.h>
|
| 462 |
+
#include <ATen/ops/batch_norm_elemt_meta.h>
|
| 463 |
+
#include <ATen/ops/batch_norm_gather_stats_meta.h>
|
| 464 |
+
#include <ATen/ops/batch_norm_gather_stats_with_counts_meta.h>
|
| 465 |
+
#include <ATen/ops/batch_norm_stats_meta.h>
|
| 466 |
+
#include <ATen/ops/batch_norm_update_stats_meta.h>
|
| 467 |
+
#include <ATen/ops/bernoulli_meta.h>
|
| 468 |
+
#include <ATen/ops/bilinear_meta.h>
|
| 469 |
+
#include <ATen/ops/binary_cross_entropy_meta.h>
|
| 470 |
+
#include <ATen/ops/binary_cross_entropy_backward_meta.h>
|
| 471 |
+
#include <ATen/ops/binary_cross_entropy_with_logits_meta.h>
|
| 472 |
+
#include <ATen/ops/bincount_meta.h>
|
| 473 |
+
#include <ATen/ops/binomial_meta.h>
|
| 474 |
+
#include <ATen/ops/bitwise_and_meta.h>
|
| 475 |
+
#include <ATen/ops/bitwise_left_shift_meta.h>
|
| 476 |
+
#include <ATen/ops/bitwise_not_meta.h>
|
| 477 |
+
#include <ATen/ops/bitwise_or_meta.h>
|
| 478 |
+
#include <ATen/ops/bitwise_right_shift_meta.h>
|
| 479 |
+
#include <ATen/ops/bitwise_xor_meta.h>
|
| 480 |
+
#include <ATen/ops/blackman_window_meta.h>
|
| 481 |
+
#include <ATen/ops/block_diag_meta.h>
|
| 482 |
+
#include <ATen/ops/bmm_meta.h>
|
| 483 |
+
#include <ATen/ops/broadcast_tensors_meta.h>
|
| 484 |
+
#include <ATen/ops/broadcast_to_meta.h>
|
| 485 |
+
#include <ATen/ops/bucketize_meta.h>
|
| 486 |
+
#include <ATen/ops/can_cast_meta.h>
|
| 487 |
+
#include <ATen/ops/cartesian_prod_meta.h>
|
| 488 |
+
#include <ATen/ops/cat_meta.h>
|
| 489 |
+
#include <ATen/ops/cauchy_meta.h>
|
| 490 |
+
#include <ATen/ops/ccol_indices_meta.h>
|
| 491 |
+
#include <ATen/ops/ccol_indices_copy_meta.h>
|
| 492 |
+
#include <ATen/ops/cdist_meta.h>
|
| 493 |
+
#include <ATen/ops/ceil_meta.h>
|
| 494 |
+
#include <ATen/ops/celu_meta.h>
|
| 495 |
+
#include <ATen/ops/chain_matmul_meta.h>
|
| 496 |
+
#include <ATen/ops/chalf_meta.h>
|
| 497 |
+
#include <ATen/ops/channel_shuffle_meta.h>
|
| 498 |
+
#include <ATen/ops/cholesky_meta.h>
|
| 499 |
+
#include <ATen/ops/cholesky_inverse_meta.h>
|
| 500 |
+
#include <ATen/ops/cholesky_solve_meta.h>
|
| 501 |
+
#include <ATen/ops/choose_qparams_optimized_meta.h>
|
| 502 |
+
#include <ATen/ops/chunk_meta.h>
|
| 503 |
+
#include <ATen/ops/clamp_meta.h>
|
| 504 |
+
#include <ATen/ops/clamp_max_meta.h>
|
| 505 |
+
#include <ATen/ops/clamp_min_meta.h>
|
| 506 |
+
#include <ATen/ops/clip_meta.h>
|
| 507 |
+
#include <ATen/ops/clone_meta.h>
|
| 508 |
+
#include <ATen/ops/coalesce_meta.h>
|
| 509 |
+
#include <ATen/ops/col2im_meta.h>
|
| 510 |
+
#include <ATen/ops/col_indices_meta.h>
|
| 511 |
+
#include <ATen/ops/col_indices_copy_meta.h>
|
| 512 |
+
#include <ATen/ops/column_stack_meta.h>
|
| 513 |
+
#include <ATen/ops/combinations_meta.h>
|
| 514 |
+
#include <ATen/ops/complex_meta.h>
|
| 515 |
+
#include <ATen/ops/concat_meta.h>
|
| 516 |
+
#include <ATen/ops/concatenate_meta.h>
|
| 517 |
+
#include <ATen/ops/conj_meta.h>
|
| 518 |
+
#include <ATen/ops/conj_physical_meta.h>
|
| 519 |
+
#include <ATen/ops/constant_pad_nd_meta.h>
|
| 520 |
+
#include <ATen/ops/contiguous_meta.h>
|
| 521 |
+
#include <ATen/ops/conv1d_meta.h>
|
| 522 |
+
#include <ATen/ops/conv2d_meta.h>
|
| 523 |
+
#include <ATen/ops/conv3d_meta.h>
|
| 524 |
+
#include <ATen/ops/conv_depthwise3d_meta.h>
|
| 525 |
+
#include <ATen/ops/conv_tbc_meta.h>
|
| 526 |
+
#include <ATen/ops/conv_tbc_backward_meta.h>
|
| 527 |
+
#include <ATen/ops/conv_transpose1d_meta.h>
|
| 528 |
+
#include <ATen/ops/conv_transpose2d_meta.h>
|
| 529 |
+
#include <ATen/ops/conv_transpose3d_meta.h>
|
| 530 |
+
#include <ATen/ops/convolution_meta.h>
|
| 531 |
+
#include <ATen/ops/convolution_backward_meta.h>
|
| 532 |
+
#include <ATen/ops/convolution_backward_overrideable_meta.h>
|
| 533 |
+
#include <ATen/ops/convolution_overrideable_meta.h>
|
| 534 |
+
#include <ATen/ops/copy_meta.h>
|
| 535 |
+
#include <ATen/ops/copy_sparse_to_sparse_meta.h>
|
| 536 |
+
#include <ATen/ops/copysign_meta.h>
|
| 537 |
+
#include <ATen/ops/corrcoef_meta.h>
|
| 538 |
+
#include <ATen/ops/cos_meta.h>
|
| 539 |
+
#include <ATen/ops/cosh_meta.h>
|
| 540 |
+
#include <ATen/ops/cosine_embedding_loss_meta.h>
|
| 541 |
+
#include <ATen/ops/cosine_similarity_meta.h>
|
| 542 |
+
#include <ATen/ops/count_nonzero_meta.h>
|
| 543 |
+
#include <ATen/ops/cov_meta.h>
|
| 544 |
+
#include <ATen/ops/cross_meta.h>
|
| 545 |
+
#include <ATen/ops/cross_entropy_loss_meta.h>
|
| 546 |
+
#include <ATen/ops/crow_indices_meta.h>
|
| 547 |
+
#include <ATen/ops/crow_indices_copy_meta.h>
|
| 548 |
+
#include <ATen/ops/ctc_loss_meta.h>
|
| 549 |
+
#include <ATen/ops/cudnn_affine_grid_generator_meta.h>
|
| 550 |
+
#include <ATen/ops/cudnn_affine_grid_generator_backward_meta.h>
|
| 551 |
+
#include <ATen/ops/cudnn_batch_norm_meta.h>
|
| 552 |
+
#include <ATen/ops/cudnn_batch_norm_backward_meta.h>
|
| 553 |
+
#include <ATen/ops/cudnn_convolution_meta.h>
|
| 554 |
+
#include <ATen/ops/cudnn_convolution_add_relu_meta.h>
|
| 555 |
+
#include <ATen/ops/cudnn_convolution_relu_meta.h>
|
| 556 |
+
#include <ATen/ops/cudnn_convolution_transpose_meta.h>
|
| 557 |
+
#include <ATen/ops/cudnn_grid_sampler_meta.h>
|
| 558 |
+
#include <ATen/ops/cudnn_grid_sampler_backward_meta.h>
|
| 559 |
+
#include <ATen/ops/cudnn_is_acceptable_meta.h>
|
| 560 |
+
#include <ATen/ops/cummax_meta.h>
|
| 561 |
+
#include <ATen/ops/cummaxmin_backward_meta.h>
|
| 562 |
+
#include <ATen/ops/cummin_meta.h>
|
| 563 |
+
#include <ATen/ops/cumprod_meta.h>
|
| 564 |
+
#include <ATen/ops/cumprod_backward_meta.h>
|
| 565 |
+
#include <ATen/ops/cumsum_meta.h>
|
| 566 |
+
#include <ATen/ops/cumulative_trapezoid_meta.h>
|
| 567 |
+
#include <ATen/ops/data_meta.h>
|
| 568 |
+
#include <ATen/ops/deg2rad_meta.h>
|
| 569 |
+
#include <ATen/ops/dense_dim_meta.h>
|
| 570 |
+
#include <ATen/ops/dequantize_meta.h>
|
| 571 |
+
#include <ATen/ops/det_meta.h>
|
| 572 |
+
#include <ATen/ops/detach_meta.h>
|
| 573 |
+
#include <ATen/ops/detach_copy_meta.h>
|
| 574 |
+
#include <ATen/ops/diag_meta.h>
|
| 575 |
+
#include <ATen/ops/diag_embed_meta.h>
|
| 576 |
+
#include <ATen/ops/diagflat_meta.h>
|
| 577 |
+
#include <ATen/ops/diagonal_meta.h>
|
| 578 |
+
#include <ATen/ops/diagonal_backward_meta.h>
|
| 579 |
+
#include <ATen/ops/diagonal_copy_meta.h>
|
| 580 |
+
#include <ATen/ops/diagonal_scatter_meta.h>
|
| 581 |
+
#include <ATen/ops/diff_meta.h>
|
| 582 |
+
#include <ATen/ops/digamma_meta.h>
|
| 583 |
+
#include <ATen/ops/dist_meta.h>
|
| 584 |
+
#include <ATen/ops/div_meta.h>
|
| 585 |
+
#include <ATen/ops/divide_meta.h>
|
| 586 |
+
#include <ATen/ops/dot_meta.h>
|
| 587 |
+
#include <ATen/ops/dropout_meta.h>
|
| 588 |
+
#include <ATen/ops/dsplit_meta.h>
|
| 589 |
+
#include <ATen/ops/dstack_meta.h>
|
| 590 |
+
#include <ATen/ops/einsum_meta.h>
|
| 591 |
+
#include <ATen/ops/elu_meta.h>
|
| 592 |
+
#include <ATen/ops/elu_backward_meta.h>
|
| 593 |
+
#include <ATen/ops/embedding_meta.h>
|
| 594 |
+
#include <ATen/ops/embedding_backward_meta.h>
|
| 595 |
+
#include <ATen/ops/embedding_bag_meta.h>
|
| 596 |
+
#include <ATen/ops/embedding_dense_backward_meta.h>
|
| 597 |
+
#include <ATen/ops/embedding_renorm_meta.h>
|
| 598 |
+
#include <ATen/ops/embedding_sparse_backward_meta.h>
|
| 599 |
+
#include <ATen/ops/empty_meta.h>
|
| 600 |
+
#include <ATen/ops/empty_like_meta.h>
|
| 601 |
+
#include <ATen/ops/empty_permuted_meta.h>
|
| 602 |
+
#include <ATen/ops/empty_quantized_meta.h>
|
| 603 |
+
#include <ATen/ops/empty_strided_meta.h>
|
| 604 |
+
#include <ATen/ops/eq_meta.h>
|
| 605 |
+
#include <ATen/ops/equal_meta.h>
|
| 606 |
+
#include <ATen/ops/erf_meta.h>
|
| 607 |
+
#include <ATen/ops/erfc_meta.h>
|
| 608 |
+
#include <ATen/ops/erfinv_meta.h>
|
| 609 |
+
#include <ATen/ops/exp_meta.h>
|
| 610 |
+
#include <ATen/ops/exp2_meta.h>
|
| 611 |
+
#include <ATen/ops/expand_meta.h>
|
| 612 |
+
#include <ATen/ops/expand_as_meta.h>
|
| 613 |
+
#include <ATen/ops/expand_copy_meta.h>
|
| 614 |
+
#include <ATen/ops/expm1_meta.h>
|
| 615 |
+
#include <ATen/ops/exponential_meta.h>
|
| 616 |
+
#include <ATen/ops/eye_meta.h>
|
| 617 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_meta.h>
|
| 618 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_meta.h>
|
| 619 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_backward_meta.h>
|
| 620 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_meta.h>
|
| 621 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_meta.h>
|
| 622 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_backward_meta.h>
|
| 623 |
+
#include <ATen/ops/fbgemm_linear_fp16_weight_meta.h>
|
| 624 |
+
#include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation_meta.h>
|
| 625 |
+
#include <ATen/ops/fbgemm_linear_int8_weight_meta.h>
|
| 626 |
+
#include <ATen/ops/fbgemm_linear_int8_weight_fp32_activation_meta.h>
|
| 627 |
+
#include <ATen/ops/fbgemm_linear_quantize_weight_meta.h>
|
| 628 |
+
#include <ATen/ops/fbgemm_pack_gemm_matrix_fp16_meta.h>
|
| 629 |
+
#include <ATen/ops/fbgemm_pack_quantized_matrix_meta.h>
|
| 630 |
+
#include <ATen/ops/feature_alpha_dropout_meta.h>
|
| 631 |
+
#include <ATen/ops/feature_dropout_meta.h>
|
| 632 |
+
#include <ATen/ops/fft_fft_meta.h>
|
| 633 |
+
#include <ATen/ops/fft_fft2_meta.h>
|
| 634 |
+
#include <ATen/ops/fft_fftfreq_meta.h>
|
| 635 |
+
#include <ATen/ops/fft_fftn_meta.h>
|
| 636 |
+
#include <ATen/ops/fft_fftshift_meta.h>
|
| 637 |
+
#include <ATen/ops/fft_hfft_meta.h>
|
| 638 |
+
#include <ATen/ops/fft_hfft2_meta.h>
|
| 639 |
+
#include <ATen/ops/fft_hfftn_meta.h>
|
| 640 |
+
#include <ATen/ops/fft_ifft_meta.h>
|
| 641 |
+
#include <ATen/ops/fft_ifft2_meta.h>
|
| 642 |
+
#include <ATen/ops/fft_ifftn_meta.h>
|
| 643 |
+
#include <ATen/ops/fft_ifftshift_meta.h>
|
| 644 |
+
#include <ATen/ops/fft_ihfft_meta.h>
|
| 645 |
+
#include <ATen/ops/fft_ihfft2_meta.h>
|
| 646 |
+
#include <ATen/ops/fft_ihfftn_meta.h>
|
| 647 |
+
#include <ATen/ops/fft_irfft_meta.h>
|
| 648 |
+
#include <ATen/ops/fft_irfft2_meta.h>
|
| 649 |
+
#include <ATen/ops/fft_irfftn_meta.h>
|
| 650 |
+
#include <ATen/ops/fft_rfft_meta.h>
|
| 651 |
+
#include <ATen/ops/fft_rfft2_meta.h>
|
| 652 |
+
#include <ATen/ops/fft_rfftfreq_meta.h>
|
| 653 |
+
#include <ATen/ops/fft_rfftn_meta.h>
|
| 654 |
+
#include <ATen/ops/fill_meta.h>
|
| 655 |
+
#include <ATen/ops/fill_diagonal_meta.h>
|
| 656 |
+
#include <ATen/ops/fix_meta.h>
|
| 657 |
+
#include <ATen/ops/flatten_meta.h>
|
| 658 |
+
#include <ATen/ops/flatten_dense_tensors_meta.h>
|
| 659 |
+
#include <ATen/ops/flip_meta.h>
|
| 660 |
+
#include <ATen/ops/fliplr_meta.h>
|
| 661 |
+
#include <ATen/ops/flipud_meta.h>
|
| 662 |
+
#include <ATen/ops/float_power_meta.h>
|
| 663 |
+
#include <ATen/ops/floor_meta.h>
|
| 664 |
+
#include <ATen/ops/floor_divide_meta.h>
|
| 665 |
+
#include <ATen/ops/fmax_meta.h>
|
| 666 |
+
#include <ATen/ops/fmin_meta.h>
|
| 667 |
+
#include <ATen/ops/fmod_meta.h>
|
| 668 |
+
#include <ATen/ops/frac_meta.h>
|
| 669 |
+
#include <ATen/ops/fractional_max_pool2d_meta.h>
|
| 670 |
+
#include <ATen/ops/fractional_max_pool2d_backward_meta.h>
|
| 671 |
+
#include <ATen/ops/fractional_max_pool3d_meta.h>
|
| 672 |
+
#include <ATen/ops/fractional_max_pool3d_backward_meta.h>
|
| 673 |
+
#include <ATen/ops/frexp_meta.h>
|
| 674 |
+
#include <ATen/ops/frobenius_norm_meta.h>
|
| 675 |
+
#include <ATen/ops/from_file_meta.h>
|
| 676 |
+
#include <ATen/ops/full_meta.h>
|
| 677 |
+
#include <ATen/ops/full_like_meta.h>
|
| 678 |
+
#include <ATen/ops/fused_moving_avg_obs_fake_quant_meta.h>
|
| 679 |
+
#include <ATen/ops/gather_meta.h>
|
| 680 |
+
#include <ATen/ops/gather_backward_meta.h>
|
| 681 |
+
#include <ATen/ops/gcd_meta.h>
|
| 682 |
+
#include <ATen/ops/ge_meta.h>
|
| 683 |
+
#include <ATen/ops/gelu_meta.h>
|
| 684 |
+
#include <ATen/ops/gelu_backward_meta.h>
|
| 685 |
+
#include <ATen/ops/geometric_meta.h>
|
| 686 |
+
#include <ATen/ops/geqrf_meta.h>
|
| 687 |
+
#include <ATen/ops/ger_meta.h>
|
| 688 |
+
#include <ATen/ops/glu_meta.h>
|
| 689 |
+
#include <ATen/ops/glu_backward_meta.h>
|
| 690 |
+
#include <ATen/ops/glu_backward_jvp_meta.h>
|
| 691 |
+
#include <ATen/ops/glu_jvp_meta.h>
|
| 692 |
+
#include <ATen/ops/gradient_meta.h>
|
| 693 |
+
#include <ATen/ops/greater_meta.h>
|
| 694 |
+
#include <ATen/ops/greater_equal_meta.h>
|
| 695 |
+
#include <ATen/ops/grid_sampler_meta.h>
|
| 696 |
+
#include <ATen/ops/grid_sampler_2d_meta.h>
|
| 697 |
+
#include <ATen/ops/grid_sampler_2d_backward_meta.h>
|
| 698 |
+
#include <ATen/ops/grid_sampler_3d_meta.h>
|
| 699 |
+
#include <ATen/ops/grid_sampler_3d_backward_meta.h>
|
| 700 |
+
#include <ATen/ops/group_norm_meta.h>
|
| 701 |
+
#include <ATen/ops/gru_meta.h>
|
| 702 |
+
#include <ATen/ops/gru_cell_meta.h>
|
| 703 |
+
#include <ATen/ops/gt_meta.h>
|
| 704 |
+
#include <ATen/ops/hamming_window_meta.h>
|
| 705 |
+
#include <ATen/ops/hann_window_meta.h>
|
| 706 |
+
#include <ATen/ops/hardshrink_meta.h>
|
| 707 |
+
#include <ATen/ops/hardshrink_backward_meta.h>
|
| 708 |
+
#include <ATen/ops/hardsigmoid_meta.h>
|
| 709 |
+
#include <ATen/ops/hardsigmoid_backward_meta.h>
|
| 710 |
+
#include <ATen/ops/hardswish_meta.h>
|
| 711 |
+
#include <ATen/ops/hardswish_backward_meta.h>
|
| 712 |
+
#include <ATen/ops/hardtanh_meta.h>
|
| 713 |
+
#include <ATen/ops/hardtanh_backward_meta.h>
|
| 714 |
+
#include <ATen/ops/heaviside_meta.h>
|
| 715 |
+
#include <ATen/ops/hinge_embedding_loss_meta.h>
|
| 716 |
+
#include <ATen/ops/histc_meta.h>
|
| 717 |
+
#include <ATen/ops/histogram_meta.h>
|
| 718 |
+
#include <ATen/ops/histogramdd_meta.h>
|
| 719 |
+
#include <ATen/ops/hsplit_meta.h>
|
| 720 |
+
#include <ATen/ops/hspmm_meta.h>
|
| 721 |
+
#include <ATen/ops/hstack_meta.h>
|
| 722 |
+
#include <ATen/ops/huber_loss_meta.h>
|
| 723 |
+
#include <ATen/ops/huber_loss_backward_meta.h>
|
| 724 |
+
#include <ATen/ops/hypot_meta.h>
|
| 725 |
+
#include <ATen/ops/i0_meta.h>
|
| 726 |
+
#include <ATen/ops/igamma_meta.h>
|
| 727 |
+
#include <ATen/ops/igammac_meta.h>
|
| 728 |
+
#include <ATen/ops/im2col_meta.h>
|
| 729 |
+
#include <ATen/ops/imag_meta.h>
|
| 730 |
+
#include <ATen/ops/index_meta.h>
|
| 731 |
+
#include <ATen/ops/index_add_meta.h>
|
| 732 |
+
#include <ATen/ops/index_copy_meta.h>
|
| 733 |
+
#include <ATen/ops/index_fill_meta.h>
|
| 734 |
+
#include <ATen/ops/index_put_meta.h>
|
| 735 |
+
#include <ATen/ops/index_reduce_meta.h>
|
| 736 |
+
#include <ATen/ops/index_select_meta.h>
|
| 737 |
+
#include <ATen/ops/index_select_backward_meta.h>
|
| 738 |
+
#include <ATen/ops/indices_meta.h>
|
| 739 |
+
#include <ATen/ops/indices_copy_meta.h>
|
| 740 |
+
#include <ATen/ops/infinitely_differentiable_gelu_backward_meta.h>
|
| 741 |
+
#include <ATen/ops/inner_meta.h>
|
| 742 |
+
#include <ATen/ops/instance_norm_meta.h>
|
| 743 |
+
#include <ATen/ops/int_repr_meta.h>
|
| 744 |
+
#include <ATen/ops/inverse_meta.h>
|
| 745 |
+
#include <ATen/ops/is_coalesced_meta.h>
|
| 746 |
+
#include <ATen/ops/is_complex_meta.h>
|
| 747 |
+
#include <ATen/ops/is_conj_meta.h>
|
| 748 |
+
#include <ATen/ops/is_distributed_meta.h>
|
| 749 |
+
#include <ATen/ops/is_floating_point_meta.h>
|
| 750 |
+
#include <ATen/ops/is_inference_meta.h>
|
| 751 |
+
#include <ATen/ops/is_leaf_meta.h>
|
| 752 |
+
#include <ATen/ops/is_neg_meta.h>
|
| 753 |
+
#include <ATen/ops/is_nonzero_meta.h>
|
| 754 |
+
#include <ATen/ops/is_pinned_meta.h>
|
| 755 |
+
#include <ATen/ops/is_same_size_meta.h>
|
| 756 |
+
#include <ATen/ops/is_set_to_meta.h>
|
| 757 |
+
#include <ATen/ops/is_signed_meta.h>
|
| 758 |
+
#include <ATen/ops/is_vulkan_available_meta.h>
|
| 759 |
+
#include <ATen/ops/isclose_meta.h>
|
| 760 |
+
#include <ATen/ops/isfinite_meta.h>
|
| 761 |
+
#include <ATen/ops/isin_meta.h>
|
| 762 |
+
#include <ATen/ops/isinf_meta.h>
|
| 763 |
+
#include <ATen/ops/isnan_meta.h>
|
| 764 |
+
#include <ATen/ops/isneginf_meta.h>
|
| 765 |
+
#include <ATen/ops/isposinf_meta.h>
|
| 766 |
+
#include <ATen/ops/isreal_meta.h>
|
| 767 |
+
#include <ATen/ops/istft_meta.h>
|
| 768 |
+
#include <ATen/ops/item_meta.h>
|
| 769 |
+
#include <ATen/ops/kaiser_window_meta.h>
|
| 770 |
+
#include <ATen/ops/kl_div_meta.h>
|
| 771 |
+
#include <ATen/ops/kron_meta.h>
|
| 772 |
+
#include <ATen/ops/kthvalue_meta.h>
|
| 773 |
+
#include <ATen/ops/l1_loss_meta.h>
|
| 774 |
+
#include <ATen/ops/layer_norm_meta.h>
|
| 775 |
+
#include <ATen/ops/lcm_meta.h>
|
| 776 |
+
#include <ATen/ops/ldexp_meta.h>
|
| 777 |
+
#include <ATen/ops/le_meta.h>
|
| 778 |
+
#include <ATen/ops/leaky_relu_meta.h>
|
| 779 |
+
#include <ATen/ops/leaky_relu_backward_meta.h>
|
| 780 |
+
#include <ATen/ops/lerp_meta.h>
|
| 781 |
+
#include <ATen/ops/less_meta.h>
|
| 782 |
+
#include <ATen/ops/less_equal_meta.h>
|
| 783 |
+
#include <ATen/ops/lgamma_meta.h>
|
| 784 |
+
#include <ATen/ops/lift_meta.h>
|
| 785 |
+
#include <ATen/ops/lift_fresh_meta.h>
|
| 786 |
+
#include <ATen/ops/lift_fresh_copy_meta.h>
|
| 787 |
+
#include <ATen/ops/linalg_cholesky_meta.h>
|
| 788 |
+
#include <ATen/ops/linalg_cholesky_ex_meta.h>
|
| 789 |
+
#include <ATen/ops/linalg_cond_meta.h>
|
| 790 |
+
#include <ATen/ops/linalg_cross_meta.h>
|
| 791 |
+
#include <ATen/ops/linalg_det_meta.h>
|
| 792 |
+
#include <ATen/ops/linalg_diagonal_meta.h>
|
| 793 |
+
#include <ATen/ops/linalg_eig_meta.h>
|
| 794 |
+
#include <ATen/ops/linalg_eigh_meta.h>
|
| 795 |
+
#include <ATen/ops/linalg_eigvals_meta.h>
|
| 796 |
+
#include <ATen/ops/linalg_eigvalsh_meta.h>
|
| 797 |
+
#include <ATen/ops/linalg_householder_product_meta.h>
|
| 798 |
+
#include <ATen/ops/linalg_inv_meta.h>
|
| 799 |
+
#include <ATen/ops/linalg_inv_ex_meta.h>
|
| 800 |
+
#include <ATen/ops/linalg_ldl_factor_meta.h>
|
| 801 |
+
#include <ATen/ops/linalg_ldl_factor_ex_meta.h>
|
| 802 |
+
#include <ATen/ops/linalg_ldl_solve_meta.h>
|
| 803 |
+
#include <ATen/ops/linalg_lstsq_meta.h>
|
| 804 |
+
#include <ATen/ops/linalg_lu_meta.h>
|
| 805 |
+
#include <ATen/ops/linalg_lu_factor_meta.h>
|
| 806 |
+
#include <ATen/ops/linalg_lu_factor_ex_meta.h>
|
| 807 |
+
#include <ATen/ops/linalg_lu_solve_meta.h>
|
| 808 |
+
#include <ATen/ops/linalg_matmul_meta.h>
|
| 809 |
+
#include <ATen/ops/linalg_matrix_exp_meta.h>
|
| 810 |
+
#include <ATen/ops/linalg_matrix_norm_meta.h>
|
| 811 |
+
#include <ATen/ops/linalg_matrix_power_meta.h>
|
| 812 |
+
#include <ATen/ops/linalg_matrix_rank_meta.h>
|
| 813 |
+
#include <ATen/ops/linalg_multi_dot_meta.h>
|
| 814 |
+
#include <ATen/ops/linalg_norm_meta.h>
|
| 815 |
+
#include <ATen/ops/linalg_pinv_meta.h>
|
| 816 |
+
#include <ATen/ops/linalg_qr_meta.h>
|
| 817 |
+
#include <ATen/ops/linalg_slogdet_meta.h>
|
| 818 |
+
#include <ATen/ops/linalg_solve_meta.h>
|
| 819 |
+
#include <ATen/ops/linalg_solve_ex_meta.h>
|
| 820 |
+
#include <ATen/ops/linalg_solve_triangular_meta.h>
|
| 821 |
+
#include <ATen/ops/linalg_svd_meta.h>
|
| 822 |
+
#include <ATen/ops/linalg_svdvals_meta.h>
|
| 823 |
+
#include <ATen/ops/linalg_tensorinv_meta.h>
|
| 824 |
+
#include <ATen/ops/linalg_tensorsolve_meta.h>
|
| 825 |
+
#include <ATen/ops/linalg_vander_meta.h>
|
| 826 |
+
#include <ATen/ops/linalg_vecdot_meta.h>
|
| 827 |
+
#include <ATen/ops/linalg_vector_norm_meta.h>
|
| 828 |
+
#include <ATen/ops/linear_meta.h>
|
| 829 |
+
#include <ATen/ops/linear_backward_meta.h>
|
| 830 |
+
#include <ATen/ops/linspace_meta.h>
|
| 831 |
+
#include <ATen/ops/log_meta.h>
|
| 832 |
+
#include <ATen/ops/log10_meta.h>
|
| 833 |
+
#include <ATen/ops/log1p_meta.h>
|
| 834 |
+
#include <ATen/ops/log2_meta.h>
|
| 835 |
+
#include <ATen/ops/log_normal_meta.h>
|
| 836 |
+
#include <ATen/ops/log_sigmoid_meta.h>
|
| 837 |
+
#include <ATen/ops/log_sigmoid_backward_meta.h>
|
| 838 |
+
#include <ATen/ops/log_sigmoid_forward_meta.h>
|
| 839 |
+
#include <ATen/ops/log_softmax_meta.h>
|
| 840 |
+
#include <ATen/ops/logaddexp_meta.h>
|
| 841 |
+
#include <ATen/ops/logaddexp2_meta.h>
|
| 842 |
+
#include <ATen/ops/logcumsumexp_meta.h>
|
| 843 |
+
#include <ATen/ops/logdet_meta.h>
|
| 844 |
+
#include <ATen/ops/logical_and_meta.h>
|
| 845 |
+
#include <ATen/ops/logical_not_meta.h>
|
| 846 |
+
#include <ATen/ops/logical_or_meta.h>
|
| 847 |
+
#include <ATen/ops/logical_xor_meta.h>
|
| 848 |
+
#include <ATen/ops/logit_meta.h>
|
| 849 |
+
#include <ATen/ops/logit_backward_meta.h>
|
| 850 |
+
#include <ATen/ops/logspace_meta.h>
|
| 851 |
+
#include <ATen/ops/logsumexp_meta.h>
|
| 852 |
+
#include <ATen/ops/lshift_meta.h>
|
| 853 |
+
#include <ATen/ops/lstm_meta.h>
|
| 854 |
+
#include <ATen/ops/lstm_cell_meta.h>
|
| 855 |
+
#include <ATen/ops/lstm_mps_backward_meta.h>
|
| 856 |
+
#include <ATen/ops/lt_meta.h>
|
| 857 |
+
#include <ATen/ops/lu_solve_meta.h>
|
| 858 |
+
#include <ATen/ops/lu_unpack_meta.h>
|
| 859 |
+
#include <ATen/ops/mH_meta.h>
|
| 860 |
+
#include <ATen/ops/mT_meta.h>
|
| 861 |
+
#include <ATen/ops/margin_ranking_loss_meta.h>
|
| 862 |
+
#include <ATen/ops/masked_fill_meta.h>
|
| 863 |
+
#include <ATen/ops/masked_scatter_meta.h>
|
| 864 |
+
#include <ATen/ops/masked_scatter_backward_meta.h>
|
| 865 |
+
#include <ATen/ops/masked_select_meta.h>
|
| 866 |
+
#include <ATen/ops/masked_select_backward_meta.h>
|
| 867 |
+
#include <ATen/ops/matmul_meta.h>
|
| 868 |
+
#include <ATen/ops/matmul_backward_meta.h>
|
| 869 |
+
#include <ATen/ops/matrix_H_meta.h>
|
| 870 |
+
#include <ATen/ops/matrix_exp_meta.h>
|
| 871 |
+
#include <ATen/ops/matrix_exp_backward_meta.h>
|
| 872 |
+
#include <ATen/ops/matrix_power_meta.h>
|
| 873 |
+
#include <ATen/ops/max_meta.h>
|
| 874 |
+
#include <ATen/ops/max_pool1d_meta.h>
|
| 875 |
+
#include <ATen/ops/max_pool1d_with_indices_meta.h>
|
| 876 |
+
#include <ATen/ops/max_pool2d_meta.h>
|
| 877 |
+
#include <ATen/ops/max_pool2d_backward_meta.h>
|
| 878 |
+
#include <ATen/ops/max_pool2d_with_indices_meta.h>
|
| 879 |
+
#include <ATen/ops/max_pool2d_with_indices_backward_meta.h>
|
| 880 |
+
#include <ATen/ops/max_pool3d_meta.h>
|
| 881 |
+
#include <ATen/ops/max_pool3d_with_indices_meta.h>
|
| 882 |
+
#include <ATen/ops/max_pool3d_with_indices_backward_meta.h>
|
| 883 |
+
#include <ATen/ops/max_unpool2d_meta.h>
|
| 884 |
+
#include <ATen/ops/max_unpool3d_meta.h>
|
| 885 |
+
#include <ATen/ops/maximum_meta.h>
|
| 886 |
+
#include <ATen/ops/mean_meta.h>
|
| 887 |
+
#include <ATen/ops/median_meta.h>
|
| 888 |
+
#include <ATen/ops/meshgrid_meta.h>
|
| 889 |
+
#include <ATen/ops/min_meta.h>
|
| 890 |
+
#include <ATen/ops/minimum_meta.h>
|
| 891 |
+
#include <ATen/ops/miopen_batch_norm_meta.h>
|
| 892 |
+
#include <ATen/ops/miopen_batch_norm_backward_meta.h>
|
| 893 |
+
#include <ATen/ops/miopen_convolution_meta.h>
|
| 894 |
+
#include <ATen/ops/miopen_convolution_add_relu_meta.h>
|
| 895 |
+
#include <ATen/ops/miopen_convolution_relu_meta.h>
|
| 896 |
+
#include <ATen/ops/miopen_convolution_transpose_meta.h>
|
| 897 |
+
#include <ATen/ops/miopen_depthwise_convolution_meta.h>
|
| 898 |
+
#include <ATen/ops/miopen_rnn_meta.h>
|
| 899 |
+
#include <ATen/ops/miopen_rnn_backward_meta.h>
|
| 900 |
+
#include <ATen/ops/mish_meta.h>
|
| 901 |
+
#include <ATen/ops/mish_backward_meta.h>
|
| 902 |
+
#include <ATen/ops/mkldnn_adaptive_avg_pool2d_meta.h>
|
| 903 |
+
#include <ATen/ops/mkldnn_adaptive_avg_pool2d_backward_meta.h>
|
| 904 |
+
#include <ATen/ops/mkldnn_convolution_meta.h>
|
| 905 |
+
#include <ATen/ops/mkldnn_linear_meta.h>
|
| 906 |
+
#include <ATen/ops/mkldnn_linear_backward_meta.h>
|
| 907 |
+
#include <ATen/ops/mkldnn_linear_backward_input_meta.h>
|
| 908 |
+
#include <ATen/ops/mkldnn_linear_backward_weights_meta.h>
|
| 909 |
+
#include <ATen/ops/mkldnn_max_pool2d_meta.h>
|
| 910 |
+
#include <ATen/ops/mkldnn_max_pool2d_backward_meta.h>
|
| 911 |
+
#include <ATen/ops/mkldnn_max_pool3d_meta.h>
|
| 912 |
+
#include <ATen/ops/mkldnn_max_pool3d_backward_meta.h>
|
| 913 |
+
#include <ATen/ops/mkldnn_reorder_conv2d_weight_meta.h>
|
| 914 |
+
#include <ATen/ops/mkldnn_reorder_conv3d_weight_meta.h>
|
| 915 |
+
#include <ATen/ops/mkldnn_rnn_layer_meta.h>
|
| 916 |
+
#include <ATen/ops/mkldnn_rnn_layer_backward_meta.h>
|
| 917 |
+
#include <ATen/ops/mm_meta.h>
|
| 918 |
+
#include <ATen/ops/mode_meta.h>
|
| 919 |
+
#include <ATen/ops/moveaxis_meta.h>
|
| 920 |
+
#include <ATen/ops/movedim_meta.h>
|
| 921 |
+
#include <ATen/ops/mps_convolution_backward_meta.h>
|
| 922 |
+
#include <ATen/ops/mps_convolution_transpose_backward_meta.h>
|
| 923 |
+
#include <ATen/ops/mse_loss_meta.h>
|
| 924 |
+
#include <ATen/ops/mse_loss_backward_meta.h>
|
| 925 |
+
#include <ATen/ops/msort_meta.h>
|
| 926 |
+
#include <ATen/ops/mul_meta.h>
|
| 927 |
+
#include <ATen/ops/multi_margin_loss_meta.h>
|
| 928 |
+
#include <ATen/ops/multi_margin_loss_backward_meta.h>
|
| 929 |
+
#include <ATen/ops/multilabel_margin_loss_meta.h>
|
| 930 |
+
#include <ATen/ops/multilabel_margin_loss_backward_meta.h>
|
| 931 |
+
#include <ATen/ops/multilabel_margin_loss_forward_meta.h>
|
| 932 |
+
#include <ATen/ops/multinomial_meta.h>
|
| 933 |
+
#include <ATen/ops/multiply_meta.h>
|
| 934 |
+
#include <ATen/ops/mv_meta.h>
|
| 935 |
+
#include <ATen/ops/mvlgamma_meta.h>
|
| 936 |
+
#include <ATen/ops/nan_to_num_meta.h>
|
| 937 |
+
#include <ATen/ops/nanmean_meta.h>
|
| 938 |
+
#include <ATen/ops/nanmedian_meta.h>
|
| 939 |
+
#include <ATen/ops/nanquantile_meta.h>
|
| 940 |
+
#include <ATen/ops/nansum_meta.h>
|
| 941 |
+
#include <ATen/ops/narrow_meta.h>
|
| 942 |
+
#include <ATen/ops/narrow_copy_meta.h>
|
| 943 |
+
#include <ATen/ops/native_batch_norm_meta.h>
|
| 944 |
+
#include <ATen/ops/native_batch_norm_backward_meta.h>
|
| 945 |
+
#include <ATen/ops/native_channel_shuffle_meta.h>
|
| 946 |
+
#include <ATen/ops/native_dropout_meta.h>
|
| 947 |
+
#include <ATen/ops/native_dropout_backward_meta.h>
|
| 948 |
+
#include <ATen/ops/native_group_norm_meta.h>
|
| 949 |
+
#include <ATen/ops/native_group_norm_backward_meta.h>
|
| 950 |
+
#include <ATen/ops/native_layer_norm_meta.h>
|
| 951 |
+
#include <ATen/ops/native_layer_norm_backward_meta.h>
|
| 952 |
+
#include <ATen/ops/native_norm_meta.h>
|
| 953 |
+
#include <ATen/ops/ne_meta.h>
|
| 954 |
+
#include <ATen/ops/neg_meta.h>
|
| 955 |
+
#include <ATen/ops/negative_meta.h>
|
| 956 |
+
#include <ATen/ops/nested_to_padded_tensor_meta.h>
|
| 957 |
+
#include <ATen/ops/new_empty_meta.h>
|
| 958 |
+
#include <ATen/ops/new_empty_strided_meta.h>
|
| 959 |
+
#include <ATen/ops/new_full_meta.h>
|
| 960 |
+
#include <ATen/ops/new_ones_meta.h>
|
| 961 |
+
#include <ATen/ops/new_zeros_meta.h>
|
| 962 |
+
#include <ATen/ops/nextafter_meta.h>
|
| 963 |
+
#include <ATen/ops/nll_loss_meta.h>
|
| 964 |
+
#include <ATen/ops/nll_loss2d_meta.h>
|
| 965 |
+
#include <ATen/ops/nll_loss2d_backward_meta.h>
|
| 966 |
+
#include <ATen/ops/nll_loss2d_forward_meta.h>
|
| 967 |
+
#include <ATen/ops/nll_loss_backward_meta.h>
|
| 968 |
+
#include <ATen/ops/nll_loss_forward_meta.h>
|
| 969 |
+
#include <ATen/ops/nll_loss_nd_meta.h>
|
| 970 |
+
#include <ATen/ops/nonzero_meta.h>
|
| 971 |
+
#include <ATen/ops/nonzero_numpy_meta.h>
|
| 972 |
+
#include <ATen/ops/nonzero_static_meta.h>
|
| 973 |
+
#include <ATen/ops/norm_meta.h>
|
| 974 |
+
#include <ATen/ops/norm_except_dim_meta.h>
|
| 975 |
+
#include <ATen/ops/normal_meta.h>
|
| 976 |
+
#include <ATen/ops/not_equal_meta.h>
|
| 977 |
+
#include <ATen/ops/nuclear_norm_meta.h>
|
| 978 |
+
#include <ATen/ops/numpy_T_meta.h>
|
| 979 |
+
#include <ATen/ops/one_hot_meta.h>
|
| 980 |
+
#include <ATen/ops/ones_meta.h>
|
| 981 |
+
#include <ATen/ops/ones_like_meta.h>
|
| 982 |
+
#include <ATen/ops/or_meta.h>
|
| 983 |
+
#include <ATen/ops/orgqr_meta.h>
|
| 984 |
+
#include <ATen/ops/ormqr_meta.h>
|
| 985 |
+
#include <ATen/ops/outer_meta.h>
|
| 986 |
+
#include <ATen/ops/output_nr_meta.h>
|
| 987 |
+
#include <ATen/ops/pad_meta.h>
|
| 988 |
+
#include <ATen/ops/pad_sequence_meta.h>
|
| 989 |
+
#include <ATen/ops/pairwise_distance_meta.h>
|
| 990 |
+
#include <ATen/ops/pdist_meta.h>
|
| 991 |
+
#include <ATen/ops/permute_meta.h>
|
| 992 |
+
#include <ATen/ops/permute_copy_meta.h>
|
| 993 |
+
#include <ATen/ops/pin_memory_meta.h>
|
| 994 |
+
#include <ATen/ops/pinverse_meta.h>
|
| 995 |
+
#include <ATen/ops/pixel_shuffle_meta.h>
|
| 996 |
+
#include <ATen/ops/pixel_unshuffle_meta.h>
|
| 997 |
+
#include <ATen/ops/poisson_meta.h>
|
| 998 |
+
#include <ATen/ops/poisson_nll_loss_meta.h>
|
| 999 |
+
#include <ATen/ops/polar_meta.h>
|
| 1000 |
+
#include <ATen/ops/polygamma_meta.h>
|
| 1001 |
+
#include <ATen/ops/positive_meta.h>
|
| 1002 |
+
#include <ATen/ops/pow_meta.h>
|
| 1003 |
+
#include <ATen/ops/prelu_meta.h>
|
| 1004 |
+
#include <ATen/ops/prod_meta.h>
|
| 1005 |
+
#include <ATen/ops/promote_types_meta.h>
|
| 1006 |
+
#include <ATen/ops/put_meta.h>
|
| 1007 |
+
#include <ATen/ops/q_per_channel_axis_meta.h>
|
| 1008 |
+
#include <ATen/ops/q_per_channel_scales_meta.h>
|
| 1009 |
+
#include <ATen/ops/q_per_channel_zero_points_meta.h>
|
| 1010 |
+
#include <ATen/ops/q_scale_meta.h>
|
| 1011 |
+
#include <ATen/ops/q_zero_point_meta.h>
|
| 1012 |
+
#include <ATen/ops/qr_meta.h>
|
| 1013 |
+
#include <ATen/ops/qscheme_meta.h>
|
| 1014 |
+
#include <ATen/ops/quantile_meta.h>
|
| 1015 |
+
#include <ATen/ops/quantize_per_channel_meta.h>
|
| 1016 |
+
#include <ATen/ops/quantize_per_tensor_meta.h>
|
| 1017 |
+
#include <ATen/ops/quantize_per_tensor_dynamic_meta.h>
|
| 1018 |
+
#include <ATen/ops/quantized_batch_norm_meta.h>
|
| 1019 |
+
#include <ATen/ops/quantized_gru_cell_meta.h>
|
| 1020 |
+
#include <ATen/ops/quantized_lstm_cell_meta.h>
|
| 1021 |
+
#include <ATen/ops/quantized_max_pool1d_meta.h>
|
| 1022 |
+
#include <ATen/ops/quantized_max_pool2d_meta.h>
|
| 1023 |
+
#include <ATen/ops/quantized_max_pool3d_meta.h>
|
| 1024 |
+
#include <ATen/ops/quantized_rnn_relu_cell_meta.h>
|
| 1025 |
+
#include <ATen/ops/quantized_rnn_tanh_cell_meta.h>
|
| 1026 |
+
#include <ATen/ops/rad2deg_meta.h>
|
| 1027 |
+
#include <ATen/ops/rand_meta.h>
|
| 1028 |
+
#include <ATen/ops/rand_like_meta.h>
|
| 1029 |
+
#include <ATen/ops/randint_meta.h>
|
| 1030 |
+
#include <ATen/ops/randint_like_meta.h>
|
| 1031 |
+
#include <ATen/ops/randn_meta.h>
|
| 1032 |
+
#include <ATen/ops/randn_like_meta.h>
|
| 1033 |
+
#include <ATen/ops/random_meta.h>
|
| 1034 |
+
#include <ATen/ops/randperm_meta.h>
|
| 1035 |
+
#include <ATen/ops/range_meta.h>
|
| 1036 |
+
#include <ATen/ops/ravel_meta.h>
|
| 1037 |
+
#include <ATen/ops/real_meta.h>
|
| 1038 |
+
#include <ATen/ops/reciprocal_meta.h>
|
| 1039 |
+
#include <ATen/ops/record_stream_meta.h>
|
| 1040 |
+
#include <ATen/ops/refine_names_meta.h>
|
| 1041 |
+
#include <ATen/ops/reflection_pad1d_meta.h>
|
| 1042 |
+
#include <ATen/ops/reflection_pad1d_backward_meta.h>
|
| 1043 |
+
#include <ATen/ops/reflection_pad2d_meta.h>
|
| 1044 |
+
#include <ATen/ops/reflection_pad2d_backward_meta.h>
|
| 1045 |
+
#include <ATen/ops/reflection_pad3d_meta.h>
|
| 1046 |
+
#include <ATen/ops/reflection_pad3d_backward_meta.h>
|
| 1047 |
+
#include <ATen/ops/relu_meta.h>
|
| 1048 |
+
#include <ATen/ops/relu6_meta.h>
|
| 1049 |
+
#include <ATen/ops/remainder_meta.h>
|
| 1050 |
+
#include <ATen/ops/rename_meta.h>
|
| 1051 |
+
#include <ATen/ops/renorm_meta.h>
|
| 1052 |
+
#include <ATen/ops/repeat_meta.h>
|
| 1053 |
+
#include <ATen/ops/repeat_interleave_meta.h>
|
| 1054 |
+
#include <ATen/ops/replication_pad1d_meta.h>
|
| 1055 |
+
#include <ATen/ops/replication_pad1d_backward_meta.h>
|
| 1056 |
+
#include <ATen/ops/replication_pad2d_meta.h>
|
| 1057 |
+
#include <ATen/ops/replication_pad2d_backward_meta.h>
|
| 1058 |
+
#include <ATen/ops/replication_pad3d_meta.h>
|
| 1059 |
+
#include <ATen/ops/replication_pad3d_backward_meta.h>
|
| 1060 |
+
#include <ATen/ops/requires_grad_meta.h>
|
| 1061 |
+
#include <ATen/ops/reshape_meta.h>
|
| 1062 |
+
#include <ATen/ops/reshape_as_meta.h>
|
| 1063 |
+
#include <ATen/ops/resize_meta.h>
|
| 1064 |
+
#include <ATen/ops/resize_as_meta.h>
|
| 1065 |
+
#include <ATen/ops/resize_as_sparse_meta.h>
|
| 1066 |
+
#include <ATen/ops/resolve_conj_meta.h>
|
| 1067 |
+
#include <ATen/ops/resolve_neg_meta.h>
|
| 1068 |
+
#include <ATen/ops/result_type_meta.h>
|
| 1069 |
+
#include <ATen/ops/retain_grad_meta.h>
|
| 1070 |
+
#include <ATen/ops/retains_grad_meta.h>
|
| 1071 |
+
#include <ATen/ops/rms_norm_meta.h>
|
| 1072 |
+
#include <ATen/ops/rnn_relu_meta.h>
|
| 1073 |
+
#include <ATen/ops/rnn_relu_cell_meta.h>
|
| 1074 |
+
#include <ATen/ops/rnn_tanh_meta.h>
|
| 1075 |
+
#include <ATen/ops/rnn_tanh_cell_meta.h>
|
| 1076 |
+
#include <ATen/ops/roll_meta.h>
|
| 1077 |
+
#include <ATen/ops/rot90_meta.h>
|
| 1078 |
+
#include <ATen/ops/round_meta.h>
|
| 1079 |
+
#include <ATen/ops/row_indices_meta.h>
|
| 1080 |
+
#include <ATen/ops/row_indices_copy_meta.h>
|
| 1081 |
+
#include <ATen/ops/row_stack_meta.h>
|
| 1082 |
+
#include <ATen/ops/rrelu_meta.h>
|
| 1083 |
+
#include <ATen/ops/rrelu_with_noise_meta.h>
|
| 1084 |
+
#include <ATen/ops/rrelu_with_noise_backward_meta.h>
|
| 1085 |
+
#include <ATen/ops/rshift_meta.h>
|
| 1086 |
+
#include <ATen/ops/rsqrt_meta.h>
|
| 1087 |
+
#include <ATen/ops/rsub_meta.h>
|
| 1088 |
+
#include <ATen/ops/scalar_tensor_meta.h>
|
| 1089 |
+
#include <ATen/ops/scaled_dot_product_attention_meta.h>
|
| 1090 |
+
#include <ATen/ops/scatter_meta.h>
|
| 1091 |
+
#include <ATen/ops/scatter_add_meta.h>
|
| 1092 |
+
#include <ATen/ops/scatter_reduce_meta.h>
|
| 1093 |
+
#include <ATen/ops/searchsorted_meta.h>
|
| 1094 |
+
#include <ATen/ops/segment_reduce_meta.h>
|
| 1095 |
+
#include <ATen/ops/select_meta.h>
|
| 1096 |
+
#include <ATen/ops/select_backward_meta.h>
|
| 1097 |
+
#include <ATen/ops/select_copy_meta.h>
|
| 1098 |
+
#include <ATen/ops/select_scatter_meta.h>
|
| 1099 |
+
#include <ATen/ops/selu_meta.h>
|
| 1100 |
+
#include <ATen/ops/set_meta.h>
|
| 1101 |
+
#include <ATen/ops/set_data_meta.h>
|
| 1102 |
+
#include <ATen/ops/sgn_meta.h>
|
| 1103 |
+
#include <ATen/ops/sigmoid_meta.h>
|
| 1104 |
+
#include <ATen/ops/sigmoid_backward_meta.h>
|
| 1105 |
+
#include <ATen/ops/sign_meta.h>
|
| 1106 |
+
#include <ATen/ops/signbit_meta.h>
|
| 1107 |
+
#include <ATen/ops/silu_meta.h>
|
| 1108 |
+
#include <ATen/ops/silu_backward_meta.h>
|
| 1109 |
+
#include <ATen/ops/sin_meta.h>
|
| 1110 |
+
#include <ATen/ops/sinc_meta.h>
|
| 1111 |
+
#include <ATen/ops/sinh_meta.h>
|
| 1112 |
+
#include <ATen/ops/size_meta.h>
|
| 1113 |
+
#include <ATen/ops/slice_meta.h>
|
| 1114 |
+
#include <ATen/ops/slice_backward_meta.h>
|
| 1115 |
+
#include <ATen/ops/slice_copy_meta.h>
|
| 1116 |
+
#include <ATen/ops/slice_inverse_meta.h>
|
| 1117 |
+
#include <ATen/ops/slice_scatter_meta.h>
|
| 1118 |
+
#include <ATen/ops/slogdet_meta.h>
|
| 1119 |
+
#include <ATen/ops/slow_conv3d_meta.h>
|
| 1120 |
+
#include <ATen/ops/slow_conv3d_forward_meta.h>
|
| 1121 |
+
#include <ATen/ops/slow_conv_dilated2d_meta.h>
|
| 1122 |
+
#include <ATen/ops/slow_conv_dilated3d_meta.h>
|
| 1123 |
+
#include <ATen/ops/slow_conv_transpose2d_meta.h>
|
| 1124 |
+
#include <ATen/ops/slow_conv_transpose3d_meta.h>
|
| 1125 |
+
#include <ATen/ops/smm_meta.h>
|
| 1126 |
+
#include <ATen/ops/smooth_l1_loss_meta.h>
|
| 1127 |
+
#include <ATen/ops/smooth_l1_loss_backward_meta.h>
|
| 1128 |
+
#include <ATen/ops/soft_margin_loss_meta.h>
|
| 1129 |
+
#include <ATen/ops/soft_margin_loss_backward_meta.h>
|
| 1130 |
+
#include <ATen/ops/softmax_meta.h>
|
| 1131 |
+
#include <ATen/ops/softplus_meta.h>
|
| 1132 |
+
#include <ATen/ops/softplus_backward_meta.h>
|
| 1133 |
+
#include <ATen/ops/softshrink_meta.h>
|
| 1134 |
+
#include <ATen/ops/softshrink_backward_meta.h>
|
| 1135 |
+
#include <ATen/ops/sort_meta.h>
|
| 1136 |
+
#include <ATen/ops/sparse_bsc_tensor_meta.h>
|
| 1137 |
+
#include <ATen/ops/sparse_bsr_tensor_meta.h>
|
| 1138 |
+
#include <ATen/ops/sparse_compressed_tensor_meta.h>
|
| 1139 |
+
#include <ATen/ops/sparse_coo_tensor_meta.h>
|
| 1140 |
+
#include <ATen/ops/sparse_csc_tensor_meta.h>
|
| 1141 |
+
#include <ATen/ops/sparse_csr_tensor_meta.h>
|
| 1142 |
+
#include <ATen/ops/sparse_dim_meta.h>
|
| 1143 |
+
#include <ATen/ops/sparse_mask_meta.h>
|
| 1144 |
+
#include <ATen/ops/sparse_resize_meta.h>
|
| 1145 |
+
#include <ATen/ops/sparse_resize_and_clear_meta.h>
|
| 1146 |
+
#include <ATen/ops/sparse_sampled_addmm_meta.h>
|
| 1147 |
+
#include <ATen/ops/special_airy_ai_meta.h>
|
| 1148 |
+
#include <ATen/ops/special_bessel_j0_meta.h>
|
| 1149 |
+
#include <ATen/ops/special_bessel_j1_meta.h>
|
| 1150 |
+
#include <ATen/ops/special_bessel_y0_meta.h>
|
| 1151 |
+
#include <ATen/ops/special_bessel_y1_meta.h>
|
| 1152 |
+
#include <ATen/ops/special_chebyshev_polynomial_t_meta.h>
|
| 1153 |
+
#include <ATen/ops/special_chebyshev_polynomial_u_meta.h>
|
| 1154 |
+
#include <ATen/ops/special_chebyshev_polynomial_v_meta.h>
|
| 1155 |
+
#include <ATen/ops/special_chebyshev_polynomial_w_meta.h>
|
| 1156 |
+
#include <ATen/ops/special_digamma_meta.h>
|
| 1157 |
+
#include <ATen/ops/special_entr_meta.h>
|
| 1158 |
+
#include <ATen/ops/special_erf_meta.h>
|
| 1159 |
+
#include <ATen/ops/special_erfc_meta.h>
|
| 1160 |
+
#include <ATen/ops/special_erfcx_meta.h>
|
| 1161 |
+
#include <ATen/ops/special_erfinv_meta.h>
|
| 1162 |
+
#include <ATen/ops/special_exp2_meta.h>
|
| 1163 |
+
#include <ATen/ops/special_expit_meta.h>
|
| 1164 |
+
#include <ATen/ops/special_expm1_meta.h>
|
| 1165 |
+
#include <ATen/ops/special_gammainc_meta.h>
|
| 1166 |
+
#include <ATen/ops/special_gammaincc_meta.h>
|
| 1167 |
+
#include <ATen/ops/special_gammaln_meta.h>
|
| 1168 |
+
#include <ATen/ops/special_hermite_polynomial_h_meta.h>
|
| 1169 |
+
#include <ATen/ops/special_hermite_polynomial_he_meta.h>
|
| 1170 |
+
#include <ATen/ops/special_i0_meta.h>
|
| 1171 |
+
#include <ATen/ops/special_i0e_meta.h>
|
| 1172 |
+
#include <ATen/ops/special_i1_meta.h>
|
| 1173 |
+
#include <ATen/ops/special_i1e_meta.h>
|
| 1174 |
+
#include <ATen/ops/special_laguerre_polynomial_l_meta.h>
|
| 1175 |
+
#include <ATen/ops/special_legendre_polynomial_p_meta.h>
|
| 1176 |
+
#include <ATen/ops/special_log1p_meta.h>
|
| 1177 |
+
#include <ATen/ops/special_log_ndtr_meta.h>
|
| 1178 |
+
#include <ATen/ops/special_log_softmax_meta.h>
|
| 1179 |
+
#include <ATen/ops/special_logit_meta.h>
|
| 1180 |
+
#include <ATen/ops/special_logsumexp_meta.h>
|
| 1181 |
+
#include <ATen/ops/special_modified_bessel_i0_meta.h>
|
| 1182 |
+
#include <ATen/ops/special_modified_bessel_i1_meta.h>
|
| 1183 |
+
#include <ATen/ops/special_modified_bessel_k0_meta.h>
|
| 1184 |
+
#include <ATen/ops/special_modified_bessel_k1_meta.h>
|
| 1185 |
+
#include <ATen/ops/special_multigammaln_meta.h>
|
| 1186 |
+
#include <ATen/ops/special_ndtr_meta.h>
|
| 1187 |
+
#include <ATen/ops/special_ndtri_meta.h>
|
| 1188 |
+
#include <ATen/ops/special_polygamma_meta.h>
|
| 1189 |
+
#include <ATen/ops/special_psi_meta.h>
|
| 1190 |
+
#include <ATen/ops/special_round_meta.h>
|
| 1191 |
+
#include <ATen/ops/special_scaled_modified_bessel_k0_meta.h>
|
| 1192 |
+
#include <ATen/ops/special_scaled_modified_bessel_k1_meta.h>
|
| 1193 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_t_meta.h>
|
| 1194 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_u_meta.h>
|
| 1195 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_v_meta.h>
|
| 1196 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_w_meta.h>
|
| 1197 |
+
#include <ATen/ops/special_sinc_meta.h>
|
| 1198 |
+
#include <ATen/ops/special_softmax_meta.h>
|
| 1199 |
+
#include <ATen/ops/special_spherical_bessel_j0_meta.h>
|
| 1200 |
+
#include <ATen/ops/special_xlog1py_meta.h>
|
| 1201 |
+
#include <ATen/ops/special_xlogy_meta.h>
|
| 1202 |
+
#include <ATen/ops/special_zeta_meta.h>
|
| 1203 |
+
#include <ATen/ops/split_meta.h>
|
| 1204 |
+
#include <ATen/ops/split_copy_meta.h>
|
| 1205 |
+
#include <ATen/ops/split_with_sizes_meta.h>
|
| 1206 |
+
#include <ATen/ops/split_with_sizes_copy_meta.h>
|
| 1207 |
+
#include <ATen/ops/sqrt_meta.h>
|
| 1208 |
+
#include <ATen/ops/square_meta.h>
|
| 1209 |
+
#include <ATen/ops/squeeze_meta.h>
|
| 1210 |
+
#include <ATen/ops/squeeze_copy_meta.h>
|
| 1211 |
+
#include <ATen/ops/sspaddmm_meta.h>
|
| 1212 |
+
#include <ATen/ops/stack_meta.h>
|
| 1213 |
+
#include <ATen/ops/std_meta.h>
|
| 1214 |
+
#include <ATen/ops/std_mean_meta.h>
|
| 1215 |
+
#include <ATen/ops/stft_meta.h>
|
| 1216 |
+
#include <ATen/ops/stride_meta.h>
|
| 1217 |
+
#include <ATen/ops/sub_meta.h>
|
| 1218 |
+
#include <ATen/ops/subtract_meta.h>
|
| 1219 |
+
#include <ATen/ops/sum_meta.h>
|
| 1220 |
+
#include <ATen/ops/sum_to_size_meta.h>
|
| 1221 |
+
#include <ATen/ops/svd_meta.h>
|
| 1222 |
+
#include <ATen/ops/swapaxes_meta.h>
|
| 1223 |
+
#include <ATen/ops/swapdims_meta.h>
|
| 1224 |
+
#include <ATen/ops/sym_constrain_range_meta.h>
|
| 1225 |
+
#include <ATen/ops/sym_constrain_range_for_size_meta.h>
|
| 1226 |
+
#include <ATen/ops/sym_numel_meta.h>
|
| 1227 |
+
#include <ATen/ops/sym_size_meta.h>
|
| 1228 |
+
#include <ATen/ops/sym_storage_offset_meta.h>
|
| 1229 |
+
#include <ATen/ops/sym_stride_meta.h>
|
| 1230 |
+
#include <ATen/ops/t_meta.h>
|
| 1231 |
+
#include <ATen/ops/t_copy_meta.h>
|
| 1232 |
+
#include <ATen/ops/take_meta.h>
|
| 1233 |
+
#include <ATen/ops/take_along_dim_meta.h>
|
| 1234 |
+
#include <ATen/ops/tan_meta.h>
|
| 1235 |
+
#include <ATen/ops/tanh_meta.h>
|
| 1236 |
+
#include <ATen/ops/tanh_backward_meta.h>
|
| 1237 |
+
#include <ATen/ops/tensor_split_meta.h>
|
| 1238 |
+
#include <ATen/ops/tensordot_meta.h>
|
| 1239 |
+
#include <ATen/ops/thnn_conv2d_meta.h>
|
| 1240 |
+
#include <ATen/ops/threshold_meta.h>
|
| 1241 |
+
#include <ATen/ops/threshold_backward_meta.h>
|
| 1242 |
+
#include <ATen/ops/tile_meta.h>
|
| 1243 |
+
#include <ATen/ops/to_meta.h>
|
| 1244 |
+
#include <ATen/ops/to_dense_meta.h>
|
| 1245 |
+
#include <ATen/ops/to_dense_backward_meta.h>
|
| 1246 |
+
#include <ATen/ops/to_mkldnn_meta.h>
|
| 1247 |
+
#include <ATen/ops/to_mkldnn_backward_meta.h>
|
| 1248 |
+
#include <ATen/ops/to_padded_tensor_meta.h>
|
| 1249 |
+
#include <ATen/ops/to_sparse_meta.h>
|
| 1250 |
+
#include <ATen/ops/to_sparse_bsc_meta.h>
|
| 1251 |
+
#include <ATen/ops/to_sparse_bsr_meta.h>
|
| 1252 |
+
#include <ATen/ops/to_sparse_csc_meta.h>
|
| 1253 |
+
#include <ATen/ops/to_sparse_csr_meta.h>
|
| 1254 |
+
#include <ATen/ops/topk_meta.h>
|
| 1255 |
+
#include <ATen/ops/trace_meta.h>
|
| 1256 |
+
#include <ATen/ops/trace_backward_meta.h>
|
| 1257 |
+
#include <ATen/ops/transpose_meta.h>
|
| 1258 |
+
#include <ATen/ops/transpose_copy_meta.h>
|
| 1259 |
+
#include <ATen/ops/trapezoid_meta.h>
|
| 1260 |
+
#include <ATen/ops/trapz_meta.h>
|
| 1261 |
+
#include <ATen/ops/triangular_solve_meta.h>
|
| 1262 |
+
#include <ATen/ops/tril_meta.h>
|
| 1263 |
+
#include <ATen/ops/tril_indices_meta.h>
|
| 1264 |
+
#include <ATen/ops/triplet_margin_loss_meta.h>
|
| 1265 |
+
#include <ATen/ops/triu_meta.h>
|
| 1266 |
+
#include <ATen/ops/triu_indices_meta.h>
|
| 1267 |
+
#include <ATen/ops/true_divide_meta.h>
|
| 1268 |
+
#include <ATen/ops/trunc_meta.h>
|
| 1269 |
+
#include <ATen/ops/type_as_meta.h>
|
| 1270 |
+
#include <ATen/ops/unbind_meta.h>
|
| 1271 |
+
#include <ATen/ops/unbind_copy_meta.h>
|
| 1272 |
+
#include <ATen/ops/unflatten_meta.h>
|
| 1273 |
+
#include <ATen/ops/unflatten_dense_tensors_meta.h>
|
| 1274 |
+
#include <ATen/ops/unfold_meta.h>
|
| 1275 |
+
#include <ATen/ops/unfold_backward_meta.h>
|
| 1276 |
+
#include <ATen/ops/unfold_copy_meta.h>
|
| 1277 |
+
#include <ATen/ops/uniform_meta.h>
|
| 1278 |
+
#include <ATen/ops/unique_consecutive_meta.h>
|
| 1279 |
+
#include <ATen/ops/unique_dim_meta.h>
|
| 1280 |
+
#include <ATen/ops/unique_dim_consecutive_meta.h>
|
| 1281 |
+
#include <ATen/ops/unsafe_chunk_meta.h>
|
| 1282 |
+
#include <ATen/ops/unsafe_split_meta.h>
|
| 1283 |
+
#include <ATen/ops/unsafe_split_with_sizes_meta.h>
|
| 1284 |
+
#include <ATen/ops/unsqueeze_meta.h>
|
| 1285 |
+
#include <ATen/ops/unsqueeze_copy_meta.h>
|
| 1286 |
+
#include <ATen/ops/upsample_bicubic2d_meta.h>
|
| 1287 |
+
#include <ATen/ops/upsample_bicubic2d_backward_meta.h>
|
| 1288 |
+
#include <ATen/ops/upsample_bilinear2d_meta.h>
|
| 1289 |
+
#include <ATen/ops/upsample_bilinear2d_backward_meta.h>
|
| 1290 |
+
#include <ATen/ops/upsample_linear1d_meta.h>
|
| 1291 |
+
#include <ATen/ops/upsample_linear1d_backward_meta.h>
|
| 1292 |
+
#include <ATen/ops/upsample_nearest1d_meta.h>
|
| 1293 |
+
#include <ATen/ops/upsample_nearest1d_backward_meta.h>
|
| 1294 |
+
#include <ATen/ops/upsample_nearest2d_meta.h>
|
| 1295 |
+
#include <ATen/ops/upsample_nearest2d_backward_meta.h>
|
| 1296 |
+
#include <ATen/ops/upsample_nearest3d_meta.h>
|
| 1297 |
+
#include <ATen/ops/upsample_nearest3d_backward_meta.h>
|
| 1298 |
+
#include <ATen/ops/upsample_trilinear3d_meta.h>
|
| 1299 |
+
#include <ATen/ops/upsample_trilinear3d_backward_meta.h>
|
| 1300 |
+
#include <ATen/ops/value_selecting_reduction_backward_meta.h>
|
| 1301 |
+
#include <ATen/ops/values_meta.h>
|
| 1302 |
+
#include <ATen/ops/values_copy_meta.h>
|
| 1303 |
+
#include <ATen/ops/vander_meta.h>
|
| 1304 |
+
#include <ATen/ops/var_meta.h>
|
| 1305 |
+
#include <ATen/ops/var_mean_meta.h>
|
| 1306 |
+
#include <ATen/ops/vdot_meta.h>
|
| 1307 |
+
#include <ATen/ops/view_meta.h>
|
| 1308 |
+
#include <ATen/ops/view_as_meta.h>
|
| 1309 |
+
#include <ATen/ops/view_as_complex_meta.h>
|
| 1310 |
+
#include <ATen/ops/view_as_complex_copy_meta.h>
|
| 1311 |
+
#include <ATen/ops/view_as_real_meta.h>
|
| 1312 |
+
#include <ATen/ops/view_as_real_copy_meta.h>
|
| 1313 |
+
#include <ATen/ops/view_copy_meta.h>
|
| 1314 |
+
#include <ATen/ops/vsplit_meta.h>
|
| 1315 |
+
#include <ATen/ops/vstack_meta.h>
|
| 1316 |
+
#include <ATen/ops/where_meta.h>
|
| 1317 |
+
#include <ATen/ops/xlogy_meta.h>
|
| 1318 |
+
#include <ATen/ops/xor_meta.h>
|
| 1319 |
+
#include <ATen/ops/zero_meta.h>
|
| 1320 |
+
#include <ATen/ops/zeros_meta.h>
|
| 1321 |
+
#include <ATen/ops/zeros_like_meta.h>
|
| 1322 |
+
|
| 1323 |
+
namespace at {
|
| 1324 |
+
|
| 1325 |
+
namespace meta {
|
| 1326 |
+
|
| 1327 |
+
|
| 1328 |
+
|
| 1329 |
+
} // namespace meta
|
| 1330 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/NumericUtils.h
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#ifdef __HIPCC__
|
| 4 |
+
#include <hip/hip_runtime.h>
|
| 5 |
+
#endif
|
| 6 |
+
|
| 7 |
+
#include <c10/macros/Macros.h>
|
| 8 |
+
#include <c10/util/BFloat16.h>
|
| 9 |
+
#include <c10/util/Float8_e4m3fn.h>
|
| 10 |
+
#include <c10/util/Float8_e4m3fnuz.h>
|
| 11 |
+
#include <c10/util/Float8_e5m2.h>
|
| 12 |
+
#include <c10/util/Float8_e5m2fnuz.h>
|
| 13 |
+
#include <c10/util/Half.h>
|
| 14 |
+
#include <c10/util/complex.h>
|
| 15 |
+
|
| 16 |
+
#include <cmath>
|
| 17 |
+
#include <type_traits>
|
| 18 |
+
|
| 19 |
+
namespace at {
|
| 20 |
+
|
| 21 |
+
// std::isnan isn't performant to use on integral types; it will
|
| 22 |
+
// (uselessly) convert to floating point and then do the test.
|
| 23 |
+
// This function is.
|
| 24 |
+
|
| 25 |
+
template <typename T, std::enable_if_t<std::is_integral_v<T>, int> = 0>
|
| 26 |
+
inline C10_HOST_DEVICE bool _isnan(T /*val*/) {
|
| 27 |
+
return false;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
template <typename T, std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
|
| 31 |
+
inline C10_HOST_DEVICE bool _isnan(T val) {
|
| 32 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 33 |
+
return ::isnan(val);
|
| 34 |
+
#else
|
| 35 |
+
return std::isnan(val);
|
| 36 |
+
#endif
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
template <typename T, std::enable_if_t<c10::is_complex<T>::value, int> = 0>
|
| 40 |
+
inline C10_HOST_DEVICE bool _isnan(T val) {
|
| 41 |
+
return std::isnan(val.real()) || std::isnan(val.imag());
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
template <typename T, std::enable_if_t<std::is_same_v<T, at::Half>, int> = 0>
|
| 45 |
+
inline C10_HOST_DEVICE bool _isnan(T val) {
|
| 46 |
+
return at::_isnan(static_cast<float>(val));
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
template <
|
| 50 |
+
typename T,
|
| 51 |
+
std::enable_if_t<std::is_same_v<T, at::BFloat16>, int> = 0>
|
| 52 |
+
inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {
|
| 53 |
+
return at::_isnan(static_cast<float>(val));
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {
|
| 57 |
+
return at::_isnan(static_cast<float>(val));
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
template <
|
| 61 |
+
typename T,
|
| 62 |
+
std::enable_if_t<std::is_same_v<T, at::Float8_e5m2>, int> = 0>
|
| 63 |
+
inline C10_HOST_DEVICE bool _isnan(T val) {
|
| 64 |
+
return val.isnan();
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
template <
|
| 68 |
+
typename T,
|
| 69 |
+
std::enable_if_t<std::is_same_v<T, at::Float8_e4m3fn>, int> = 0>
|
| 70 |
+
inline C10_HOST_DEVICE bool _isnan(T val) {
|
| 71 |
+
return val.isnan();
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
template <
|
| 75 |
+
typename T,
|
| 76 |
+
std::enable_if_t<std::is_same_v<T, at::Float8_e5m2fnuz>, int> = 0>
|
| 77 |
+
inline C10_HOST_DEVICE bool _isnan(T val) {
|
| 78 |
+
return val.isnan();
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
template <
|
| 82 |
+
typename T,
|
| 83 |
+
std::enable_if_t<std::is_same_v<T, at::Float8_e4m3fnuz>, int> = 0>
|
| 84 |
+
inline C10_HOST_DEVICE bool _isnan(T val) {
|
| 85 |
+
return val.isnan();
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
// std::isinf isn't performant to use on integral types; it will
|
| 89 |
+
// (uselessly) convert to floating point and then do the test.
|
| 90 |
+
// This function is.
|
| 91 |
+
|
| 92 |
+
template <typename T, std::enable_if_t<std::is_integral_v<T>, int> = 0>
|
| 93 |
+
inline C10_HOST_DEVICE bool _isinf(T /*val*/) {
|
| 94 |
+
return false;
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
template <typename T, std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
|
| 98 |
+
inline C10_HOST_DEVICE bool _isinf(T val) {
|
| 99 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 100 |
+
return ::isinf(val);
|
| 101 |
+
#else
|
| 102 |
+
return std::isinf(val);
|
| 103 |
+
#endif
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
inline C10_HOST_DEVICE bool _isinf(at::Half val) {
|
| 107 |
+
return at::_isinf(static_cast<float>(val));
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
inline C10_HOST_DEVICE bool _isinf(at::BFloat16 val) {
|
| 111 |
+
return at::_isinf(static_cast<float>(val));
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2 val) {
|
| 115 |
+
return val.isinf();
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fn val [[maybe_unused]]) {
|
| 119 |
+
return false;
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2fnuz val [[maybe_unused]]) {
|
| 123 |
+
return false;
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fnuz val [[maybe_unused]]) {
|
| 127 |
+
return false;
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
template <typename T>
|
| 131 |
+
C10_HOST_DEVICE inline T exp(T x) {
|
| 132 |
+
static_assert(
|
| 133 |
+
!std::is_same_v<T, double>,
|
| 134 |
+
"this template must be used with float or less precise type");
|
| 135 |
+
#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
|
| 136 |
+
// use __expf fast approximation for peak bandwidth
|
| 137 |
+
return __expf(x);
|
| 138 |
+
#else
|
| 139 |
+
return ::exp(x);
|
| 140 |
+
#endif
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
template <>
|
| 144 |
+
C10_HOST_DEVICE inline double exp<double>(double x) {
|
| 145 |
+
return ::exp(x);
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
template <typename T>
|
| 149 |
+
C10_HOST_DEVICE inline T log(T x) {
|
| 150 |
+
static_assert(
|
| 151 |
+
!std::is_same_v<T, double>,
|
| 152 |
+
"this template must be used with float or less precise type");
|
| 153 |
+
#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
|
| 154 |
+
// use __logf fast approximation for peak bandwidth
|
| 155 |
+
return __logf(x);
|
| 156 |
+
#else
|
| 157 |
+
return ::log(x);
|
| 158 |
+
#endif
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
template <>
|
| 162 |
+
C10_HOST_DEVICE inline double log<double>(double x) {
|
| 163 |
+
return ::log(x);
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
template <typename T>
|
| 167 |
+
C10_HOST_DEVICE inline T log1p(T x) {
|
| 168 |
+
static_assert(
|
| 169 |
+
!std::is_same_v<T, double>,
|
| 170 |
+
"this template must be used with float or less precise type");
|
| 171 |
+
#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
|
| 172 |
+
// use __logf fast approximation for peak bandwidth
|
| 173 |
+
// NOTE: There is no __log1pf so unfortunately we lose precision.
|
| 174 |
+
return __logf(1.0f + x);
|
| 175 |
+
#else
|
| 176 |
+
return ::log1p(x);
|
| 177 |
+
#endif
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
template <>
|
| 181 |
+
C10_HOST_DEVICE inline double log1p<double>(double x) {
|
| 182 |
+
return ::log1p(x);
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
template <typename T>
|
| 186 |
+
C10_HOST_DEVICE inline T tan(T x) {
|
| 187 |
+
static_assert(
|
| 188 |
+
!std::is_same_v<T, double>,
|
| 189 |
+
"this template must be used with float or less precise type");
|
| 190 |
+
#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
|
| 191 |
+
// use __tanf fast approximation for peak bandwidth
|
| 192 |
+
return __tanf(x);
|
| 193 |
+
#else
|
| 194 |
+
return ::tan(x);
|
| 195 |
+
#endif
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
template <>
|
| 199 |
+
C10_HOST_DEVICE inline double tan<double>(double x) {
|
| 200 |
+
return ::tan(x);
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/OpaqueTensorImpl.h
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/MemoryFormat.h>
|
| 4 |
+
#include <c10/core/SymIntArrayRef.h>
|
| 5 |
+
#include <c10/core/TensorImpl.h>
|
| 6 |
+
#include <c10/util/Exception.h>
|
| 7 |
+
|
| 8 |
+
namespace at {
|
| 9 |
+
|
| 10 |
+
// An "Opaque" TensorImpl -- there are no strides and (for now)
|
| 11 |
+
// even data() is not supported (thus no pointer arithmetic).
|
| 12 |
+
|
| 13 |
+
// NOTE: We could allow data() in the future, but would have to ensure pointer
|
| 14 |
+
// arithmetic code is properly guarded.
|
| 15 |
+
//
|
| 16 |
+
// NOTE: This does not support resize_ (and other metadata-changing ops) because
|
| 17 |
+
// of `shallow_copy_and_detach`. We would need to define an interface to
|
| 18 |
+
// "shallow copy" in order to add support.
|
| 19 |
+
|
| 20 |
+
template <typename OpaqueHandle>
|
| 21 |
+
struct TORCH_API OpaqueTensorImpl : public TensorImpl {
|
| 22 |
+
// public constructor for now...
|
| 23 |
+
OpaqueTensorImpl(
|
| 24 |
+
at::DispatchKeySet key_set,
|
| 25 |
+
const caffe2::TypeMeta data_type,
|
| 26 |
+
c10::Device device,
|
| 27 |
+
OpaqueHandle opaque_handle,
|
| 28 |
+
c10::IntArrayRef sizes,
|
| 29 |
+
bool is_non_overlapping_and_dense = true)
|
| 30 |
+
: TensorImpl(key_set, data_type, device),
|
| 31 |
+
opaque_handle_(std::move(opaque_handle)) {
|
| 32 |
+
set_storage_access_should_throw();
|
| 33 |
+
set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
|
| 34 |
+
sizes_and_strides_.set_sizes(sizes);
|
| 35 |
+
refresh_numel();
|
| 36 |
+
// NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
|
| 37 |
+
is_non_overlapping_and_dense_ = is_non_overlapping_and_dense;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
// Destructor doesn't call release_resources because it's
|
| 41 |
+
// unnecessary; don't forget to change that if needed!
|
| 42 |
+
void release_resources() override {
|
| 43 |
+
TensorImpl::release_resources();
|
| 44 |
+
opaque_handle_ = {};
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
void set_size(int64_t dim, int64_t new_size) override {
|
| 48 |
+
AT_ERROR("opaque tensors do not have set_size");
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
void set_stride(int64_t dim, int64_t new_stride) override {
|
| 52 |
+
AT_ERROR("opaque tensors do not have set_stride");
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
void set_storage_offset(int64_t storage_offset) override {
|
| 56 |
+
AT_ERROR("opaque tensors do not have set_storage_offset");
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
#ifdef DEBUG
|
| 60 |
+
bool has_storage() const override {
|
| 61 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
| 62 |
+
!storage_, "OpaqueTensorImpl assumes that storage_ is never set");
|
| 63 |
+
return false;
|
| 64 |
+
}
|
| 65 |
+
#endif
|
| 66 |
+
|
| 67 |
+
/**
|
| 68 |
+
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
|
| 69 |
+
*
|
| 70 |
+
* For usage of `version_counter` and `allow_tensor_metadata_change`,
|
| 71 |
+
* see NOTE [ TensorImpl Shallow-Copying ].
|
| 72 |
+
*/
|
| 73 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 74 |
+
const c10::VariableVersion& version_counter,
|
| 75 |
+
bool allow_tensor_metadata_change) const override {
|
| 76 |
+
auto impl = c10::make_intrusive<OpaqueTensorImpl<OpaqueHandle>>(
|
| 77 |
+
key_set(),
|
| 78 |
+
dtype(),
|
| 79 |
+
device(),
|
| 80 |
+
opaque_handle_,
|
| 81 |
+
sizes_and_strides_.sizes_arrayref());
|
| 82 |
+
copy_tensor_metadata(
|
| 83 |
+
/*src_opaque_impl=*/this,
|
| 84 |
+
/*dest_opaque_impl=*/impl.get(),
|
| 85 |
+
/*version_counter=*/version_counter,
|
| 86 |
+
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
|
| 87 |
+
impl->refresh_numel();
|
| 88 |
+
return impl;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
/**
|
| 92 |
+
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
|
| 93 |
+
*
|
| 94 |
+
* For usage of `version_counter` and `allow_tensor_metadata_change`,
|
| 95 |
+
* see NOTE [ TensorImpl Shallow-Copying ].
|
| 96 |
+
*/
|
| 97 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 98 |
+
c10::VariableVersion&& version_counter,
|
| 99 |
+
bool allow_tensor_metadata_change) const override {
|
| 100 |
+
auto impl = c10::make_intrusive<OpaqueTensorImpl<OpaqueHandle>>(
|
| 101 |
+
key_set(),
|
| 102 |
+
dtype(),
|
| 103 |
+
device(),
|
| 104 |
+
opaque_handle_,
|
| 105 |
+
sizes_and_strides_.sizes_arrayref());
|
| 106 |
+
copy_tensor_metadata(
|
| 107 |
+
/*src_opaque_impl=*/this,
|
| 108 |
+
/*dest_opaque_impl=*/impl.get(),
|
| 109 |
+
/*version_counter=*/std::move(version_counter),
|
| 110 |
+
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
|
| 111 |
+
impl->refresh_numel();
|
| 112 |
+
return impl;
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
/**
|
| 116 |
+
* Shallow-copies data from another TensorImpl into this TensorImpl.
|
| 117 |
+
*
|
| 118 |
+
* For why this function doesn't check this TensorImpl's
|
| 119 |
+
* `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ].
|
| 120 |
+
*/
|
| 121 |
+
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
|
| 122 |
+
AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set()));
|
| 123 |
+
auto opaque_impl =
|
| 124 |
+
static_cast<const OpaqueTensorImpl<OpaqueHandle>*>(impl.get());
|
| 125 |
+
copy_tensor_metadata(
|
| 126 |
+
/*src_impl=*/opaque_impl,
|
| 127 |
+
/*dest_impl=*/this,
|
| 128 |
+
/*version_counter=*/version_counter(),
|
| 129 |
+
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
|
| 130 |
+
refresh_numel();
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
const OpaqueHandle& opaque_handle() const {
|
| 134 |
+
return opaque_handle_;
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
OpaqueHandle& unsafe_opaque_handle() {
|
| 138 |
+
return opaque_handle_;
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
protected:
|
| 142 |
+
/**
|
| 143 |
+
* Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
|
| 144 |
+
* storage_offset) from one TensorImpl to another TensorImpl.
|
| 145 |
+
*
|
| 146 |
+
* For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
|
| 147 |
+
* [ TensorImpl Shallow-Copying ].
|
| 148 |
+
*/
|
| 149 |
+
static void copy_tensor_metadata(
|
| 150 |
+
const OpaqueTensorImpl<OpaqueHandle>* src_opaque_impl,
|
| 151 |
+
OpaqueTensorImpl<OpaqueHandle>* dest_opaque_impl,
|
| 152 |
+
const c10::VariableVersion& version_counter,
|
| 153 |
+
bool allow_tensor_metadata_change) {
|
| 154 |
+
TensorImpl::copy_tensor_metadata(
|
| 155 |
+
src_opaque_impl,
|
| 156 |
+
dest_opaque_impl,
|
| 157 |
+
version_counter,
|
| 158 |
+
allow_tensor_metadata_change);
|
| 159 |
+
|
| 160 |
+
// OpaqueTensorImpl-specific fields.
|
| 161 |
+
dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_;
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
static void copy_tensor_metadata(
|
| 165 |
+
const OpaqueTensorImpl<OpaqueHandle>* src_opaque_impl,
|
| 166 |
+
OpaqueTensorImpl<OpaqueHandle>* dest_opaque_impl,
|
| 167 |
+
c10::VariableVersion&& version_counter,
|
| 168 |
+
bool allow_tensor_metadata_change) {
|
| 169 |
+
TensorImpl::copy_tensor_metadata(
|
| 170 |
+
src_opaque_impl,
|
| 171 |
+
dest_opaque_impl,
|
| 172 |
+
std::move(version_counter),
|
| 173 |
+
allow_tensor_metadata_change);
|
| 174 |
+
|
| 175 |
+
// OpaqueTensorImpl-specific fields.
|
| 176 |
+
dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_;
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
private:
|
| 180 |
+
const char* tensorimpl_type_name() const override {
|
| 181 |
+
return "OpaqueTensorImpl";
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
OpaqueHandle opaque_handle_;
|
| 185 |
+
};
|
| 186 |
+
|
| 187 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/Operators.h
ADDED
|
@@ -0,0 +1,1385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operators.h
|
| 4 |
+
|
| 5 |
+
#ifdef TORCH_ASSERT_NO_OPERATORS
|
| 6 |
+
#error This change adds a dependency on native_functions.yaml, \
|
| 7 |
+
meaning the file will need to be re-compiled every time an operator \
|
| 8 |
+
is changed or added. Consider if your change would be better placed in \
|
| 9 |
+
another file, or if a more specific header might achieve the same goal. \
|
| 10 |
+
See NOTE: [Tensor vs. TensorBase]
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 14 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 15 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 16 |
+
Consider including a specific operator from <ATen/ops/{my_operator}_ops.h> \
|
| 17 |
+
and see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 18 |
+
#endif
|
| 19 |
+
|
| 20 |
+
#include <c10/core/SymInt.h>
|
| 21 |
+
#include <c10/core/SymIntArrayRef.h>
|
| 22 |
+
#include <c10/core/Scalar.h>
|
| 23 |
+
#include <c10/core/TensorOptions.h>
|
| 24 |
+
#include <c10/core/QScheme.h>
|
| 25 |
+
#include <c10/util/OptionalArrayRef.h>
|
| 26 |
+
#include <tuple>
|
| 27 |
+
#include <vector>
|
| 28 |
+
|
| 29 |
+
#include <ATen/ops/_adaptive_avg_pool2d_ops.h>
|
| 30 |
+
#include <ATen/ops/_adaptive_avg_pool2d_backward_ops.h>
|
| 31 |
+
#include <ATen/ops/_adaptive_avg_pool3d_ops.h>
|
| 32 |
+
#include <ATen/ops/_adaptive_avg_pool3d_backward_ops.h>
|
| 33 |
+
#include <ATen/ops/_add_batch_dim_ops.h>
|
| 34 |
+
#include <ATen/ops/_add_relu_ops.h>
|
| 35 |
+
#include <ATen/ops/_addmm_activation_ops.h>
|
| 36 |
+
#include <ATen/ops/_aminmax_ops.h>
|
| 37 |
+
#include <ATen/ops/_amp_foreach_non_finite_check_and_unscale_ops.h>
|
| 38 |
+
#include <ATen/ops/_amp_update_scale_ops.h>
|
| 39 |
+
#include <ATen/ops/_assert_async_ops.h>
|
| 40 |
+
#include <ATen/ops/_assert_scalar_ops.h>
|
| 41 |
+
#include <ATen/ops/_assert_tensor_metadata_ops.h>
|
| 42 |
+
#include <ATen/ops/_autocast_to_full_precision_ops.h>
|
| 43 |
+
#include <ATen/ops/_autocast_to_reduced_precision_ops.h>
|
| 44 |
+
#include <ATen/ops/_backward_ops.h>
|
| 45 |
+
#include <ATen/ops/_batch_norm_impl_index_ops.h>
|
| 46 |
+
#include <ATen/ops/_batch_norm_impl_index_backward_ops.h>
|
| 47 |
+
#include <ATen/ops/_batch_norm_no_update_ops.h>
|
| 48 |
+
#include <ATen/ops/_batch_norm_with_update_ops.h>
|
| 49 |
+
#include <ATen/ops/_cast_Byte_ops.h>
|
| 50 |
+
#include <ATen/ops/_cast_Char_ops.h>
|
| 51 |
+
#include <ATen/ops/_cast_Double_ops.h>
|
| 52 |
+
#include <ATen/ops/_cast_Float_ops.h>
|
| 53 |
+
#include <ATen/ops/_cast_Half_ops.h>
|
| 54 |
+
#include <ATen/ops/_cast_Int_ops.h>
|
| 55 |
+
#include <ATen/ops/_cast_Long_ops.h>
|
| 56 |
+
#include <ATen/ops/_cast_Short_ops.h>
|
| 57 |
+
#include <ATen/ops/_cdist_backward_ops.h>
|
| 58 |
+
#include <ATen/ops/_cdist_forward_ops.h>
|
| 59 |
+
#include <ATen/ops/_cholesky_solve_helper_ops.h>
|
| 60 |
+
#include <ATen/ops/_choose_qparams_per_tensor_ops.h>
|
| 61 |
+
#include <ATen/ops/_chunk_cat_ops.h>
|
| 62 |
+
#include <ATen/ops/_coalesce_ops.h>
|
| 63 |
+
#include <ATen/ops/_coalesced_ops.h>
|
| 64 |
+
#include <ATen/ops/_compute_linear_combination_ops.h>
|
| 65 |
+
#include <ATen/ops/_conj_ops.h>
|
| 66 |
+
#include <ATen/ops/_conj_copy_ops.h>
|
| 67 |
+
#include <ATen/ops/_conj_physical_ops.h>
|
| 68 |
+
#include <ATen/ops/_conv_depthwise2d_ops.h>
|
| 69 |
+
#include <ATen/ops/_convert_indices_from_coo_to_csr_ops.h>
|
| 70 |
+
#include <ATen/ops/_convert_indices_from_csr_to_coo_ops.h>
|
| 71 |
+
#include <ATen/ops/_convert_weight_to_int4pack_ops.h>
|
| 72 |
+
#include <ATen/ops/_convolution_ops.h>
|
| 73 |
+
#include <ATen/ops/_convolution_double_backward_ops.h>
|
| 74 |
+
#include <ATen/ops/_convolution_mode_ops.h>
|
| 75 |
+
#include <ATen/ops/_copy_from_ops.h>
|
| 76 |
+
#include <ATen/ops/_copy_from_and_resize_ops.h>
|
| 77 |
+
#include <ATen/ops/_cslt_compress_ops.h>
|
| 78 |
+
#include <ATen/ops/_cslt_sparse_mm_ops.h>
|
| 79 |
+
#include <ATen/ops/_cslt_sparse_mm_search_ops.h>
|
| 80 |
+
#include <ATen/ops/_ctc_loss_ops.h>
|
| 81 |
+
#include <ATen/ops/_ctc_loss_backward_ops.h>
|
| 82 |
+
#include <ATen/ops/_cudnn_ctc_loss_ops.h>
|
| 83 |
+
#include <ATen/ops/_cudnn_init_dropout_state_ops.h>
|
| 84 |
+
#include <ATen/ops/_cudnn_rnn_ops.h>
|
| 85 |
+
#include <ATen/ops/_cudnn_rnn_backward_ops.h>
|
| 86 |
+
#include <ATen/ops/_cudnn_rnn_flatten_weight_ops.h>
|
| 87 |
+
#include <ATen/ops/_cufft_clear_plan_cache_ops.h>
|
| 88 |
+
#include <ATen/ops/_cufft_get_plan_cache_max_size_ops.h>
|
| 89 |
+
#include <ATen/ops/_cufft_get_plan_cache_size_ops.h>
|
| 90 |
+
#include <ATen/ops/_cufft_set_plan_cache_max_size_ops.h>
|
| 91 |
+
#include <ATen/ops/_cummax_helper_ops.h>
|
| 92 |
+
#include <ATen/ops/_cummin_helper_ops.h>
|
| 93 |
+
#include <ATen/ops/_debug_has_internal_overlap_ops.h>
|
| 94 |
+
#include <ATen/ops/_dimI_ops.h>
|
| 95 |
+
#include <ATen/ops/_dimV_ops.h>
|
| 96 |
+
#include <ATen/ops/_dim_arange_ops.h>
|
| 97 |
+
#include <ATen/ops/_dirichlet_grad_ops.h>
|
| 98 |
+
#include <ATen/ops/_efficient_attention_backward_ops.h>
|
| 99 |
+
#include <ATen/ops/_efficient_attention_forward_ops.h>
|
| 100 |
+
#include <ATen/ops/_efficientzerotensor_ops.h>
|
| 101 |
+
#include <ATen/ops/_embedding_bag_ops.h>
|
| 102 |
+
#include <ATen/ops/_embedding_bag_backward_ops.h>
|
| 103 |
+
#include <ATen/ops/_embedding_bag_dense_backward_ops.h>
|
| 104 |
+
#include <ATen/ops/_embedding_bag_forward_only_ops.h>
|
| 105 |
+
#include <ATen/ops/_embedding_bag_per_sample_weights_backward_ops.h>
|
| 106 |
+
#include <ATen/ops/_embedding_bag_sparse_backward_ops.h>
|
| 107 |
+
#include <ATen/ops/_empty_affine_quantized_ops.h>
|
| 108 |
+
#include <ATen/ops/_empty_per_channel_affine_quantized_ops.h>
|
| 109 |
+
#include <ATen/ops/_euclidean_dist_ops.h>
|
| 110 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_ops.h>
|
| 111 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_backward_ops.h>
|
| 112 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_ops.h>
|
| 113 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_backward_ops.h>
|
| 114 |
+
#include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_ops.h>
|
| 115 |
+
#include <ATen/ops/_fft_c2c_ops.h>
|
| 116 |
+
#include <ATen/ops/_fft_c2r_ops.h>
|
| 117 |
+
#include <ATen/ops/_fft_r2c_ops.h>
|
| 118 |
+
#include <ATen/ops/_fill_mem_eff_dropout_mask_ops.h>
|
| 119 |
+
#include <ATen/ops/_flash_attention_backward_ops.h>
|
| 120 |
+
#include <ATen/ops/_flash_attention_forward_ops.h>
|
| 121 |
+
#include <ATen/ops/_foobar_ops.h>
|
| 122 |
+
#include <ATen/ops/_foreach_abs_ops.h>
|
| 123 |
+
#include <ATen/ops/_foreach_acos_ops.h>
|
| 124 |
+
#include <ATen/ops/_foreach_add_ops.h>
|
| 125 |
+
#include <ATen/ops/_foreach_addcdiv_ops.h>
|
| 126 |
+
#include <ATen/ops/_foreach_addcmul_ops.h>
|
| 127 |
+
#include <ATen/ops/_foreach_asin_ops.h>
|
| 128 |
+
#include <ATen/ops/_foreach_atan_ops.h>
|
| 129 |
+
#include <ATen/ops/_foreach_ceil_ops.h>
|
| 130 |
+
#include <ATen/ops/_foreach_clamp_max_ops.h>
|
| 131 |
+
#include <ATen/ops/_foreach_clamp_min_ops.h>
|
| 132 |
+
#include <ATen/ops/_foreach_copy_ops.h>
|
| 133 |
+
#include <ATen/ops/_foreach_cos_ops.h>
|
| 134 |
+
#include <ATen/ops/_foreach_cosh_ops.h>
|
| 135 |
+
#include <ATen/ops/_foreach_div_ops.h>
|
| 136 |
+
#include <ATen/ops/_foreach_erf_ops.h>
|
| 137 |
+
#include <ATen/ops/_foreach_erfc_ops.h>
|
| 138 |
+
#include <ATen/ops/_foreach_exp_ops.h>
|
| 139 |
+
#include <ATen/ops/_foreach_expm1_ops.h>
|
| 140 |
+
#include <ATen/ops/_foreach_floor_ops.h>
|
| 141 |
+
#include <ATen/ops/_foreach_frac_ops.h>
|
| 142 |
+
#include <ATen/ops/_foreach_lerp_ops.h>
|
| 143 |
+
#include <ATen/ops/_foreach_lgamma_ops.h>
|
| 144 |
+
#include <ATen/ops/_foreach_log_ops.h>
|
| 145 |
+
#include <ATen/ops/_foreach_log10_ops.h>
|
| 146 |
+
#include <ATen/ops/_foreach_log1p_ops.h>
|
| 147 |
+
#include <ATen/ops/_foreach_log2_ops.h>
|
| 148 |
+
#include <ATen/ops/_foreach_max_ops.h>
|
| 149 |
+
#include <ATen/ops/_foreach_maximum_ops.h>
|
| 150 |
+
#include <ATen/ops/_foreach_minimum_ops.h>
|
| 151 |
+
#include <ATen/ops/_foreach_mul_ops.h>
|
| 152 |
+
#include <ATen/ops/_foreach_neg_ops.h>
|
| 153 |
+
#include <ATen/ops/_foreach_norm_ops.h>
|
| 154 |
+
#include <ATen/ops/_foreach_pow_ops.h>
|
| 155 |
+
#include <ATen/ops/_foreach_reciprocal_ops.h>
|
| 156 |
+
#include <ATen/ops/_foreach_round_ops.h>
|
| 157 |
+
#include <ATen/ops/_foreach_sigmoid_ops.h>
|
| 158 |
+
#include <ATen/ops/_foreach_sign_ops.h>
|
| 159 |
+
#include <ATen/ops/_foreach_sin_ops.h>
|
| 160 |
+
#include <ATen/ops/_foreach_sinh_ops.h>
|
| 161 |
+
#include <ATen/ops/_foreach_sqrt_ops.h>
|
| 162 |
+
#include <ATen/ops/_foreach_sub_ops.h>
|
| 163 |
+
#include <ATen/ops/_foreach_tan_ops.h>
|
| 164 |
+
#include <ATen/ops/_foreach_tanh_ops.h>
|
| 165 |
+
#include <ATen/ops/_foreach_trunc_ops.h>
|
| 166 |
+
#include <ATen/ops/_foreach_zero_ops.h>
|
| 167 |
+
#include <ATen/ops/_functional_assert_async_ops.h>
|
| 168 |
+
#include <ATen/ops/_functional_assert_scalar_ops.h>
|
| 169 |
+
#include <ATen/ops/_functional_sym_constrain_range_ops.h>
|
| 170 |
+
#include <ATen/ops/_functional_sym_constrain_range_for_size_ops.h>
|
| 171 |
+
#include <ATen/ops/_fused_adagrad_ops.h>
|
| 172 |
+
#include <ATen/ops/_fused_adam_ops.h>
|
| 173 |
+
#include <ATen/ops/_fused_adamw_ops.h>
|
| 174 |
+
#include <ATen/ops/_fused_dropout_ops.h>
|
| 175 |
+
#include <ATen/ops/_fused_moving_avg_obs_fq_helper_ops.h>
|
| 176 |
+
#include <ATen/ops/_fused_sdp_choice_ops.h>
|
| 177 |
+
#include <ATen/ops/_fused_sgd_ops.h>
|
| 178 |
+
#include <ATen/ops/_fw_primal_ops.h>
|
| 179 |
+
#include <ATen/ops/_fw_primal_copy_ops.h>
|
| 180 |
+
#include <ATen/ops/_gather_sparse_backward_ops.h>
|
| 181 |
+
#include <ATen/ops/_grid_sampler_2d_cpu_fallback_ops.h>
|
| 182 |
+
#include <ATen/ops/_grid_sampler_2d_cpu_fallback_backward_ops.h>
|
| 183 |
+
#include <ATen/ops/_has_compatible_shallow_copy_type_ops.h>
|
| 184 |
+
#include <ATen/ops/_has_same_storage_numel_ops.h>
|
| 185 |
+
#include <ATen/ops/_histogramdd_bin_edges_ops.h>
|
| 186 |
+
#include <ATen/ops/_histogramdd_from_bin_cts_ops.h>
|
| 187 |
+
#include <ATen/ops/_histogramdd_from_bin_tensors_ops.h>
|
| 188 |
+
#include <ATen/ops/_index_put_impl_ops.h>
|
| 189 |
+
#include <ATen/ops/_indices_ops.h>
|
| 190 |
+
#include <ATen/ops/_indices_copy_ops.h>
|
| 191 |
+
#include <ATen/ops/_int_mm_ops.h>
|
| 192 |
+
#include <ATen/ops/_is_all_true_ops.h>
|
| 193 |
+
#include <ATen/ops/_is_any_true_ops.h>
|
| 194 |
+
#include <ATen/ops/_is_zerotensor_ops.h>
|
| 195 |
+
#include <ATen/ops/_jagged_to_padded_dense_forward_ops.h>
|
| 196 |
+
#include <ATen/ops/_lazy_clone_ops.h>
|
| 197 |
+
#include <ATen/ops/_linalg_check_errors_ops.h>
|
| 198 |
+
#include <ATen/ops/_linalg_det_ops.h>
|
| 199 |
+
#include <ATen/ops/_linalg_eigh_ops.h>
|
| 200 |
+
#include <ATen/ops/_linalg_eigvals_ops.h>
|
| 201 |
+
#include <ATen/ops/_linalg_slogdet_ops.h>
|
| 202 |
+
#include <ATen/ops/_linalg_solve_ex_ops.h>
|
| 203 |
+
#include <ATen/ops/_linalg_svd_ops.h>
|
| 204 |
+
#include <ATen/ops/_local_scalar_dense_ops.h>
|
| 205 |
+
#include <ATen/ops/_log_softmax_ops.h>
|
| 206 |
+
#include <ATen/ops/_log_softmax_backward_data_ops.h>
|
| 207 |
+
#include <ATen/ops/_logcumsumexp_ops.h>
|
| 208 |
+
#include <ATen/ops/_lstm_mps_ops.h>
|
| 209 |
+
#include <ATen/ops/_lu_with_info_ops.h>
|
| 210 |
+
#include <ATen/ops/_make_dep_token_ops.h>
|
| 211 |
+
#include <ATen/ops/_make_dual_ops.h>
|
| 212 |
+
#include <ATen/ops/_make_dual_copy_ops.h>
|
| 213 |
+
#include <ATen/ops/_make_per_channel_quantized_tensor_ops.h>
|
| 214 |
+
#include <ATen/ops/_make_per_tensor_quantized_tensor_ops.h>
|
| 215 |
+
#include <ATen/ops/_masked_scale_ops.h>
|
| 216 |
+
#include <ATen/ops/_masked_softmax_ops.h>
|
| 217 |
+
#include <ATen/ops/_masked_softmax_backward_ops.h>
|
| 218 |
+
#include <ATen/ops/_mixed_dtypes_linear_ops.h>
|
| 219 |
+
#include <ATen/ops/_mkldnn_reshape_ops.h>
|
| 220 |
+
#include <ATen/ops/_mkldnn_transpose_ops.h>
|
| 221 |
+
#include <ATen/ops/_mps_convolution_ops.h>
|
| 222 |
+
#include <ATen/ops/_mps_convolution_transpose_ops.h>
|
| 223 |
+
#include <ATen/ops/_native_batch_norm_legit_ops.h>
|
| 224 |
+
#include <ATen/ops/_native_batch_norm_legit_no_training_ops.h>
|
| 225 |
+
#include <ATen/ops/_native_multi_head_attention_ops.h>
|
| 226 |
+
#include <ATen/ops/_neg_view_ops.h>
|
| 227 |
+
#include <ATen/ops/_neg_view_copy_ops.h>
|
| 228 |
+
#include <ATen/ops/_nested_compute_contiguous_strides_offsets_ops.h>
|
| 229 |
+
#include <ATen/ops/_nested_from_padded_ops.h>
|
| 230 |
+
#include <ATen/ops/_nested_from_padded_and_nested_example_ops.h>
|
| 231 |
+
#include <ATen/ops/_nested_get_jagged_dummy_ops.h>
|
| 232 |
+
#include <ATen/ops/_nested_get_lengths_ops.h>
|
| 233 |
+
#include <ATen/ops/_nested_get_max_seqlen_ops.h>
|
| 234 |
+
#include <ATen/ops/_nested_get_min_seqlen_ops.h>
|
| 235 |
+
#include <ATen/ops/_nested_get_offsets_ops.h>
|
| 236 |
+
#include <ATen/ops/_nested_get_ragged_idx_ops.h>
|
| 237 |
+
#include <ATen/ops/_nested_get_values_ops.h>
|
| 238 |
+
#include <ATen/ops/_nested_get_values_copy_ops.h>
|
| 239 |
+
#include <ATen/ops/_nested_select_backward_ops.h>
|
| 240 |
+
#include <ATen/ops/_nested_sum_backward_ops.h>
|
| 241 |
+
#include <ATen/ops/_nested_tensor_from_mask_ops.h>
|
| 242 |
+
#include <ATen/ops/_nested_tensor_from_mask_left_aligned_ops.h>
|
| 243 |
+
#include <ATen/ops/_nested_tensor_from_tensor_list_ops.h>
|
| 244 |
+
#include <ATen/ops/_nested_tensor_size_ops.h>
|
| 245 |
+
#include <ATen/ops/_nested_tensor_softmax_with_shape_ops.h>
|
| 246 |
+
#include <ATen/ops/_nested_tensor_storage_offsets_ops.h>
|
| 247 |
+
#include <ATen/ops/_nested_tensor_strides_ops.h>
|
| 248 |
+
#include <ATen/ops/_nested_view_from_buffer_ops.h>
|
| 249 |
+
#include <ATen/ops/_nested_view_from_buffer_copy_ops.h>
|
| 250 |
+
#include <ATen/ops/_nested_view_from_jagged_ops.h>
|
| 251 |
+
#include <ATen/ops/_nested_view_from_jagged_copy_ops.h>
|
| 252 |
+
#include <ATen/ops/_new_zeros_with_same_feature_meta_ops.h>
|
| 253 |
+
#include <ATen/ops/_nnpack_available_ops.h>
|
| 254 |
+
#include <ATen/ops/_nnpack_spatial_convolution_ops.h>
|
| 255 |
+
#include <ATen/ops/_nnz_ops.h>
|
| 256 |
+
#include <ATen/ops/_pack_padded_sequence_ops.h>
|
| 257 |
+
#include <ATen/ops/_pack_padded_sequence_backward_ops.h>
|
| 258 |
+
#include <ATen/ops/_pad_circular_ops.h>
|
| 259 |
+
#include <ATen/ops/_pad_enum_ops.h>
|
| 260 |
+
#include <ATen/ops/_pad_packed_sequence_ops.h>
|
| 261 |
+
#include <ATen/ops/_padded_dense_to_jagged_forward_ops.h>
|
| 262 |
+
#include <ATen/ops/_pdist_backward_ops.h>
|
| 263 |
+
#include <ATen/ops/_pdist_forward_ops.h>
|
| 264 |
+
#include <ATen/ops/_pin_memory_ops.h>
|
| 265 |
+
#include <ATen/ops/_prelu_kernel_ops.h>
|
| 266 |
+
#include <ATen/ops/_prelu_kernel_backward_ops.h>
|
| 267 |
+
#include <ATen/ops/_print_ops.h>
|
| 268 |
+
#include <ATen/ops/_propagate_xla_data_ops.h>
|
| 269 |
+
#include <ATen/ops/_remove_batch_dim_ops.h>
|
| 270 |
+
#include <ATen/ops/_reshape_alias_ops.h>
|
| 271 |
+
#include <ATen/ops/_reshape_alias_copy_ops.h>
|
| 272 |
+
#include <ATen/ops/_reshape_copy_ops.h>
|
| 273 |
+
#include <ATen/ops/_reshape_from_tensor_ops.h>
|
| 274 |
+
#include <ATen/ops/_resize_output_ops.h>
|
| 275 |
+
#include <ATen/ops/_rowwise_prune_ops.h>
|
| 276 |
+
#include <ATen/ops/_safe_softmax_ops.h>
|
| 277 |
+
#include <ATen/ops/_sample_dirichlet_ops.h>
|
| 278 |
+
#include <ATen/ops/_saturate_weight_to_fp16_ops.h>
|
| 279 |
+
#include <ATen/ops/_scaled_dot_product_attention_math_ops.h>
|
| 280 |
+
#include <ATen/ops/_scaled_dot_product_attention_math_for_mps_ops.h>
|
| 281 |
+
#include <ATen/ops/_scaled_dot_product_cudnn_attention_ops.h>
|
| 282 |
+
#include <ATen/ops/_scaled_dot_product_cudnn_attention_backward_ops.h>
|
| 283 |
+
#include <ATen/ops/_scaled_dot_product_efficient_attention_ops.h>
|
| 284 |
+
#include <ATen/ops/_scaled_dot_product_efficient_attention_backward_ops.h>
|
| 285 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_ops.h>
|
| 286 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_backward_ops.h>
|
| 287 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_ops.h>
|
| 288 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward_ops.h>
|
| 289 |
+
#include <ATen/ops/_scaled_dot_product_fused_attention_overrideable_ops.h>
|
| 290 |
+
#include <ATen/ops/_scaled_dot_product_fused_attention_overrideable_backward_ops.h>
|
| 291 |
+
#include <ATen/ops/_scaled_mm_ops.h>
|
| 292 |
+
#include <ATen/ops/_segment_reduce_backward_ops.h>
|
| 293 |
+
#include <ATen/ops/_shape_as_tensor_ops.h>
|
| 294 |
+
#include <ATen/ops/_slow_conv2d_backward_ops.h>
|
| 295 |
+
#include <ATen/ops/_slow_conv2d_forward_ops.h>
|
| 296 |
+
#include <ATen/ops/_sobol_engine_draw_ops.h>
|
| 297 |
+
#include <ATen/ops/_sobol_engine_ff_ops.h>
|
| 298 |
+
#include <ATen/ops/_sobol_engine_initialize_state_ops.h>
|
| 299 |
+
#include <ATen/ops/_sobol_engine_scramble_ops.h>
|
| 300 |
+
#include <ATen/ops/_softmax_ops.h>
|
| 301 |
+
#include <ATen/ops/_softmax_backward_data_ops.h>
|
| 302 |
+
#include <ATen/ops/_sparse_addmm_ops.h>
|
| 303 |
+
#include <ATen/ops/_sparse_broadcast_to_ops.h>
|
| 304 |
+
#include <ATen/ops/_sparse_broadcast_to_copy_ops.h>
|
| 305 |
+
#include <ATen/ops/_sparse_bsc_tensor_unsafe_ops.h>
|
| 306 |
+
#include <ATen/ops/_sparse_bsr_tensor_unsafe_ops.h>
|
| 307 |
+
#include <ATen/ops/_sparse_compressed_tensor_unsafe_ops.h>
|
| 308 |
+
#include <ATen/ops/_sparse_compressed_tensor_with_dims_ops.h>
|
| 309 |
+
#include <ATen/ops/_sparse_coo_tensor_unsafe_ops.h>
|
| 310 |
+
#include <ATen/ops/_sparse_coo_tensor_with_dims_ops.h>
|
| 311 |
+
#include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors_ops.h>
|
| 312 |
+
#include <ATen/ops/_sparse_csc_tensor_unsafe_ops.h>
|
| 313 |
+
#include <ATen/ops/_sparse_csr_prod_ops.h>
|
| 314 |
+
#include <ATen/ops/_sparse_csr_sum_ops.h>
|
| 315 |
+
#include <ATen/ops/_sparse_csr_tensor_unsafe_ops.h>
|
| 316 |
+
#include <ATen/ops/_sparse_log_softmax_ops.h>
|
| 317 |
+
#include <ATen/ops/_sparse_log_softmax_backward_data_ops.h>
|
| 318 |
+
#include <ATen/ops/_sparse_mask_projection_ops.h>
|
| 319 |
+
#include <ATen/ops/_sparse_mm_ops.h>
|
| 320 |
+
#include <ATen/ops/_sparse_mm_reduce_impl_ops.h>
|
| 321 |
+
#include <ATen/ops/_sparse_mm_reduce_impl_backward_ops.h>
|
| 322 |
+
#include <ATen/ops/_sparse_semi_structured_addmm_ops.h>
|
| 323 |
+
#include <ATen/ops/_sparse_semi_structured_apply_ops.h>
|
| 324 |
+
#include <ATen/ops/_sparse_semi_structured_apply_dense_ops.h>
|
| 325 |
+
#include <ATen/ops/_sparse_semi_structured_linear_ops.h>
|
| 326 |
+
#include <ATen/ops/_sparse_semi_structured_mm_ops.h>
|
| 327 |
+
#include <ATen/ops/_sparse_semi_structured_tile_ops.h>
|
| 328 |
+
#include <ATen/ops/_sparse_softmax_ops.h>
|
| 329 |
+
#include <ATen/ops/_sparse_softmax_backward_data_ops.h>
|
| 330 |
+
#include <ATen/ops/_sparse_sparse_matmul_ops.h>
|
| 331 |
+
#include <ATen/ops/_sparse_sum_ops.h>
|
| 332 |
+
#include <ATen/ops/_sparse_sum_backward_ops.h>
|
| 333 |
+
#include <ATen/ops/_spdiags_ops.h>
|
| 334 |
+
#include <ATen/ops/_spsolve_ops.h>
|
| 335 |
+
#include <ATen/ops/_stack_ops.h>
|
| 336 |
+
#include <ATen/ops/_standard_gamma_ops.h>
|
| 337 |
+
#include <ATen/ops/_standard_gamma_grad_ops.h>
|
| 338 |
+
#include <ATen/ops/_test_ambiguous_defaults_ops.h>
|
| 339 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_ops.h>
|
| 340 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_view_ops.h>
|
| 341 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_view_copy_ops.h>
|
| 342 |
+
#include <ATen/ops/_test_check_tensor_ops.h>
|
| 343 |
+
#include <ATen/ops/_test_functorch_fallback_ops.h>
|
| 344 |
+
#include <ATen/ops/_test_optional_filled_intlist_ops.h>
|
| 345 |
+
#include <ATen/ops/_test_optional_floatlist_ops.h>
|
| 346 |
+
#include <ATen/ops/_test_optional_intlist_ops.h>
|
| 347 |
+
#include <ATen/ops/_test_parallel_materialize_ops.h>
|
| 348 |
+
#include <ATen/ops/_test_serialization_subcmul_ops.h>
|
| 349 |
+
#include <ATen/ops/_test_string_default_ops.h>
|
| 350 |
+
#include <ATen/ops/_test_warn_in_autograd_ops.h>
|
| 351 |
+
#include <ATen/ops/_thnn_differentiable_gru_cell_backward_ops.h>
|
| 352 |
+
#include <ATen/ops/_thnn_differentiable_lstm_cell_backward_ops.h>
|
| 353 |
+
#include <ATen/ops/_thnn_fused_gru_cell_ops.h>
|
| 354 |
+
#include <ATen/ops/_thnn_fused_gru_cell_backward_ops.h>
|
| 355 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_ops.h>
|
| 356 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_backward_ops.h>
|
| 357 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_backward_impl_ops.h>
|
| 358 |
+
#include <ATen/ops/_to_copy_ops.h>
|
| 359 |
+
#include <ATen/ops/_to_cpu_ops.h>
|
| 360 |
+
#include <ATen/ops/_to_dense_ops.h>
|
| 361 |
+
#include <ATen/ops/_to_sparse_ops.h>
|
| 362 |
+
#include <ATen/ops/_to_sparse_bsc_ops.h>
|
| 363 |
+
#include <ATen/ops/_to_sparse_bsr_ops.h>
|
| 364 |
+
#include <ATen/ops/_to_sparse_csc_ops.h>
|
| 365 |
+
#include <ATen/ops/_to_sparse_csr_ops.h>
|
| 366 |
+
#include <ATen/ops/_to_sparse_semi_structured_ops.h>
|
| 367 |
+
#include <ATen/ops/_transform_bias_rescale_qkv_ops.h>
|
| 368 |
+
#include <ATen/ops/_transformer_encoder_layer_fwd_ops.h>
|
| 369 |
+
#include <ATen/ops/_trilinear_ops.h>
|
| 370 |
+
#include <ATen/ops/_triton_multi_head_attention_ops.h>
|
| 371 |
+
#include <ATen/ops/_triton_scaled_dot_attention_ops.h>
|
| 372 |
+
#include <ATen/ops/_unique_ops.h>
|
| 373 |
+
#include <ATen/ops/_unique2_ops.h>
|
| 374 |
+
#include <ATen/ops/_unpack_dual_ops.h>
|
| 375 |
+
#include <ATen/ops/_unsafe_index_ops.h>
|
| 376 |
+
#include <ATen/ops/_unsafe_index_put_ops.h>
|
| 377 |
+
#include <ATen/ops/_unsafe_masked_index_ops.h>
|
| 378 |
+
#include <ATen/ops/_unsafe_masked_index_put_accumulate_ops.h>
|
| 379 |
+
#include <ATen/ops/_unsafe_view_ops.h>
|
| 380 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_ops.h>
|
| 381 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_backward_ops.h>
|
| 382 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_ops.h>
|
| 383 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_backward_ops.h>
|
| 384 |
+
#include <ATen/ops/_upsample_nearest_exact1d_ops.h>
|
| 385 |
+
#include <ATen/ops/_upsample_nearest_exact1d_backward_ops.h>
|
| 386 |
+
#include <ATen/ops/_upsample_nearest_exact2d_ops.h>
|
| 387 |
+
#include <ATen/ops/_upsample_nearest_exact2d_backward_ops.h>
|
| 388 |
+
#include <ATen/ops/_upsample_nearest_exact3d_ops.h>
|
| 389 |
+
#include <ATen/ops/_upsample_nearest_exact3d_backward_ops.h>
|
| 390 |
+
#include <ATen/ops/_use_cudnn_ctc_loss_ops.h>
|
| 391 |
+
#include <ATen/ops/_use_cudnn_rnn_flatten_weight_ops.h>
|
| 392 |
+
#include <ATen/ops/_validate_compressed_sparse_indices_ops.h>
|
| 393 |
+
#include <ATen/ops/_validate_sparse_bsc_tensor_args_ops.h>
|
| 394 |
+
#include <ATen/ops/_validate_sparse_bsr_tensor_args_ops.h>
|
| 395 |
+
#include <ATen/ops/_validate_sparse_compressed_tensor_args_ops.h>
|
| 396 |
+
#include <ATen/ops/_validate_sparse_coo_tensor_args_ops.h>
|
| 397 |
+
#include <ATen/ops/_validate_sparse_csc_tensor_args_ops.h>
|
| 398 |
+
#include <ATen/ops/_validate_sparse_csr_tensor_args_ops.h>
|
| 399 |
+
#include <ATen/ops/_values_ops.h>
|
| 400 |
+
#include <ATen/ops/_values_copy_ops.h>
|
| 401 |
+
#include <ATen/ops/_version_ops.h>
|
| 402 |
+
#include <ATen/ops/_weight_int4pack_mm_ops.h>
|
| 403 |
+
#include <ATen/ops/_weight_int8pack_mm_ops.h>
|
| 404 |
+
#include <ATen/ops/_weight_norm_ops.h>
|
| 405 |
+
#include <ATen/ops/_weight_norm_differentiable_backward_ops.h>
|
| 406 |
+
#include <ATen/ops/_weight_norm_interface_ops.h>
|
| 407 |
+
#include <ATen/ops/_weight_norm_interface_backward_ops.h>
|
| 408 |
+
#include <ATen/ops/_wrapped_linear_prepack_ops.h>
|
| 409 |
+
#include <ATen/ops/_wrapped_quantized_linear_prepacked_ops.h>
|
| 410 |
+
#include <ATen/ops/abs_ops.h>
|
| 411 |
+
#include <ATen/ops/absolute_ops.h>
|
| 412 |
+
#include <ATen/ops/acos_ops.h>
|
| 413 |
+
#include <ATen/ops/acosh_ops.h>
|
| 414 |
+
#include <ATen/ops/adaptive_avg_pool1d_ops.h>
|
| 415 |
+
#include <ATen/ops/adaptive_avg_pool2d_ops.h>
|
| 416 |
+
#include <ATen/ops/adaptive_avg_pool3d_ops.h>
|
| 417 |
+
#include <ATen/ops/adaptive_avg_pool3d_backward_ops.h>
|
| 418 |
+
#include <ATen/ops/adaptive_max_pool1d_ops.h>
|
| 419 |
+
#include <ATen/ops/adaptive_max_pool2d_ops.h>
|
| 420 |
+
#include <ATen/ops/adaptive_max_pool2d_backward_ops.h>
|
| 421 |
+
#include <ATen/ops/adaptive_max_pool3d_ops.h>
|
| 422 |
+
#include <ATen/ops/adaptive_max_pool3d_backward_ops.h>
|
| 423 |
+
#include <ATen/ops/add_ops.h>
|
| 424 |
+
#include <ATen/ops/addbmm_ops.h>
|
| 425 |
+
#include <ATen/ops/addcdiv_ops.h>
|
| 426 |
+
#include <ATen/ops/addcmul_ops.h>
|
| 427 |
+
#include <ATen/ops/addmm_ops.h>
|
| 428 |
+
#include <ATen/ops/addmv_ops.h>
|
| 429 |
+
#include <ATen/ops/addr_ops.h>
|
| 430 |
+
#include <ATen/ops/adjoint_ops.h>
|
| 431 |
+
#include <ATen/ops/affine_grid_generator_ops.h>
|
| 432 |
+
#include <ATen/ops/affine_grid_generator_backward_ops.h>
|
| 433 |
+
#include <ATen/ops/alias_ops.h>
|
| 434 |
+
#include <ATen/ops/alias_copy_ops.h>
|
| 435 |
+
#include <ATen/ops/align_as_ops.h>
|
| 436 |
+
#include <ATen/ops/align_tensors_ops.h>
|
| 437 |
+
#include <ATen/ops/align_to_ops.h>
|
| 438 |
+
#include <ATen/ops/all_ops.h>
|
| 439 |
+
#include <ATen/ops/allclose_ops.h>
|
| 440 |
+
#include <ATen/ops/alpha_dropout_ops.h>
|
| 441 |
+
#include <ATen/ops/amax_ops.h>
|
| 442 |
+
#include <ATen/ops/amin_ops.h>
|
| 443 |
+
#include <ATen/ops/aminmax_ops.h>
|
| 444 |
+
#include <ATen/ops/and_ops.h>
|
| 445 |
+
#include <ATen/ops/angle_ops.h>
|
| 446 |
+
#include <ATen/ops/any_ops.h>
|
| 447 |
+
#include <ATen/ops/arange_ops.h>
|
| 448 |
+
#include <ATen/ops/arccos_ops.h>
|
| 449 |
+
#include <ATen/ops/arccosh_ops.h>
|
| 450 |
+
#include <ATen/ops/arcsin_ops.h>
|
| 451 |
+
#include <ATen/ops/arcsinh_ops.h>
|
| 452 |
+
#include <ATen/ops/arctan_ops.h>
|
| 453 |
+
#include <ATen/ops/arctan2_ops.h>
|
| 454 |
+
#include <ATen/ops/arctanh_ops.h>
|
| 455 |
+
#include <ATen/ops/argmax_ops.h>
|
| 456 |
+
#include <ATen/ops/argmin_ops.h>
|
| 457 |
+
#include <ATen/ops/argsort_ops.h>
|
| 458 |
+
#include <ATen/ops/argwhere_ops.h>
|
| 459 |
+
#include <ATen/ops/as_strided_ops.h>
|
| 460 |
+
#include <ATen/ops/as_strided_copy_ops.h>
|
| 461 |
+
#include <ATen/ops/as_strided_scatter_ops.h>
|
| 462 |
+
#include <ATen/ops/asin_ops.h>
|
| 463 |
+
#include <ATen/ops/asinh_ops.h>
|
| 464 |
+
#include <ATen/ops/atan_ops.h>
|
| 465 |
+
#include <ATen/ops/atan2_ops.h>
|
| 466 |
+
#include <ATen/ops/atanh_ops.h>
|
| 467 |
+
#include <ATen/ops/atleast_1d_ops.h>
|
| 468 |
+
#include <ATen/ops/atleast_2d_ops.h>
|
| 469 |
+
#include <ATen/ops/atleast_3d_ops.h>
|
| 470 |
+
#include <ATen/ops/avg_pool1d_ops.h>
|
| 471 |
+
#include <ATen/ops/avg_pool2d_ops.h>
|
| 472 |
+
#include <ATen/ops/avg_pool2d_backward_ops.h>
|
| 473 |
+
#include <ATen/ops/avg_pool3d_ops.h>
|
| 474 |
+
#include <ATen/ops/avg_pool3d_backward_ops.h>
|
| 475 |
+
#include <ATen/ops/baddbmm_ops.h>
|
| 476 |
+
#include <ATen/ops/bartlett_window_ops.h>
|
| 477 |
+
#include <ATen/ops/batch_norm_ops.h>
|
| 478 |
+
#include <ATen/ops/batch_norm_backward_ops.h>
|
| 479 |
+
#include <ATen/ops/batch_norm_backward_elemt_ops.h>
|
| 480 |
+
#include <ATen/ops/batch_norm_backward_reduce_ops.h>
|
| 481 |
+
#include <ATen/ops/batch_norm_elemt_ops.h>
|
| 482 |
+
#include <ATen/ops/batch_norm_gather_stats_ops.h>
|
| 483 |
+
#include <ATen/ops/batch_norm_gather_stats_with_counts_ops.h>
|
| 484 |
+
#include <ATen/ops/batch_norm_stats_ops.h>
|
| 485 |
+
#include <ATen/ops/batch_norm_update_stats_ops.h>
|
| 486 |
+
#include <ATen/ops/bernoulli_ops.h>
|
| 487 |
+
#include <ATen/ops/bilinear_ops.h>
|
| 488 |
+
#include <ATen/ops/binary_cross_entropy_ops.h>
|
| 489 |
+
#include <ATen/ops/binary_cross_entropy_backward_ops.h>
|
| 490 |
+
#include <ATen/ops/binary_cross_entropy_with_logits_ops.h>
|
| 491 |
+
#include <ATen/ops/bincount_ops.h>
|
| 492 |
+
#include <ATen/ops/binomial_ops.h>
|
| 493 |
+
#include <ATen/ops/bitwise_and_ops.h>
|
| 494 |
+
#include <ATen/ops/bitwise_left_shift_ops.h>
|
| 495 |
+
#include <ATen/ops/bitwise_not_ops.h>
|
| 496 |
+
#include <ATen/ops/bitwise_or_ops.h>
|
| 497 |
+
#include <ATen/ops/bitwise_right_shift_ops.h>
|
| 498 |
+
#include <ATen/ops/bitwise_xor_ops.h>
|
| 499 |
+
#include <ATen/ops/blackman_window_ops.h>
|
| 500 |
+
#include <ATen/ops/block_diag_ops.h>
|
| 501 |
+
#include <ATen/ops/bmm_ops.h>
|
| 502 |
+
#include <ATen/ops/broadcast_tensors_ops.h>
|
| 503 |
+
#include <ATen/ops/broadcast_to_ops.h>
|
| 504 |
+
#include <ATen/ops/bucketize_ops.h>
|
| 505 |
+
#include <ATen/ops/can_cast_ops.h>
|
| 506 |
+
#include <ATen/ops/cartesian_prod_ops.h>
|
| 507 |
+
#include <ATen/ops/cat_ops.h>
|
| 508 |
+
#include <ATen/ops/cauchy_ops.h>
|
| 509 |
+
#include <ATen/ops/ccol_indices_ops.h>
|
| 510 |
+
#include <ATen/ops/ccol_indices_copy_ops.h>
|
| 511 |
+
#include <ATen/ops/cdist_ops.h>
|
| 512 |
+
#include <ATen/ops/ceil_ops.h>
|
| 513 |
+
#include <ATen/ops/celu_ops.h>
|
| 514 |
+
#include <ATen/ops/chain_matmul_ops.h>
|
| 515 |
+
#include <ATen/ops/chalf_ops.h>
|
| 516 |
+
#include <ATen/ops/channel_shuffle_ops.h>
|
| 517 |
+
#include <ATen/ops/cholesky_ops.h>
|
| 518 |
+
#include <ATen/ops/cholesky_inverse_ops.h>
|
| 519 |
+
#include <ATen/ops/cholesky_solve_ops.h>
|
| 520 |
+
#include <ATen/ops/choose_qparams_optimized_ops.h>
|
| 521 |
+
#include <ATen/ops/chunk_ops.h>
|
| 522 |
+
#include <ATen/ops/clamp_ops.h>
|
| 523 |
+
#include <ATen/ops/clamp_max_ops.h>
|
| 524 |
+
#include <ATen/ops/clamp_min_ops.h>
|
| 525 |
+
#include <ATen/ops/clip_ops.h>
|
| 526 |
+
#include <ATen/ops/clone_ops.h>
|
| 527 |
+
#include <ATen/ops/coalesce_ops.h>
|
| 528 |
+
#include <ATen/ops/col2im_ops.h>
|
| 529 |
+
#include <ATen/ops/col_indices_ops.h>
|
| 530 |
+
#include <ATen/ops/col_indices_copy_ops.h>
|
| 531 |
+
#include <ATen/ops/column_stack_ops.h>
|
| 532 |
+
#include <ATen/ops/combinations_ops.h>
|
| 533 |
+
#include <ATen/ops/complex_ops.h>
|
| 534 |
+
#include <ATen/ops/concat_ops.h>
|
| 535 |
+
#include <ATen/ops/concatenate_ops.h>
|
| 536 |
+
#include <ATen/ops/conj_ops.h>
|
| 537 |
+
#include <ATen/ops/conj_physical_ops.h>
|
| 538 |
+
#include <ATen/ops/constant_pad_nd_ops.h>
|
| 539 |
+
#include <ATen/ops/contiguous_ops.h>
|
| 540 |
+
#include <ATen/ops/conv1d_ops.h>
|
| 541 |
+
#include <ATen/ops/conv2d_ops.h>
|
| 542 |
+
#include <ATen/ops/conv3d_ops.h>
|
| 543 |
+
#include <ATen/ops/conv_depthwise3d_ops.h>
|
| 544 |
+
#include <ATen/ops/conv_tbc_ops.h>
|
| 545 |
+
#include <ATen/ops/conv_tbc_backward_ops.h>
|
| 546 |
+
#include <ATen/ops/conv_transpose1d_ops.h>
|
| 547 |
+
#include <ATen/ops/conv_transpose2d_ops.h>
|
| 548 |
+
#include <ATen/ops/conv_transpose3d_ops.h>
|
| 549 |
+
#include <ATen/ops/convolution_ops.h>
|
| 550 |
+
#include <ATen/ops/convolution_backward_ops.h>
|
| 551 |
+
#include <ATen/ops/convolution_backward_overrideable_ops.h>
|
| 552 |
+
#include <ATen/ops/convolution_overrideable_ops.h>
|
| 553 |
+
#include <ATen/ops/copy_ops.h>
|
| 554 |
+
#include <ATen/ops/copy_sparse_to_sparse_ops.h>
|
| 555 |
+
#include <ATen/ops/copysign_ops.h>
|
| 556 |
+
#include <ATen/ops/corrcoef_ops.h>
|
| 557 |
+
#include <ATen/ops/cos_ops.h>
|
| 558 |
+
#include <ATen/ops/cosh_ops.h>
|
| 559 |
+
#include <ATen/ops/cosine_embedding_loss_ops.h>
|
| 560 |
+
#include <ATen/ops/cosine_similarity_ops.h>
|
| 561 |
+
#include <ATen/ops/count_nonzero_ops.h>
|
| 562 |
+
#include <ATen/ops/cov_ops.h>
|
| 563 |
+
#include <ATen/ops/cross_ops.h>
|
| 564 |
+
#include <ATen/ops/cross_entropy_loss_ops.h>
|
| 565 |
+
#include <ATen/ops/crow_indices_ops.h>
|
| 566 |
+
#include <ATen/ops/crow_indices_copy_ops.h>
|
| 567 |
+
#include <ATen/ops/ctc_loss_ops.h>
|
| 568 |
+
#include <ATen/ops/cudnn_affine_grid_generator_ops.h>
|
| 569 |
+
#include <ATen/ops/cudnn_affine_grid_generator_backward_ops.h>
|
| 570 |
+
#include <ATen/ops/cudnn_batch_norm_ops.h>
|
| 571 |
+
#include <ATen/ops/cudnn_batch_norm_backward_ops.h>
|
| 572 |
+
#include <ATen/ops/cudnn_convolution_ops.h>
|
| 573 |
+
#include <ATen/ops/cudnn_convolution_add_relu_ops.h>
|
| 574 |
+
#include <ATen/ops/cudnn_convolution_relu_ops.h>
|
| 575 |
+
#include <ATen/ops/cudnn_convolution_transpose_ops.h>
|
| 576 |
+
#include <ATen/ops/cudnn_grid_sampler_ops.h>
|
| 577 |
+
#include <ATen/ops/cudnn_grid_sampler_backward_ops.h>
|
| 578 |
+
#include <ATen/ops/cudnn_is_acceptable_ops.h>
|
| 579 |
+
#include <ATen/ops/cummax_ops.h>
|
| 580 |
+
#include <ATen/ops/cummaxmin_backward_ops.h>
|
| 581 |
+
#include <ATen/ops/cummin_ops.h>
|
| 582 |
+
#include <ATen/ops/cumprod_ops.h>
|
| 583 |
+
#include <ATen/ops/cumprod_backward_ops.h>
|
| 584 |
+
#include <ATen/ops/cumsum_ops.h>
|
| 585 |
+
#include <ATen/ops/cumulative_trapezoid_ops.h>
|
| 586 |
+
#include <ATen/ops/data_ops.h>
|
| 587 |
+
#include <ATen/ops/deg2rad_ops.h>
|
| 588 |
+
#include <ATen/ops/dense_dim_ops.h>
|
| 589 |
+
#include <ATen/ops/dequantize_ops.h>
|
| 590 |
+
#include <ATen/ops/det_ops.h>
|
| 591 |
+
#include <ATen/ops/detach_ops.h>
|
| 592 |
+
#include <ATen/ops/detach_copy_ops.h>
|
| 593 |
+
#include <ATen/ops/diag_ops.h>
|
| 594 |
+
#include <ATen/ops/diag_embed_ops.h>
|
| 595 |
+
#include <ATen/ops/diagflat_ops.h>
|
| 596 |
+
#include <ATen/ops/diagonal_ops.h>
|
| 597 |
+
#include <ATen/ops/diagonal_backward_ops.h>
|
| 598 |
+
#include <ATen/ops/diagonal_copy_ops.h>
|
| 599 |
+
#include <ATen/ops/diagonal_scatter_ops.h>
|
| 600 |
+
#include <ATen/ops/diff_ops.h>
|
| 601 |
+
#include <ATen/ops/digamma_ops.h>
|
| 602 |
+
#include <ATen/ops/dist_ops.h>
|
| 603 |
+
#include <ATen/ops/div_ops.h>
|
| 604 |
+
#include <ATen/ops/divide_ops.h>
|
| 605 |
+
#include <ATen/ops/dot_ops.h>
|
| 606 |
+
#include <ATen/ops/dropout_ops.h>
|
| 607 |
+
#include <ATen/ops/dsplit_ops.h>
|
| 608 |
+
#include <ATen/ops/dstack_ops.h>
|
| 609 |
+
#include <ATen/ops/einsum_ops.h>
|
| 610 |
+
#include <ATen/ops/elu_ops.h>
|
| 611 |
+
#include <ATen/ops/elu_backward_ops.h>
|
| 612 |
+
#include <ATen/ops/embedding_ops.h>
|
| 613 |
+
#include <ATen/ops/embedding_backward_ops.h>
|
| 614 |
+
#include <ATen/ops/embedding_bag_ops.h>
|
| 615 |
+
#include <ATen/ops/embedding_dense_backward_ops.h>
|
| 616 |
+
#include <ATen/ops/embedding_renorm_ops.h>
|
| 617 |
+
#include <ATen/ops/embedding_sparse_backward_ops.h>
|
| 618 |
+
#include <ATen/ops/empty_ops.h>
|
| 619 |
+
#include <ATen/ops/empty_like_ops.h>
|
| 620 |
+
#include <ATen/ops/empty_permuted_ops.h>
|
| 621 |
+
#include <ATen/ops/empty_quantized_ops.h>
|
| 622 |
+
#include <ATen/ops/empty_strided_ops.h>
|
| 623 |
+
#include <ATen/ops/eq_ops.h>
|
| 624 |
+
#include <ATen/ops/equal_ops.h>
|
| 625 |
+
#include <ATen/ops/erf_ops.h>
|
| 626 |
+
#include <ATen/ops/erfc_ops.h>
|
| 627 |
+
#include <ATen/ops/erfinv_ops.h>
|
| 628 |
+
#include <ATen/ops/exp_ops.h>
|
| 629 |
+
#include <ATen/ops/exp2_ops.h>
|
| 630 |
+
#include <ATen/ops/expand_ops.h>
|
| 631 |
+
#include <ATen/ops/expand_as_ops.h>
|
| 632 |
+
#include <ATen/ops/expand_copy_ops.h>
|
| 633 |
+
#include <ATen/ops/expm1_ops.h>
|
| 634 |
+
#include <ATen/ops/exponential_ops.h>
|
| 635 |
+
#include <ATen/ops/eye_ops.h>
|
| 636 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_ops.h>
|
| 637 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_ops.h>
|
| 638 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_backward_ops.h>
|
| 639 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_ops.h>
|
| 640 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_ops.h>
|
| 641 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_backward_ops.h>
|
| 642 |
+
#include <ATen/ops/fbgemm_linear_fp16_weight_ops.h>
|
| 643 |
+
#include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation_ops.h>
|
| 644 |
+
#include <ATen/ops/fbgemm_linear_int8_weight_ops.h>
|
| 645 |
+
#include <ATen/ops/fbgemm_linear_int8_weight_fp32_activation_ops.h>
|
| 646 |
+
#include <ATen/ops/fbgemm_linear_quantize_weight_ops.h>
|
| 647 |
+
#include <ATen/ops/fbgemm_pack_gemm_matrix_fp16_ops.h>
|
| 648 |
+
#include <ATen/ops/fbgemm_pack_quantized_matrix_ops.h>
|
| 649 |
+
#include <ATen/ops/feature_alpha_dropout_ops.h>
|
| 650 |
+
#include <ATen/ops/feature_dropout_ops.h>
|
| 651 |
+
#include <ATen/ops/fft_fft_ops.h>
|
| 652 |
+
#include <ATen/ops/fft_fft2_ops.h>
|
| 653 |
+
#include <ATen/ops/fft_fftfreq_ops.h>
|
| 654 |
+
#include <ATen/ops/fft_fftn_ops.h>
|
| 655 |
+
#include <ATen/ops/fft_fftshift_ops.h>
|
| 656 |
+
#include <ATen/ops/fft_hfft_ops.h>
|
| 657 |
+
#include <ATen/ops/fft_hfft2_ops.h>
|
| 658 |
+
#include <ATen/ops/fft_hfftn_ops.h>
|
| 659 |
+
#include <ATen/ops/fft_ifft_ops.h>
|
| 660 |
+
#include <ATen/ops/fft_ifft2_ops.h>
|
| 661 |
+
#include <ATen/ops/fft_ifftn_ops.h>
|
| 662 |
+
#include <ATen/ops/fft_ifftshift_ops.h>
|
| 663 |
+
#include <ATen/ops/fft_ihfft_ops.h>
|
| 664 |
+
#include <ATen/ops/fft_ihfft2_ops.h>
|
| 665 |
+
#include <ATen/ops/fft_ihfftn_ops.h>
|
| 666 |
+
#include <ATen/ops/fft_irfft_ops.h>
|
| 667 |
+
#include <ATen/ops/fft_irfft2_ops.h>
|
| 668 |
+
#include <ATen/ops/fft_irfftn_ops.h>
|
| 669 |
+
#include <ATen/ops/fft_rfft_ops.h>
|
| 670 |
+
#include <ATen/ops/fft_rfft2_ops.h>
|
| 671 |
+
#include <ATen/ops/fft_rfftfreq_ops.h>
|
| 672 |
+
#include <ATen/ops/fft_rfftn_ops.h>
|
| 673 |
+
#include <ATen/ops/fill_ops.h>
|
| 674 |
+
#include <ATen/ops/fill_diagonal_ops.h>
|
| 675 |
+
#include <ATen/ops/fix_ops.h>
|
| 676 |
+
#include <ATen/ops/flatten_ops.h>
|
| 677 |
+
#include <ATen/ops/flatten_dense_tensors_ops.h>
|
| 678 |
+
#include <ATen/ops/flip_ops.h>
|
| 679 |
+
#include <ATen/ops/fliplr_ops.h>
|
| 680 |
+
#include <ATen/ops/flipud_ops.h>
|
| 681 |
+
#include <ATen/ops/float_power_ops.h>
|
| 682 |
+
#include <ATen/ops/floor_ops.h>
|
| 683 |
+
#include <ATen/ops/floor_divide_ops.h>
|
| 684 |
+
#include <ATen/ops/fmax_ops.h>
|
| 685 |
+
#include <ATen/ops/fmin_ops.h>
|
| 686 |
+
#include <ATen/ops/fmod_ops.h>
|
| 687 |
+
#include <ATen/ops/frac_ops.h>
|
| 688 |
+
#include <ATen/ops/fractional_max_pool2d_ops.h>
|
| 689 |
+
#include <ATen/ops/fractional_max_pool2d_backward_ops.h>
|
| 690 |
+
#include <ATen/ops/fractional_max_pool3d_ops.h>
|
| 691 |
+
#include <ATen/ops/fractional_max_pool3d_backward_ops.h>
|
| 692 |
+
#include <ATen/ops/frexp_ops.h>
|
| 693 |
+
#include <ATen/ops/frobenius_norm_ops.h>
|
| 694 |
+
#include <ATen/ops/from_file_ops.h>
|
| 695 |
+
#include <ATen/ops/full_ops.h>
|
| 696 |
+
#include <ATen/ops/full_like_ops.h>
|
| 697 |
+
#include <ATen/ops/fused_moving_avg_obs_fake_quant_ops.h>
|
| 698 |
+
#include <ATen/ops/gather_ops.h>
|
| 699 |
+
#include <ATen/ops/gather_backward_ops.h>
|
| 700 |
+
#include <ATen/ops/gcd_ops.h>
|
| 701 |
+
#include <ATen/ops/ge_ops.h>
|
| 702 |
+
#include <ATen/ops/gelu_ops.h>
|
| 703 |
+
#include <ATen/ops/gelu_backward_ops.h>
|
| 704 |
+
#include <ATen/ops/geometric_ops.h>
|
| 705 |
+
#include <ATen/ops/geqrf_ops.h>
|
| 706 |
+
#include <ATen/ops/ger_ops.h>
|
| 707 |
+
#include <ATen/ops/glu_ops.h>
|
| 708 |
+
#include <ATen/ops/glu_backward_ops.h>
|
| 709 |
+
#include <ATen/ops/glu_backward_jvp_ops.h>
|
| 710 |
+
#include <ATen/ops/glu_jvp_ops.h>
|
| 711 |
+
#include <ATen/ops/gradient_ops.h>
|
| 712 |
+
#include <ATen/ops/greater_ops.h>
|
| 713 |
+
#include <ATen/ops/greater_equal_ops.h>
|
| 714 |
+
#include <ATen/ops/grid_sampler_ops.h>
|
| 715 |
+
#include <ATen/ops/grid_sampler_2d_ops.h>
|
| 716 |
+
#include <ATen/ops/grid_sampler_2d_backward_ops.h>
|
| 717 |
+
#include <ATen/ops/grid_sampler_3d_ops.h>
|
| 718 |
+
#include <ATen/ops/grid_sampler_3d_backward_ops.h>
|
| 719 |
+
#include <ATen/ops/group_norm_ops.h>
|
| 720 |
+
#include <ATen/ops/gru_ops.h>
|
| 721 |
+
#include <ATen/ops/gru_cell_ops.h>
|
| 722 |
+
#include <ATen/ops/gt_ops.h>
|
| 723 |
+
#include <ATen/ops/hamming_window_ops.h>
|
| 724 |
+
#include <ATen/ops/hann_window_ops.h>
|
| 725 |
+
#include <ATen/ops/hardshrink_ops.h>
|
| 726 |
+
#include <ATen/ops/hardshrink_backward_ops.h>
|
| 727 |
+
#include <ATen/ops/hardsigmoid_ops.h>
|
| 728 |
+
#include <ATen/ops/hardsigmoid_backward_ops.h>
|
| 729 |
+
#include <ATen/ops/hardswish_ops.h>
|
| 730 |
+
#include <ATen/ops/hardswish_backward_ops.h>
|
| 731 |
+
#include <ATen/ops/hardtanh_ops.h>
|
| 732 |
+
#include <ATen/ops/hardtanh_backward_ops.h>
|
| 733 |
+
#include <ATen/ops/heaviside_ops.h>
|
| 734 |
+
#include <ATen/ops/hinge_embedding_loss_ops.h>
|
| 735 |
+
#include <ATen/ops/histc_ops.h>
|
| 736 |
+
#include <ATen/ops/histogram_ops.h>
|
| 737 |
+
#include <ATen/ops/histogramdd_ops.h>
|
| 738 |
+
#include <ATen/ops/hsplit_ops.h>
|
| 739 |
+
#include <ATen/ops/hspmm_ops.h>
|
| 740 |
+
#include <ATen/ops/hstack_ops.h>
|
| 741 |
+
#include <ATen/ops/huber_loss_ops.h>
|
| 742 |
+
#include <ATen/ops/huber_loss_backward_ops.h>
|
| 743 |
+
#include <ATen/ops/hypot_ops.h>
|
| 744 |
+
#include <ATen/ops/i0_ops.h>
|
| 745 |
+
#include <ATen/ops/igamma_ops.h>
|
| 746 |
+
#include <ATen/ops/igammac_ops.h>
|
| 747 |
+
#include <ATen/ops/im2col_ops.h>
|
| 748 |
+
#include <ATen/ops/imag_ops.h>
|
| 749 |
+
#include <ATen/ops/index_ops.h>
|
| 750 |
+
#include <ATen/ops/index_add_ops.h>
|
| 751 |
+
#include <ATen/ops/index_copy_ops.h>
|
| 752 |
+
#include <ATen/ops/index_fill_ops.h>
|
| 753 |
+
#include <ATen/ops/index_put_ops.h>
|
| 754 |
+
#include <ATen/ops/index_reduce_ops.h>
|
| 755 |
+
#include <ATen/ops/index_select_ops.h>
|
| 756 |
+
#include <ATen/ops/index_select_backward_ops.h>
|
| 757 |
+
#include <ATen/ops/indices_ops.h>
|
| 758 |
+
#include <ATen/ops/indices_copy_ops.h>
|
| 759 |
+
#include <ATen/ops/infinitely_differentiable_gelu_backward_ops.h>
|
| 760 |
+
#include <ATen/ops/inner_ops.h>
|
| 761 |
+
#include <ATen/ops/instance_norm_ops.h>
|
| 762 |
+
#include <ATen/ops/int_repr_ops.h>
|
| 763 |
+
#include <ATen/ops/inverse_ops.h>
|
| 764 |
+
#include <ATen/ops/is_coalesced_ops.h>
|
| 765 |
+
#include <ATen/ops/is_complex_ops.h>
|
| 766 |
+
#include <ATen/ops/is_conj_ops.h>
|
| 767 |
+
#include <ATen/ops/is_distributed_ops.h>
|
| 768 |
+
#include <ATen/ops/is_floating_point_ops.h>
|
| 769 |
+
#include <ATen/ops/is_inference_ops.h>
|
| 770 |
+
#include <ATen/ops/is_leaf_ops.h>
|
| 771 |
+
#include <ATen/ops/is_neg_ops.h>
|
| 772 |
+
#include <ATen/ops/is_nonzero_ops.h>
|
| 773 |
+
#include <ATen/ops/is_pinned_ops.h>
|
| 774 |
+
#include <ATen/ops/is_same_size_ops.h>
|
| 775 |
+
#include <ATen/ops/is_set_to_ops.h>
|
| 776 |
+
#include <ATen/ops/is_signed_ops.h>
|
| 777 |
+
#include <ATen/ops/is_vulkan_available_ops.h>
|
| 778 |
+
#include <ATen/ops/isclose_ops.h>
|
| 779 |
+
#include <ATen/ops/isfinite_ops.h>
|
| 780 |
+
#include <ATen/ops/isin_ops.h>
|
| 781 |
+
#include <ATen/ops/isinf_ops.h>
|
| 782 |
+
#include <ATen/ops/isnan_ops.h>
|
| 783 |
+
#include <ATen/ops/isneginf_ops.h>
|
| 784 |
+
#include <ATen/ops/isposinf_ops.h>
|
| 785 |
+
#include <ATen/ops/isreal_ops.h>
|
| 786 |
+
#include <ATen/ops/istft_ops.h>
|
| 787 |
+
#include <ATen/ops/item_ops.h>
|
| 788 |
+
#include <ATen/ops/kaiser_window_ops.h>
|
| 789 |
+
#include <ATen/ops/kl_div_ops.h>
|
| 790 |
+
#include <ATen/ops/kron_ops.h>
|
| 791 |
+
#include <ATen/ops/kthvalue_ops.h>
|
| 792 |
+
#include <ATen/ops/l1_loss_ops.h>
|
| 793 |
+
#include <ATen/ops/layer_norm_ops.h>
|
| 794 |
+
#include <ATen/ops/lcm_ops.h>
|
| 795 |
+
#include <ATen/ops/ldexp_ops.h>
|
| 796 |
+
#include <ATen/ops/le_ops.h>
|
| 797 |
+
#include <ATen/ops/leaky_relu_ops.h>
|
| 798 |
+
#include <ATen/ops/leaky_relu_backward_ops.h>
|
| 799 |
+
#include <ATen/ops/lerp_ops.h>
|
| 800 |
+
#include <ATen/ops/less_ops.h>
|
| 801 |
+
#include <ATen/ops/less_equal_ops.h>
|
| 802 |
+
#include <ATen/ops/lgamma_ops.h>
|
| 803 |
+
#include <ATen/ops/lift_ops.h>
|
| 804 |
+
#include <ATen/ops/lift_fresh_ops.h>
|
| 805 |
+
#include <ATen/ops/lift_fresh_copy_ops.h>
|
| 806 |
+
#include <ATen/ops/linalg_cholesky_ops.h>
|
| 807 |
+
#include <ATen/ops/linalg_cholesky_ex_ops.h>
|
| 808 |
+
#include <ATen/ops/linalg_cond_ops.h>
|
| 809 |
+
#include <ATen/ops/linalg_cross_ops.h>
|
| 810 |
+
#include <ATen/ops/linalg_det_ops.h>
|
| 811 |
+
#include <ATen/ops/linalg_diagonal_ops.h>
|
| 812 |
+
#include <ATen/ops/linalg_eig_ops.h>
|
| 813 |
+
#include <ATen/ops/linalg_eigh_ops.h>
|
| 814 |
+
#include <ATen/ops/linalg_eigvals_ops.h>
|
| 815 |
+
#include <ATen/ops/linalg_eigvalsh_ops.h>
|
| 816 |
+
#include <ATen/ops/linalg_householder_product_ops.h>
|
| 817 |
+
#include <ATen/ops/linalg_inv_ops.h>
|
| 818 |
+
#include <ATen/ops/linalg_inv_ex_ops.h>
|
| 819 |
+
#include <ATen/ops/linalg_ldl_factor_ops.h>
|
| 820 |
+
#include <ATen/ops/linalg_ldl_factor_ex_ops.h>
|
| 821 |
+
#include <ATen/ops/linalg_ldl_solve_ops.h>
|
| 822 |
+
#include <ATen/ops/linalg_lstsq_ops.h>
|
| 823 |
+
#include <ATen/ops/linalg_lu_ops.h>
|
| 824 |
+
#include <ATen/ops/linalg_lu_factor_ops.h>
|
| 825 |
+
#include <ATen/ops/linalg_lu_factor_ex_ops.h>
|
| 826 |
+
#include <ATen/ops/linalg_lu_solve_ops.h>
|
| 827 |
+
#include <ATen/ops/linalg_matmul_ops.h>
|
| 828 |
+
#include <ATen/ops/linalg_matrix_exp_ops.h>
|
| 829 |
+
#include <ATen/ops/linalg_matrix_norm_ops.h>
|
| 830 |
+
#include <ATen/ops/linalg_matrix_power_ops.h>
|
| 831 |
+
#include <ATen/ops/linalg_matrix_rank_ops.h>
|
| 832 |
+
#include <ATen/ops/linalg_multi_dot_ops.h>
|
| 833 |
+
#include <ATen/ops/linalg_norm_ops.h>
|
| 834 |
+
#include <ATen/ops/linalg_pinv_ops.h>
|
| 835 |
+
#include <ATen/ops/linalg_qr_ops.h>
|
| 836 |
+
#include <ATen/ops/linalg_slogdet_ops.h>
|
| 837 |
+
#include <ATen/ops/linalg_solve_ops.h>
|
| 838 |
+
#include <ATen/ops/linalg_solve_ex_ops.h>
|
| 839 |
+
#include <ATen/ops/linalg_solve_triangular_ops.h>
|
| 840 |
+
#include <ATen/ops/linalg_svd_ops.h>
|
| 841 |
+
#include <ATen/ops/linalg_svdvals_ops.h>
|
| 842 |
+
#include <ATen/ops/linalg_tensorinv_ops.h>
|
| 843 |
+
#include <ATen/ops/linalg_tensorsolve_ops.h>
|
| 844 |
+
#include <ATen/ops/linalg_vander_ops.h>
|
| 845 |
+
#include <ATen/ops/linalg_vecdot_ops.h>
|
| 846 |
+
#include <ATen/ops/linalg_vector_norm_ops.h>
|
| 847 |
+
#include <ATen/ops/linear_ops.h>
|
| 848 |
+
#include <ATen/ops/linear_backward_ops.h>
|
| 849 |
+
#include <ATen/ops/linspace_ops.h>
|
| 850 |
+
#include <ATen/ops/log_ops.h>
|
| 851 |
+
#include <ATen/ops/log10_ops.h>
|
| 852 |
+
#include <ATen/ops/log1p_ops.h>
|
| 853 |
+
#include <ATen/ops/log2_ops.h>
|
| 854 |
+
#include <ATen/ops/log_normal_ops.h>
|
| 855 |
+
#include <ATen/ops/log_sigmoid_ops.h>
|
| 856 |
+
#include <ATen/ops/log_sigmoid_backward_ops.h>
|
| 857 |
+
#include <ATen/ops/log_sigmoid_forward_ops.h>
|
| 858 |
+
#include <ATen/ops/log_softmax_ops.h>
|
| 859 |
+
#include <ATen/ops/logaddexp_ops.h>
|
| 860 |
+
#include <ATen/ops/logaddexp2_ops.h>
|
| 861 |
+
#include <ATen/ops/logcumsumexp_ops.h>
|
| 862 |
+
#include <ATen/ops/logdet_ops.h>
|
| 863 |
+
#include <ATen/ops/logical_and_ops.h>
|
| 864 |
+
#include <ATen/ops/logical_not_ops.h>
|
| 865 |
+
#include <ATen/ops/logical_or_ops.h>
|
| 866 |
+
#include <ATen/ops/logical_xor_ops.h>
|
| 867 |
+
#include <ATen/ops/logit_ops.h>
|
| 868 |
+
#include <ATen/ops/logit_backward_ops.h>
|
| 869 |
+
#include <ATen/ops/logspace_ops.h>
|
| 870 |
+
#include <ATen/ops/logsumexp_ops.h>
|
| 871 |
+
#include <ATen/ops/lshift_ops.h>
|
| 872 |
+
#include <ATen/ops/lstm_ops.h>
|
| 873 |
+
#include <ATen/ops/lstm_cell_ops.h>
|
| 874 |
+
#include <ATen/ops/lstm_mps_backward_ops.h>
|
| 875 |
+
#include <ATen/ops/lt_ops.h>
|
| 876 |
+
#include <ATen/ops/lu_solve_ops.h>
|
| 877 |
+
#include <ATen/ops/lu_unpack_ops.h>
|
| 878 |
+
#include <ATen/ops/mH_ops.h>
|
| 879 |
+
#include <ATen/ops/mT_ops.h>
|
| 880 |
+
#include <ATen/ops/margin_ranking_loss_ops.h>
|
| 881 |
+
#include <ATen/ops/masked_fill_ops.h>
|
| 882 |
+
#include <ATen/ops/masked_scatter_ops.h>
|
| 883 |
+
#include <ATen/ops/masked_scatter_backward_ops.h>
|
| 884 |
+
#include <ATen/ops/masked_select_ops.h>
|
| 885 |
+
#include <ATen/ops/masked_select_backward_ops.h>
|
| 886 |
+
#include <ATen/ops/matmul_ops.h>
|
| 887 |
+
#include <ATen/ops/matmul_backward_ops.h>
|
| 888 |
+
#include <ATen/ops/matrix_H_ops.h>
|
| 889 |
+
#include <ATen/ops/matrix_exp_ops.h>
|
| 890 |
+
#include <ATen/ops/matrix_exp_backward_ops.h>
|
| 891 |
+
#include <ATen/ops/matrix_power_ops.h>
|
| 892 |
+
#include <ATen/ops/max_ops.h>
|
| 893 |
+
#include <ATen/ops/max_pool1d_ops.h>
|
| 894 |
+
#include <ATen/ops/max_pool1d_with_indices_ops.h>
|
| 895 |
+
#include <ATen/ops/max_pool2d_ops.h>
|
| 896 |
+
#include <ATen/ops/max_pool2d_backward_ops.h>
|
| 897 |
+
#include <ATen/ops/max_pool2d_with_indices_ops.h>
|
| 898 |
+
#include <ATen/ops/max_pool2d_with_indices_backward_ops.h>
|
| 899 |
+
#include <ATen/ops/max_pool3d_ops.h>
|
| 900 |
+
#include <ATen/ops/max_pool3d_with_indices_ops.h>
|
| 901 |
+
#include <ATen/ops/max_pool3d_with_indices_backward_ops.h>
|
| 902 |
+
#include <ATen/ops/max_unpool2d_ops.h>
|
| 903 |
+
#include <ATen/ops/max_unpool3d_ops.h>
|
| 904 |
+
#include <ATen/ops/maximum_ops.h>
|
| 905 |
+
#include <ATen/ops/mean_ops.h>
|
| 906 |
+
#include <ATen/ops/median_ops.h>
|
| 907 |
+
#include <ATen/ops/meshgrid_ops.h>
|
| 908 |
+
#include <ATen/ops/min_ops.h>
|
| 909 |
+
#include <ATen/ops/minimum_ops.h>
|
| 910 |
+
#include <ATen/ops/miopen_batch_norm_ops.h>
|
| 911 |
+
#include <ATen/ops/miopen_batch_norm_backward_ops.h>
|
| 912 |
+
#include <ATen/ops/miopen_convolution_ops.h>
|
| 913 |
+
#include <ATen/ops/miopen_convolution_add_relu_ops.h>
|
| 914 |
+
#include <ATen/ops/miopen_convolution_relu_ops.h>
|
| 915 |
+
#include <ATen/ops/miopen_convolution_transpose_ops.h>
|
| 916 |
+
#include <ATen/ops/miopen_depthwise_convolution_ops.h>
|
| 917 |
+
#include <ATen/ops/miopen_rnn_ops.h>
|
| 918 |
+
#include <ATen/ops/miopen_rnn_backward_ops.h>
|
| 919 |
+
#include <ATen/ops/mish_ops.h>
|
| 920 |
+
#include <ATen/ops/mish_backward_ops.h>
|
| 921 |
+
#include <ATen/ops/mkldnn_adaptive_avg_pool2d_ops.h>
|
| 922 |
+
#include <ATen/ops/mkldnn_adaptive_avg_pool2d_backward_ops.h>
|
| 923 |
+
#include <ATen/ops/mkldnn_convolution_ops.h>
|
| 924 |
+
#include <ATen/ops/mkldnn_linear_ops.h>
|
| 925 |
+
#include <ATen/ops/mkldnn_linear_backward_ops.h>
|
| 926 |
+
#include <ATen/ops/mkldnn_linear_backward_input_ops.h>
|
| 927 |
+
#include <ATen/ops/mkldnn_linear_backward_weights_ops.h>
|
| 928 |
+
#include <ATen/ops/mkldnn_max_pool2d_ops.h>
|
| 929 |
+
#include <ATen/ops/mkldnn_max_pool2d_backward_ops.h>
|
| 930 |
+
#include <ATen/ops/mkldnn_max_pool3d_ops.h>
|
| 931 |
+
#include <ATen/ops/mkldnn_max_pool3d_backward_ops.h>
|
| 932 |
+
#include <ATen/ops/mkldnn_reorder_conv2d_weight_ops.h>
|
| 933 |
+
#include <ATen/ops/mkldnn_reorder_conv3d_weight_ops.h>
|
| 934 |
+
#include <ATen/ops/mkldnn_rnn_layer_ops.h>
|
| 935 |
+
#include <ATen/ops/mkldnn_rnn_layer_backward_ops.h>
|
| 936 |
+
#include <ATen/ops/mm_ops.h>
|
| 937 |
+
#include <ATen/ops/mode_ops.h>
|
| 938 |
+
#include <ATen/ops/moveaxis_ops.h>
|
| 939 |
+
#include <ATen/ops/movedim_ops.h>
|
| 940 |
+
#include <ATen/ops/mps_convolution_backward_ops.h>
|
| 941 |
+
#include <ATen/ops/mps_convolution_transpose_backward_ops.h>
|
| 942 |
+
#include <ATen/ops/mse_loss_ops.h>
|
| 943 |
+
#include <ATen/ops/mse_loss_backward_ops.h>
|
| 944 |
+
#include <ATen/ops/msort_ops.h>
|
| 945 |
+
#include <ATen/ops/mul_ops.h>
|
| 946 |
+
#include <ATen/ops/multi_margin_loss_ops.h>
|
| 947 |
+
#include <ATen/ops/multi_margin_loss_backward_ops.h>
|
| 948 |
+
#include <ATen/ops/multilabel_margin_loss_ops.h>
|
| 949 |
+
#include <ATen/ops/multilabel_margin_loss_backward_ops.h>
|
| 950 |
+
#include <ATen/ops/multilabel_margin_loss_forward_ops.h>
|
| 951 |
+
#include <ATen/ops/multinomial_ops.h>
|
| 952 |
+
#include <ATen/ops/multiply_ops.h>
|
| 953 |
+
#include <ATen/ops/mv_ops.h>
|
| 954 |
+
#include <ATen/ops/mvlgamma_ops.h>
|
| 955 |
+
#include <ATen/ops/nan_to_num_ops.h>
|
| 956 |
+
#include <ATen/ops/nanmean_ops.h>
|
| 957 |
+
#include <ATen/ops/nanmedian_ops.h>
|
| 958 |
+
#include <ATen/ops/nanquantile_ops.h>
|
| 959 |
+
#include <ATen/ops/nansum_ops.h>
|
| 960 |
+
#include <ATen/ops/narrow_ops.h>
|
| 961 |
+
#include <ATen/ops/narrow_copy_ops.h>
|
| 962 |
+
#include <ATen/ops/native_batch_norm_ops.h>
|
| 963 |
+
#include <ATen/ops/native_batch_norm_backward_ops.h>
|
| 964 |
+
#include <ATen/ops/native_channel_shuffle_ops.h>
|
| 965 |
+
#include <ATen/ops/native_dropout_ops.h>
|
| 966 |
+
#include <ATen/ops/native_dropout_backward_ops.h>
|
| 967 |
+
#include <ATen/ops/native_group_norm_ops.h>
|
| 968 |
+
#include <ATen/ops/native_group_norm_backward_ops.h>
|
| 969 |
+
#include <ATen/ops/native_layer_norm_ops.h>
|
| 970 |
+
#include <ATen/ops/native_layer_norm_backward_ops.h>
|
| 971 |
+
#include <ATen/ops/native_norm_ops.h>
|
| 972 |
+
#include <ATen/ops/ne_ops.h>
|
| 973 |
+
#include <ATen/ops/neg_ops.h>
|
| 974 |
+
#include <ATen/ops/negative_ops.h>
|
| 975 |
+
#include <ATen/ops/nested_to_padded_tensor_ops.h>
|
| 976 |
+
#include <ATen/ops/new_empty_ops.h>
|
| 977 |
+
#include <ATen/ops/new_empty_strided_ops.h>
|
| 978 |
+
#include <ATen/ops/new_full_ops.h>
|
| 979 |
+
#include <ATen/ops/new_ones_ops.h>
|
| 980 |
+
#include <ATen/ops/new_zeros_ops.h>
|
| 981 |
+
#include <ATen/ops/nextafter_ops.h>
|
| 982 |
+
#include <ATen/ops/nll_loss_ops.h>
|
| 983 |
+
#include <ATen/ops/nll_loss2d_ops.h>
|
| 984 |
+
#include <ATen/ops/nll_loss2d_backward_ops.h>
|
| 985 |
+
#include <ATen/ops/nll_loss2d_forward_ops.h>
|
| 986 |
+
#include <ATen/ops/nll_loss_backward_ops.h>
|
| 987 |
+
#include <ATen/ops/nll_loss_forward_ops.h>
|
| 988 |
+
#include <ATen/ops/nll_loss_nd_ops.h>
|
| 989 |
+
#include <ATen/ops/nonzero_ops.h>
|
| 990 |
+
#include <ATen/ops/nonzero_numpy_ops.h>
|
| 991 |
+
#include <ATen/ops/nonzero_static_ops.h>
|
| 992 |
+
#include <ATen/ops/norm_ops.h>
|
| 993 |
+
#include <ATen/ops/norm_except_dim_ops.h>
|
| 994 |
+
#include <ATen/ops/normal_ops.h>
|
| 995 |
+
#include <ATen/ops/not_equal_ops.h>
|
| 996 |
+
#include <ATen/ops/nuclear_norm_ops.h>
|
| 997 |
+
#include <ATen/ops/numpy_T_ops.h>
|
| 998 |
+
#include <ATen/ops/one_hot_ops.h>
|
| 999 |
+
#include <ATen/ops/ones_ops.h>
|
| 1000 |
+
#include <ATen/ops/ones_like_ops.h>
|
| 1001 |
+
#include <ATen/ops/or_ops.h>
|
| 1002 |
+
#include <ATen/ops/orgqr_ops.h>
|
| 1003 |
+
#include <ATen/ops/ormqr_ops.h>
|
| 1004 |
+
#include <ATen/ops/outer_ops.h>
|
| 1005 |
+
#include <ATen/ops/output_nr_ops.h>
|
| 1006 |
+
#include <ATen/ops/pad_ops.h>
|
| 1007 |
+
#include <ATen/ops/pad_sequence_ops.h>
|
| 1008 |
+
#include <ATen/ops/pairwise_distance_ops.h>
|
| 1009 |
+
#include <ATen/ops/pdist_ops.h>
|
| 1010 |
+
#include <ATen/ops/permute_ops.h>
|
| 1011 |
+
#include <ATen/ops/permute_copy_ops.h>
|
| 1012 |
+
#include <ATen/ops/pin_memory_ops.h>
|
| 1013 |
+
#include <ATen/ops/pinverse_ops.h>
|
| 1014 |
+
#include <ATen/ops/pixel_shuffle_ops.h>
|
| 1015 |
+
#include <ATen/ops/pixel_unshuffle_ops.h>
|
| 1016 |
+
#include <ATen/ops/poisson_ops.h>
|
| 1017 |
+
#include <ATen/ops/poisson_nll_loss_ops.h>
|
| 1018 |
+
#include <ATen/ops/polar_ops.h>
|
| 1019 |
+
#include <ATen/ops/polygamma_ops.h>
|
| 1020 |
+
#include <ATen/ops/positive_ops.h>
|
| 1021 |
+
#include <ATen/ops/pow_ops.h>
|
| 1022 |
+
#include <ATen/ops/prelu_ops.h>
|
| 1023 |
+
#include <ATen/ops/prod_ops.h>
|
| 1024 |
+
#include <ATen/ops/promote_types_ops.h>
|
| 1025 |
+
#include <ATen/ops/put_ops.h>
|
| 1026 |
+
#include <ATen/ops/q_per_channel_axis_ops.h>
|
| 1027 |
+
#include <ATen/ops/q_per_channel_scales_ops.h>
|
| 1028 |
+
#include <ATen/ops/q_per_channel_zero_points_ops.h>
|
| 1029 |
+
#include <ATen/ops/q_scale_ops.h>
|
| 1030 |
+
#include <ATen/ops/q_zero_point_ops.h>
|
| 1031 |
+
#include <ATen/ops/qr_ops.h>
|
| 1032 |
+
#include <ATen/ops/qscheme_ops.h>
|
| 1033 |
+
#include <ATen/ops/quantile_ops.h>
|
| 1034 |
+
#include <ATen/ops/quantize_per_channel_ops.h>
|
| 1035 |
+
#include <ATen/ops/quantize_per_tensor_ops.h>
|
| 1036 |
+
#include <ATen/ops/quantize_per_tensor_dynamic_ops.h>
|
| 1037 |
+
#include <ATen/ops/quantized_batch_norm_ops.h>
|
| 1038 |
+
#include <ATen/ops/quantized_gru_cell_ops.h>
|
| 1039 |
+
#include <ATen/ops/quantized_lstm_cell_ops.h>
|
| 1040 |
+
#include <ATen/ops/quantized_max_pool1d_ops.h>
|
| 1041 |
+
#include <ATen/ops/quantized_max_pool2d_ops.h>
|
| 1042 |
+
#include <ATen/ops/quantized_max_pool3d_ops.h>
|
| 1043 |
+
#include <ATen/ops/quantized_rnn_relu_cell_ops.h>
|
| 1044 |
+
#include <ATen/ops/quantized_rnn_tanh_cell_ops.h>
|
| 1045 |
+
#include <ATen/ops/rad2deg_ops.h>
|
| 1046 |
+
#include <ATen/ops/rand_ops.h>
|
| 1047 |
+
#include <ATen/ops/rand_like_ops.h>
|
| 1048 |
+
#include <ATen/ops/randint_ops.h>
|
| 1049 |
+
#include <ATen/ops/randint_like_ops.h>
|
| 1050 |
+
#include <ATen/ops/randn_ops.h>
|
| 1051 |
+
#include <ATen/ops/randn_like_ops.h>
|
| 1052 |
+
#include <ATen/ops/random_ops.h>
|
| 1053 |
+
#include <ATen/ops/randperm_ops.h>
|
| 1054 |
+
#include <ATen/ops/range_ops.h>
|
| 1055 |
+
#include <ATen/ops/ravel_ops.h>
|
| 1056 |
+
#include <ATen/ops/real_ops.h>
|
| 1057 |
+
#include <ATen/ops/reciprocal_ops.h>
|
| 1058 |
+
#include <ATen/ops/record_stream_ops.h>
|
| 1059 |
+
#include <ATen/ops/refine_names_ops.h>
|
| 1060 |
+
#include <ATen/ops/reflection_pad1d_ops.h>
|
| 1061 |
+
#include <ATen/ops/reflection_pad1d_backward_ops.h>
|
| 1062 |
+
#include <ATen/ops/reflection_pad2d_ops.h>
|
| 1063 |
+
#include <ATen/ops/reflection_pad2d_backward_ops.h>
|
| 1064 |
+
#include <ATen/ops/reflection_pad3d_ops.h>
|
| 1065 |
+
#include <ATen/ops/reflection_pad3d_backward_ops.h>
|
| 1066 |
+
#include <ATen/ops/relu_ops.h>
|
| 1067 |
+
#include <ATen/ops/relu6_ops.h>
|
| 1068 |
+
#include <ATen/ops/remainder_ops.h>
|
| 1069 |
+
#include <ATen/ops/rename_ops.h>
|
| 1070 |
+
#include <ATen/ops/renorm_ops.h>
|
| 1071 |
+
#include <ATen/ops/repeat_ops.h>
|
| 1072 |
+
#include <ATen/ops/repeat_interleave_ops.h>
|
| 1073 |
+
#include <ATen/ops/replication_pad1d_ops.h>
|
| 1074 |
+
#include <ATen/ops/replication_pad1d_backward_ops.h>
|
| 1075 |
+
#include <ATen/ops/replication_pad2d_ops.h>
|
| 1076 |
+
#include <ATen/ops/replication_pad2d_backward_ops.h>
|
| 1077 |
+
#include <ATen/ops/replication_pad3d_ops.h>
|
| 1078 |
+
#include <ATen/ops/replication_pad3d_backward_ops.h>
|
| 1079 |
+
#include <ATen/ops/requires_grad_ops.h>
|
| 1080 |
+
#include <ATen/ops/reshape_ops.h>
|
| 1081 |
+
#include <ATen/ops/reshape_as_ops.h>
|
| 1082 |
+
#include <ATen/ops/resize_ops.h>
|
| 1083 |
+
#include <ATen/ops/resize_as_ops.h>
|
| 1084 |
+
#include <ATen/ops/resize_as_sparse_ops.h>
|
| 1085 |
+
#include <ATen/ops/resolve_conj_ops.h>
|
| 1086 |
+
#include <ATen/ops/resolve_neg_ops.h>
|
| 1087 |
+
#include <ATen/ops/result_type_ops.h>
|
| 1088 |
+
#include <ATen/ops/retain_grad_ops.h>
|
| 1089 |
+
#include <ATen/ops/retains_grad_ops.h>
|
| 1090 |
+
#include <ATen/ops/rms_norm_ops.h>
|
| 1091 |
+
#include <ATen/ops/rnn_relu_ops.h>
|
| 1092 |
+
#include <ATen/ops/rnn_relu_cell_ops.h>
|
| 1093 |
+
#include <ATen/ops/rnn_tanh_ops.h>
|
| 1094 |
+
#include <ATen/ops/rnn_tanh_cell_ops.h>
|
| 1095 |
+
#include <ATen/ops/roll_ops.h>
|
| 1096 |
+
#include <ATen/ops/rot90_ops.h>
|
| 1097 |
+
#include <ATen/ops/round_ops.h>
|
| 1098 |
+
#include <ATen/ops/row_indices_ops.h>
|
| 1099 |
+
#include <ATen/ops/row_indices_copy_ops.h>
|
| 1100 |
+
#include <ATen/ops/row_stack_ops.h>
|
| 1101 |
+
#include <ATen/ops/rrelu_ops.h>
|
| 1102 |
+
#include <ATen/ops/rrelu_with_noise_ops.h>
|
| 1103 |
+
#include <ATen/ops/rrelu_with_noise_backward_ops.h>
|
| 1104 |
+
#include <ATen/ops/rshift_ops.h>
|
| 1105 |
+
#include <ATen/ops/rsqrt_ops.h>
|
| 1106 |
+
#include <ATen/ops/rsub_ops.h>
|
| 1107 |
+
#include <ATen/ops/scalar_tensor_ops.h>
|
| 1108 |
+
#include <ATen/ops/scaled_dot_product_attention_ops.h>
|
| 1109 |
+
#include <ATen/ops/scatter_ops.h>
|
| 1110 |
+
#include <ATen/ops/scatter_add_ops.h>
|
| 1111 |
+
#include <ATen/ops/scatter_reduce_ops.h>
|
| 1112 |
+
#include <ATen/ops/searchsorted_ops.h>
|
| 1113 |
+
#include <ATen/ops/segment_reduce_ops.h>
|
| 1114 |
+
#include <ATen/ops/select_ops.h>
|
| 1115 |
+
#include <ATen/ops/select_backward_ops.h>
|
| 1116 |
+
#include <ATen/ops/select_copy_ops.h>
|
| 1117 |
+
#include <ATen/ops/select_scatter_ops.h>
|
| 1118 |
+
#include <ATen/ops/selu_ops.h>
|
| 1119 |
+
#include <ATen/ops/set_ops.h>
|
| 1120 |
+
#include <ATen/ops/set_data_ops.h>
|
| 1121 |
+
#include <ATen/ops/sgn_ops.h>
|
| 1122 |
+
#include <ATen/ops/sigmoid_ops.h>
|
| 1123 |
+
#include <ATen/ops/sigmoid_backward_ops.h>
|
| 1124 |
+
#include <ATen/ops/sign_ops.h>
|
| 1125 |
+
#include <ATen/ops/signbit_ops.h>
|
| 1126 |
+
#include <ATen/ops/silu_ops.h>
|
| 1127 |
+
#include <ATen/ops/silu_backward_ops.h>
|
| 1128 |
+
#include <ATen/ops/sin_ops.h>
|
| 1129 |
+
#include <ATen/ops/sinc_ops.h>
|
| 1130 |
+
#include <ATen/ops/sinh_ops.h>
|
| 1131 |
+
#include <ATen/ops/size_ops.h>
|
| 1132 |
+
#include <ATen/ops/slice_ops.h>
|
| 1133 |
+
#include <ATen/ops/slice_backward_ops.h>
|
| 1134 |
+
#include <ATen/ops/slice_copy_ops.h>
|
| 1135 |
+
#include <ATen/ops/slice_inverse_ops.h>
|
| 1136 |
+
#include <ATen/ops/slice_scatter_ops.h>
|
| 1137 |
+
#include <ATen/ops/slogdet_ops.h>
|
| 1138 |
+
#include <ATen/ops/slow_conv3d_ops.h>
|
| 1139 |
+
#include <ATen/ops/slow_conv3d_forward_ops.h>
|
| 1140 |
+
#include <ATen/ops/slow_conv_dilated2d_ops.h>
|
| 1141 |
+
#include <ATen/ops/slow_conv_dilated3d_ops.h>
|
| 1142 |
+
#include <ATen/ops/slow_conv_transpose2d_ops.h>
|
| 1143 |
+
#include <ATen/ops/slow_conv_transpose3d_ops.h>
|
| 1144 |
+
#include <ATen/ops/smm_ops.h>
|
| 1145 |
+
#include <ATen/ops/smooth_l1_loss_ops.h>
|
| 1146 |
+
#include <ATen/ops/smooth_l1_loss_backward_ops.h>
|
| 1147 |
+
#include <ATen/ops/soft_margin_loss_ops.h>
|
| 1148 |
+
#include <ATen/ops/soft_margin_loss_backward_ops.h>
|
| 1149 |
+
#include <ATen/ops/softmax_ops.h>
|
| 1150 |
+
#include <ATen/ops/softplus_ops.h>
|
| 1151 |
+
#include <ATen/ops/softplus_backward_ops.h>
|
| 1152 |
+
#include <ATen/ops/softshrink_ops.h>
|
| 1153 |
+
#include <ATen/ops/softshrink_backward_ops.h>
|
| 1154 |
+
#include <ATen/ops/sort_ops.h>
|
| 1155 |
+
#include <ATen/ops/sparse_bsc_tensor_ops.h>
|
| 1156 |
+
#include <ATen/ops/sparse_bsr_tensor_ops.h>
|
| 1157 |
+
#include <ATen/ops/sparse_compressed_tensor_ops.h>
|
| 1158 |
+
#include <ATen/ops/sparse_coo_tensor_ops.h>
|
| 1159 |
+
#include <ATen/ops/sparse_csc_tensor_ops.h>
|
| 1160 |
+
#include <ATen/ops/sparse_csr_tensor_ops.h>
|
| 1161 |
+
#include <ATen/ops/sparse_dim_ops.h>
|
| 1162 |
+
#include <ATen/ops/sparse_mask_ops.h>
|
| 1163 |
+
#include <ATen/ops/sparse_resize_ops.h>
|
| 1164 |
+
#include <ATen/ops/sparse_resize_and_clear_ops.h>
|
| 1165 |
+
#include <ATen/ops/sparse_sampled_addmm_ops.h>
|
| 1166 |
+
#include <ATen/ops/special_airy_ai_ops.h>
|
| 1167 |
+
#include <ATen/ops/special_bessel_j0_ops.h>
|
| 1168 |
+
#include <ATen/ops/special_bessel_j1_ops.h>
|
| 1169 |
+
#include <ATen/ops/special_bessel_y0_ops.h>
|
| 1170 |
+
#include <ATen/ops/special_bessel_y1_ops.h>
|
| 1171 |
+
#include <ATen/ops/special_chebyshev_polynomial_t_ops.h>
|
| 1172 |
+
#include <ATen/ops/special_chebyshev_polynomial_u_ops.h>
|
| 1173 |
+
#include <ATen/ops/special_chebyshev_polynomial_v_ops.h>
|
| 1174 |
+
#include <ATen/ops/special_chebyshev_polynomial_w_ops.h>
|
| 1175 |
+
#include <ATen/ops/special_digamma_ops.h>
|
| 1176 |
+
#include <ATen/ops/special_entr_ops.h>
|
| 1177 |
+
#include <ATen/ops/special_erf_ops.h>
|
| 1178 |
+
#include <ATen/ops/special_erfc_ops.h>
|
| 1179 |
+
#include <ATen/ops/special_erfcx_ops.h>
|
| 1180 |
+
#include <ATen/ops/special_erfinv_ops.h>
|
| 1181 |
+
#include <ATen/ops/special_exp2_ops.h>
|
| 1182 |
+
#include <ATen/ops/special_expit_ops.h>
|
| 1183 |
+
#include <ATen/ops/special_expm1_ops.h>
|
| 1184 |
+
#include <ATen/ops/special_gammainc_ops.h>
|
| 1185 |
+
#include <ATen/ops/special_gammaincc_ops.h>
|
| 1186 |
+
#include <ATen/ops/special_gammaln_ops.h>
|
| 1187 |
+
#include <ATen/ops/special_hermite_polynomial_h_ops.h>
|
| 1188 |
+
#include <ATen/ops/special_hermite_polynomial_he_ops.h>
|
| 1189 |
+
#include <ATen/ops/special_i0_ops.h>
|
| 1190 |
+
#include <ATen/ops/special_i0e_ops.h>
|
| 1191 |
+
#include <ATen/ops/special_i1_ops.h>
|
| 1192 |
+
#include <ATen/ops/special_i1e_ops.h>
|
| 1193 |
+
#include <ATen/ops/special_laguerre_polynomial_l_ops.h>
|
| 1194 |
+
#include <ATen/ops/special_legendre_polynomial_p_ops.h>
|
| 1195 |
+
#include <ATen/ops/special_log1p_ops.h>
|
| 1196 |
+
#include <ATen/ops/special_log_ndtr_ops.h>
|
| 1197 |
+
#include <ATen/ops/special_log_softmax_ops.h>
|
| 1198 |
+
#include <ATen/ops/special_logit_ops.h>
|
| 1199 |
+
#include <ATen/ops/special_logsumexp_ops.h>
|
| 1200 |
+
#include <ATen/ops/special_modified_bessel_i0_ops.h>
|
| 1201 |
+
#include <ATen/ops/special_modified_bessel_i1_ops.h>
|
| 1202 |
+
#include <ATen/ops/special_modified_bessel_k0_ops.h>
|
| 1203 |
+
#include <ATen/ops/special_modified_bessel_k1_ops.h>
|
| 1204 |
+
#include <ATen/ops/special_multigammaln_ops.h>
|
| 1205 |
+
#include <ATen/ops/special_ndtr_ops.h>
|
| 1206 |
+
#include <ATen/ops/special_ndtri_ops.h>
|
| 1207 |
+
#include <ATen/ops/special_polygamma_ops.h>
|
| 1208 |
+
#include <ATen/ops/special_psi_ops.h>
|
| 1209 |
+
#include <ATen/ops/special_round_ops.h>
|
| 1210 |
+
#include <ATen/ops/special_scaled_modified_bessel_k0_ops.h>
|
| 1211 |
+
#include <ATen/ops/special_scaled_modified_bessel_k1_ops.h>
|
| 1212 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_t_ops.h>
|
| 1213 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_u_ops.h>
|
| 1214 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_v_ops.h>
|
| 1215 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_w_ops.h>
|
| 1216 |
+
#include <ATen/ops/special_sinc_ops.h>
|
| 1217 |
+
#include <ATen/ops/special_softmax_ops.h>
|
| 1218 |
+
#include <ATen/ops/special_spherical_bessel_j0_ops.h>
|
| 1219 |
+
#include <ATen/ops/special_xlog1py_ops.h>
|
| 1220 |
+
#include <ATen/ops/special_xlogy_ops.h>
|
| 1221 |
+
#include <ATen/ops/special_zeta_ops.h>
|
| 1222 |
+
#include <ATen/ops/split_ops.h>
|
| 1223 |
+
#include <ATen/ops/split_copy_ops.h>
|
| 1224 |
+
#include <ATen/ops/split_with_sizes_ops.h>
|
| 1225 |
+
#include <ATen/ops/split_with_sizes_copy_ops.h>
|
| 1226 |
+
#include <ATen/ops/sqrt_ops.h>
|
| 1227 |
+
#include <ATen/ops/square_ops.h>
|
| 1228 |
+
#include <ATen/ops/squeeze_ops.h>
|
| 1229 |
+
#include <ATen/ops/squeeze_copy_ops.h>
|
| 1230 |
+
#include <ATen/ops/sspaddmm_ops.h>
|
| 1231 |
+
#include <ATen/ops/stack_ops.h>
|
| 1232 |
+
#include <ATen/ops/std_ops.h>
|
| 1233 |
+
#include <ATen/ops/std_mean_ops.h>
|
| 1234 |
+
#include <ATen/ops/stft_ops.h>
|
| 1235 |
+
#include <ATen/ops/stride_ops.h>
|
| 1236 |
+
#include <ATen/ops/sub_ops.h>
|
| 1237 |
+
#include <ATen/ops/subtract_ops.h>
|
| 1238 |
+
#include <ATen/ops/sum_ops.h>
|
| 1239 |
+
#include <ATen/ops/sum_to_size_ops.h>
|
| 1240 |
+
#include <ATen/ops/svd_ops.h>
|
| 1241 |
+
#include <ATen/ops/swapaxes_ops.h>
|
| 1242 |
+
#include <ATen/ops/swapdims_ops.h>
|
| 1243 |
+
#include <ATen/ops/sym_constrain_range_ops.h>
|
| 1244 |
+
#include <ATen/ops/sym_constrain_range_for_size_ops.h>
|
| 1245 |
+
#include <ATen/ops/sym_numel_ops.h>
|
| 1246 |
+
#include <ATen/ops/sym_size_ops.h>
|
| 1247 |
+
#include <ATen/ops/sym_storage_offset_ops.h>
|
| 1248 |
+
#include <ATen/ops/sym_stride_ops.h>
|
| 1249 |
+
#include <ATen/ops/t_ops.h>
|
| 1250 |
+
#include <ATen/ops/t_copy_ops.h>
|
| 1251 |
+
#include <ATen/ops/take_ops.h>
|
| 1252 |
+
#include <ATen/ops/take_along_dim_ops.h>
|
| 1253 |
+
#include <ATen/ops/tan_ops.h>
|
| 1254 |
+
#include <ATen/ops/tanh_ops.h>
|
| 1255 |
+
#include <ATen/ops/tanh_backward_ops.h>
|
| 1256 |
+
#include <ATen/ops/tensor_split_ops.h>
|
| 1257 |
+
#include <ATen/ops/tensordot_ops.h>
|
| 1258 |
+
#include <ATen/ops/thnn_conv2d_ops.h>
|
| 1259 |
+
#include <ATen/ops/threshold_ops.h>
|
| 1260 |
+
#include <ATen/ops/threshold_backward_ops.h>
|
| 1261 |
+
#include <ATen/ops/tile_ops.h>
|
| 1262 |
+
#include <ATen/ops/to_ops.h>
|
| 1263 |
+
#include <ATen/ops/to_dense_ops.h>
|
| 1264 |
+
#include <ATen/ops/to_dense_backward_ops.h>
|
| 1265 |
+
#include <ATen/ops/to_mkldnn_ops.h>
|
| 1266 |
+
#include <ATen/ops/to_mkldnn_backward_ops.h>
|
| 1267 |
+
#include <ATen/ops/to_padded_tensor_ops.h>
|
| 1268 |
+
#include <ATen/ops/to_sparse_ops.h>
|
| 1269 |
+
#include <ATen/ops/to_sparse_bsc_ops.h>
|
| 1270 |
+
#include <ATen/ops/to_sparse_bsr_ops.h>
|
| 1271 |
+
#include <ATen/ops/to_sparse_csc_ops.h>
|
| 1272 |
+
#include <ATen/ops/to_sparse_csr_ops.h>
|
| 1273 |
+
#include <ATen/ops/topk_ops.h>
|
| 1274 |
+
#include <ATen/ops/trace_ops.h>
|
| 1275 |
+
#include <ATen/ops/trace_backward_ops.h>
|
| 1276 |
+
#include <ATen/ops/transpose_ops.h>
|
| 1277 |
+
#include <ATen/ops/transpose_copy_ops.h>
|
| 1278 |
+
#include <ATen/ops/trapezoid_ops.h>
|
| 1279 |
+
#include <ATen/ops/trapz_ops.h>
|
| 1280 |
+
#include <ATen/ops/triangular_solve_ops.h>
|
| 1281 |
+
#include <ATen/ops/tril_ops.h>
|
| 1282 |
+
#include <ATen/ops/tril_indices_ops.h>
|
| 1283 |
+
#include <ATen/ops/triplet_margin_loss_ops.h>
|
| 1284 |
+
#include <ATen/ops/triu_ops.h>
|
| 1285 |
+
#include <ATen/ops/triu_indices_ops.h>
|
| 1286 |
+
#include <ATen/ops/true_divide_ops.h>
|
| 1287 |
+
#include <ATen/ops/trunc_ops.h>
|
| 1288 |
+
#include <ATen/ops/type_as_ops.h>
|
| 1289 |
+
#include <ATen/ops/unbind_ops.h>
|
| 1290 |
+
#include <ATen/ops/unbind_copy_ops.h>
|
| 1291 |
+
#include <ATen/ops/unflatten_ops.h>
|
| 1292 |
+
#include <ATen/ops/unflatten_dense_tensors_ops.h>
|
| 1293 |
+
#include <ATen/ops/unfold_ops.h>
|
| 1294 |
+
#include <ATen/ops/unfold_backward_ops.h>
|
| 1295 |
+
#include <ATen/ops/unfold_copy_ops.h>
|
| 1296 |
+
#include <ATen/ops/uniform_ops.h>
|
| 1297 |
+
#include <ATen/ops/unique_consecutive_ops.h>
|
| 1298 |
+
#include <ATen/ops/unique_dim_ops.h>
|
| 1299 |
+
#include <ATen/ops/unique_dim_consecutive_ops.h>
|
| 1300 |
+
#include <ATen/ops/unsafe_chunk_ops.h>
|
| 1301 |
+
#include <ATen/ops/unsafe_split_ops.h>
|
| 1302 |
+
#include <ATen/ops/unsafe_split_with_sizes_ops.h>
|
| 1303 |
+
#include <ATen/ops/unsqueeze_ops.h>
|
| 1304 |
+
#include <ATen/ops/unsqueeze_copy_ops.h>
|
| 1305 |
+
#include <ATen/ops/upsample_bicubic2d_ops.h>
|
| 1306 |
+
#include <ATen/ops/upsample_bicubic2d_backward_ops.h>
|
| 1307 |
+
#include <ATen/ops/upsample_bilinear2d_ops.h>
|
| 1308 |
+
#include <ATen/ops/upsample_bilinear2d_backward_ops.h>
|
| 1309 |
+
#include <ATen/ops/upsample_linear1d_ops.h>
|
| 1310 |
+
#include <ATen/ops/upsample_linear1d_backward_ops.h>
|
| 1311 |
+
#include <ATen/ops/upsample_nearest1d_ops.h>
|
| 1312 |
+
#include <ATen/ops/upsample_nearest1d_backward_ops.h>
|
| 1313 |
+
#include <ATen/ops/upsample_nearest2d_ops.h>
|
| 1314 |
+
#include <ATen/ops/upsample_nearest2d_backward_ops.h>
|
| 1315 |
+
#include <ATen/ops/upsample_nearest3d_ops.h>
|
| 1316 |
+
#include <ATen/ops/upsample_nearest3d_backward_ops.h>
|
| 1317 |
+
#include <ATen/ops/upsample_trilinear3d_ops.h>
|
| 1318 |
+
#include <ATen/ops/upsample_trilinear3d_backward_ops.h>
|
| 1319 |
+
#include <ATen/ops/value_selecting_reduction_backward_ops.h>
|
| 1320 |
+
#include <ATen/ops/values_ops.h>
|
| 1321 |
+
#include <ATen/ops/values_copy_ops.h>
|
| 1322 |
+
#include <ATen/ops/vander_ops.h>
|
| 1323 |
+
#include <ATen/ops/var_ops.h>
|
| 1324 |
+
#include <ATen/ops/var_mean_ops.h>
|
| 1325 |
+
#include <ATen/ops/vdot_ops.h>
|
| 1326 |
+
#include <ATen/ops/view_ops.h>
|
| 1327 |
+
#include <ATen/ops/view_as_ops.h>
|
| 1328 |
+
#include <ATen/ops/view_as_complex_ops.h>
|
| 1329 |
+
#include <ATen/ops/view_as_complex_copy_ops.h>
|
| 1330 |
+
#include <ATen/ops/view_as_real_ops.h>
|
| 1331 |
+
#include <ATen/ops/view_as_real_copy_ops.h>
|
| 1332 |
+
#include <ATen/ops/view_copy_ops.h>
|
| 1333 |
+
#include <ATen/ops/vsplit_ops.h>
|
| 1334 |
+
#include <ATen/ops/vstack_ops.h>
|
| 1335 |
+
#include <ATen/ops/where_ops.h>
|
| 1336 |
+
#include <ATen/ops/xlogy_ops.h>
|
| 1337 |
+
#include <ATen/ops/xor_ops.h>
|
| 1338 |
+
#include <ATen/ops/zero_ops.h>
|
| 1339 |
+
#include <ATen/ops/zeros_ops.h>
|
| 1340 |
+
#include <ATen/ops/zeros_like_ops.h>
|
| 1341 |
+
|
| 1342 |
+
// Extension writers: do you write wrapper functions? Are you frustrated with
|
| 1343 |
+
// resolving overloads of operators? Are you frustrated with dealing with
|
| 1344 |
+
// pointer-to-methods and resolving overloads of pointer-to-methods?? Look no
|
| 1345 |
+
// further, this is the utility for you.
|
| 1346 |
+
//
|
| 1347 |
+
// Given an operator schema: aten::op.overload(...
|
| 1348 |
+
//
|
| 1349 |
+
// Use ATEN_FN2(op, overload) to get a *function* version of the operator
|
| 1350 |
+
// that is guaranteed to not be overloaded. This means that you can safely
|
| 1351 |
+
// decltype(&ATEN_FN2(op, overload)) it. NB: the 2 means this macro takes 2 args.
|
| 1352 |
+
//
|
| 1353 |
+
// Given an operator schema without an overload name: aten::op(...
|
| 1354 |
+
//
|
| 1355 |
+
// Use ATEN_FN(op) to get an unambiguous *function* version of the operator.
|
| 1356 |
+
//
|
| 1357 |
+
// There is some interesting behavior for out= operations.
|
| 1358 |
+
// ATEN_FN2(sin, out) gives a function that is *faithful* to the schema;
|
| 1359 |
+
// that is, the order of arguments is exactly what it looks like in the schema.
|
| 1360 |
+
|
| 1361 |
+
#define ATEN_FN2(op_name, overload) at::_ops::op_name##_##overload::call
|
| 1362 |
+
#define ATEN_FN(op_name) at::_ops::op_name::call
|
| 1363 |
+
|
| 1364 |
+
// Separately, ATEN_OP(op) and ATEN_OP2(op, overload) define a class containing compile-time
|
| 1365 |
+
// metadata about a given aten operator.
|
| 1366 |
+
// Notable data on the class includes:
|
| 1367 |
+
// - ATEN_OP2(add, Tensor)::name // returns the string name: "add"
|
| 1368 |
+
// - ATEN_OP2(add, Tensor)::overload_name // returns the string overload name: "Tensor"
|
| 1369 |
+
// - ATEN_OP2(add, Tensor)::schema // returns the C++ schema type: at::Tensor (const at::Tensor &, const at::Tensor &, const at::Scalar &)
|
| 1370 |
+
// - ATEN_OP2(add, Tensor)::schema_str // returns the string jit type: "add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"
|
| 1371 |
+
|
| 1372 |
+
#define ATEN_OP2(op_name, overload) at::_ops::op_name##_##overload
|
| 1373 |
+
#define ATEN_OP(op_name) at::_ops::op_name
|
| 1374 |
+
|
| 1375 |
+
// WARNING: Please do not call any of the ops in the _ops namespace directly.
|
| 1376 |
+
// Use the ATEN_FN macros. We do not guarantee stability of the naming
|
| 1377 |
+
// scheme for the functions in at::_ops
|
| 1378 |
+
|
| 1379 |
+
// See Note [The ATen Operators API] for details of the at::_ops namespace
|
| 1380 |
+
|
| 1381 |
+
namespace at {
|
| 1382 |
+
namespace _ops {
|
| 1383 |
+
|
| 1384 |
+
} // namespace _ops
|
| 1385 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/Parallel-inl.h
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/util/Exception.h>
|
| 4 |
+
#include <c10/util/ParallelGuard.h>
|
| 5 |
+
#include <c10/util/SmallVector.h>
|
| 6 |
+
|
| 7 |
+
namespace at {
|
| 8 |
+
|
| 9 |
+
template <class F>
|
| 10 |
+
inline void parallel_for(
|
| 11 |
+
const int64_t begin,
|
| 12 |
+
const int64_t end,
|
| 13 |
+
const int64_t grain_size,
|
| 14 |
+
const F& f) {
|
| 15 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grain_size >= 0);
|
| 16 |
+
if (begin >= end) {
|
| 17 |
+
return;
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
#ifdef INTRA_OP_PARALLEL
|
| 21 |
+
at::internal::lazy_init_num_threads();
|
| 22 |
+
const auto numiter = end - begin;
|
| 23 |
+
const bool use_parallel =
|
| 24 |
+
(numiter > grain_size && numiter > 1 && !at::in_parallel_region() &&
|
| 25 |
+
at::get_num_threads() > 1);
|
| 26 |
+
if (!use_parallel) {
|
| 27 |
+
internal::ThreadIdGuard tid_guard(0);
|
| 28 |
+
c10::ParallelGuard guard(true);
|
| 29 |
+
f(begin, end);
|
| 30 |
+
return;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
internal::invoke_parallel(
|
| 34 |
+
begin, end, grain_size, [&](int64_t begin, int64_t end) {
|
| 35 |
+
c10::ParallelGuard guard(true);
|
| 36 |
+
f(begin, end);
|
| 37 |
+
});
|
| 38 |
+
#else
|
| 39 |
+
internal::ThreadIdGuard tid_guard(0);
|
| 40 |
+
c10::ParallelGuard guard(true);
|
| 41 |
+
f(begin, end);
|
| 42 |
+
#endif
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
template <class scalar_t, class F, class SF>
|
| 46 |
+
inline scalar_t parallel_reduce(
|
| 47 |
+
const int64_t begin,
|
| 48 |
+
const int64_t end,
|
| 49 |
+
const int64_t grain_size,
|
| 50 |
+
const scalar_t ident,
|
| 51 |
+
const F& f,
|
| 52 |
+
const SF& sf) {
|
| 53 |
+
TORCH_CHECK(grain_size >= 0);
|
| 54 |
+
if (begin >= end) {
|
| 55 |
+
return ident;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
#ifdef INTRA_OP_PARALLEL
|
| 59 |
+
at::internal::lazy_init_num_threads();
|
| 60 |
+
const auto max_threads = at::get_num_threads();
|
| 61 |
+
const bool use_parallel =
|
| 62 |
+
((end - begin) > grain_size && !at::in_parallel_region() &&
|
| 63 |
+
max_threads > 1);
|
| 64 |
+
if (!use_parallel) {
|
| 65 |
+
internal::ThreadIdGuard tid_guard(0);
|
| 66 |
+
c10::ParallelGuard guard(true);
|
| 67 |
+
return f(begin, end, ident);
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
c10::SmallVector<scalar_t, 64> results(max_threads, ident);
|
| 71 |
+
internal::invoke_parallel(
|
| 72 |
+
begin,
|
| 73 |
+
end,
|
| 74 |
+
grain_size,
|
| 75 |
+
[&](const int64_t my_begin, const int64_t my_end) {
|
| 76 |
+
const auto tid = at::get_thread_num();
|
| 77 |
+
c10::ParallelGuard guard(true);
|
| 78 |
+
results[tid] = f(my_begin, my_end, ident);
|
| 79 |
+
});
|
| 80 |
+
|
| 81 |
+
scalar_t result = ident;
|
| 82 |
+
for (auto partial_result : results) {
|
| 83 |
+
result = sf(result, partial_result);
|
| 84 |
+
}
|
| 85 |
+
return result;
|
| 86 |
+
#else
|
| 87 |
+
internal::ThreadIdGuard tid_guard(0);
|
| 88 |
+
c10::ParallelGuard guard(true);
|
| 89 |
+
return f(begin, end, ident);
|
| 90 |
+
#endif
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/ParallelNative.h
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/util/Exception.h>
|
| 4 |
+
|
| 5 |
+
#define INTRA_OP_PARALLEL
|
| 6 |
+
|
| 7 |
+
namespace at::internal {
|
| 8 |
+
|
| 9 |
+
TORCH_API void invoke_parallel(
|
| 10 |
+
const int64_t begin,
|
| 11 |
+
const int64_t end,
|
| 12 |
+
const int64_t grain_size,
|
| 13 |
+
const std::function<void(int64_t, int64_t)>& f);
|
| 14 |
+
|
| 15 |
+
} // namespace at::internal
|
.venv/lib/python3.11/site-packages/torch/include/ATen/ParallelOpenMP.h
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <algorithm>
|
| 4 |
+
#include <atomic>
|
| 5 |
+
#include <cstddef>
|
| 6 |
+
#include <exception>
|
| 7 |
+
|
| 8 |
+
#ifdef _OPENMP
|
| 9 |
+
#define INTRA_OP_PARALLEL
|
| 10 |
+
|
| 11 |
+
#include <omp.h>
|
| 12 |
+
#endif
|
| 13 |
+
|
| 14 |
+
#ifdef _OPENMP
|
| 15 |
+
namespace at::internal {
|
| 16 |
+
template <typename F>
|
| 17 |
+
inline void invoke_parallel(
|
| 18 |
+
int64_t begin,
|
| 19 |
+
int64_t end,
|
| 20 |
+
int64_t grain_size,
|
| 21 |
+
const F& f) {
|
| 22 |
+
std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
|
| 23 |
+
std::exception_ptr eptr;
|
| 24 |
+
|
| 25 |
+
#pragma omp parallel
|
| 26 |
+
{
|
| 27 |
+
// choose number of tasks based on grain size and number of threads
|
| 28 |
+
// can't use num_threads clause due to bugs in GOMP's thread pool (See
|
| 29 |
+
// #32008)
|
| 30 |
+
int64_t num_threads = omp_get_num_threads();
|
| 31 |
+
if (grain_size > 0) {
|
| 32 |
+
num_threads = std::min(num_threads, divup((end - begin), grain_size));
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
int64_t tid = omp_get_thread_num();
|
| 36 |
+
int64_t chunk_size = divup((end - begin), num_threads);
|
| 37 |
+
int64_t begin_tid = begin + tid * chunk_size;
|
| 38 |
+
if (begin_tid < end) {
|
| 39 |
+
try {
|
| 40 |
+
internal::ThreadIdGuard tid_guard(tid);
|
| 41 |
+
f(begin_tid, std::min(end, chunk_size + begin_tid));
|
| 42 |
+
} catch (...) {
|
| 43 |
+
if (!err_flag.test_and_set()) {
|
| 44 |
+
eptr = std::current_exception();
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
}
|
| 48 |
+
}
|
| 49 |
+
if (eptr) {
|
| 50 |
+
std::rethrow_exception(eptr);
|
| 51 |
+
}
|
| 52 |
+
}
|
| 53 |
+
} // namespace at::internal
|
| 54 |
+
#endif // _OPENMP
|
.venv/lib/python3.11/site-packages/torch/include/ATen/PythonTorchFunctionTLS.h
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/SafePyObject.h>
|
| 4 |
+
#include <c10/macros/Macros.h>
|
| 5 |
+
|
| 6 |
+
namespace at::impl {
|
| 7 |
+
|
| 8 |
+
enum TorchFunctionDisabledState { ENABLED, SUBCLASSES_DISABLED, ALL_DISABLED };
|
| 9 |
+
|
| 10 |
+
struct TORCH_API PythonTorchFunctionTLS {
|
| 11 |
+
static void set_disabled_state(TorchFunctionDisabledState disabled_state_);
|
| 12 |
+
static TorchFunctionDisabledState get_disabled_state();
|
| 13 |
+
|
| 14 |
+
static void push_onto_stack(std::shared_ptr<SafePyObject> mode);
|
| 15 |
+
static const std::shared_ptr<SafePyObject> pop_stack();
|
| 16 |
+
static const std::shared_ptr<SafePyObject>& get_stack_at(int64_t idx);
|
| 17 |
+
static int64_t stack_len();
|
| 18 |
+
|
| 19 |
+
static const PythonTorchFunctionTLS& get_state();
|
| 20 |
+
static void set_state(const PythonTorchFunctionTLS& state);
|
| 21 |
+
|
| 22 |
+
private:
|
| 23 |
+
// The mode TLS is split into
|
| 24 |
+
// - disabled_state, which says which part of torch function are disabled
|
| 25 |
+
// - stack_, which is a vector of modes representing the stack of user
|
| 26 |
+
// defined modes
|
| 27 |
+
TorchFunctionDisabledState disabled_state_ =
|
| 28 |
+
TorchFunctionDisabledState::ENABLED;
|
| 29 |
+
std::vector<std::shared_ptr<c10::SafePyObject>> stack_;
|
| 30 |
+
};
|
| 31 |
+
|
| 32 |
+
TORCH_API bool torch_function_mode_enabled();
|
| 33 |
+
|
| 34 |
+
TORCH_API bool torch_function_all_disabled();
|
| 35 |
+
|
| 36 |
+
} // namespace at::impl
|
.venv/lib/python3.11/site-packages/torch/include/ATen/RedispatchFunctions.h
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/torch/include/ATen/SmallVector.h
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <c10/util/SmallVector.h>
|
.venv/lib/python3.11/site-packages/torch/include/ATen/SparseTensorImpl.h
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/Tensor.h>
|
| 4 |
+
#include <c10/core/TensorImpl.h>
|
| 5 |
+
#include <c10/core/impl/TorchDispatchModeTLS.h>
|
| 6 |
+
#include <c10/util/Exception.h>
|
| 7 |
+
#include <c10/util/irange.h>
|
| 8 |
+
|
| 9 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 10 |
+
#include <ATen/Functions.h>
|
| 11 |
+
#else
|
| 12 |
+
#include <ATen/ops/empty.h>
|
| 13 |
+
#include <ATen/ops/resize.h>
|
| 14 |
+
#endif
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
struct TORCH_API SparseTensorImpl : public TensorImpl {
|
| 18 |
+
// Stored in COO format, indices + values.
|
| 19 |
+
|
| 20 |
+
// INVARIANTS:
|
| 21 |
+
// sparse_dim: range [0, len(shape)]; sparse_dim + dense_dim = len(shape)
|
| 22 |
+
// dense_dim : range [0, len(shape)]; sparse_dim + dense_dim = len(shape)
|
| 23 |
+
// _indices.shape: dimensionality: 2, shape: (sparse_dim, nnz)
|
| 24 |
+
// _values.shape: dimensionality: 1 + dense_dim. shape: (nnz,
|
| 25 |
+
// shape[sparse_dim:])
|
| 26 |
+
|
| 27 |
+
int64_t sparse_dim_ = 0; // number of sparse dimensions
|
| 28 |
+
int64_t dense_dim_ = 0; // number of dense dimensions
|
| 29 |
+
|
| 30 |
+
Tensor indices_; // always a LongTensor
|
| 31 |
+
Tensor values_;
|
| 32 |
+
|
| 33 |
+
// A sparse tensor is 'coalesced' if every index occurs at most once in
|
| 34 |
+
// the indices tensor, and the indices are in sorted order. (This means
|
| 35 |
+
// that it is very easy to convert a coalesced tensor to CSR format: you
|
| 36 |
+
// need only compute CSR format indices.)
|
| 37 |
+
//
|
| 38 |
+
// Most math operations can only be performed on coalesced sparse tensors,
|
| 39 |
+
// because many algorithms proceed by merging two sorted lists (of indices).
|
| 40 |
+
bool coalesced_ = false;
|
| 41 |
+
|
| 42 |
+
// compute_numel with integer multiplication overflow check, see gh-57542
|
| 43 |
+
void refresh_numel() {
|
| 44 |
+
TensorImpl::safe_refresh_numel();
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
public:
|
| 48 |
+
// Public for now...
|
| 49 |
+
explicit SparseTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta);
|
| 50 |
+
|
| 51 |
+
void release_resources() override;
|
| 52 |
+
|
| 53 |
+
int64_t nnz() const {
|
| 54 |
+
return values_.size(0);
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
c10::SymInt sym_nnz() const {
|
| 58 |
+
return values_.sym_size(0);
|
| 59 |
+
}
|
| 60 |
+
int64_t sparse_dim() const {
|
| 61 |
+
return sparse_dim_;
|
| 62 |
+
}
|
| 63 |
+
int64_t dense_dim() const {
|
| 64 |
+
return dense_dim_;
|
| 65 |
+
}
|
| 66 |
+
bool coalesced() const {
|
| 67 |
+
return coalesced_;
|
| 68 |
+
}
|
| 69 |
+
Tensor indices() const {
|
| 70 |
+
return indices_;
|
| 71 |
+
}
|
| 72 |
+
Tensor values() const {
|
| 73 |
+
return values_;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
void set_size(int64_t dim, int64_t new_size) override;
|
| 77 |
+
void set_stride(int64_t dim, int64_t new_stride) override;
|
| 78 |
+
void set_storage_offset(int64_t storage_offset) override;
|
| 79 |
+
|
| 80 |
+
#ifdef DEBUG
|
| 81 |
+
bool has_storage() const override;
|
| 82 |
+
#endif
|
| 83 |
+
|
| 84 |
+
// WARNING: This function does NOT preserve invariants of sparse_dim/dense_dim
|
| 85 |
+
// with respect to indices and values
|
| 86 |
+
void raw_resize_(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size) {
|
| 87 |
+
TORCH_CHECK(
|
| 88 |
+
allow_tensor_metadata_change(),
|
| 89 |
+
"raw_resize_ ",
|
| 90 |
+
err_msg_tensor_metadata_change_not_allowed);
|
| 91 |
+
TORCH_CHECK(
|
| 92 |
+
!has_symbolic_sizes_strides_,
|
| 93 |
+
"raw_resize_ called on tensor with symbolic shape")
|
| 94 |
+
set_sizes_and_strides(size, std::vector<int64_t>(size.size()));
|
| 95 |
+
sparse_dim_ = sparse_dim;
|
| 96 |
+
dense_dim_ = dense_dim;
|
| 97 |
+
refresh_numel();
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
// NOTE: This function preserves invariants of sparse_dim/dense_dim with
|
| 101 |
+
// respect to indices and values.
|
| 102 |
+
//
|
| 103 |
+
// NOTE: This function supports the following cases:
|
| 104 |
+
// 1. When we keep the number of dense dimensions unchanged, and NOT shrinking
|
| 105 |
+
// the size of any of the dense dimensions.
|
| 106 |
+
// 2. When we keep the number of sparse dimensions unchanged, and NOT
|
| 107 |
+
// shrinking the size of any of the sparse dimensions.
|
| 108 |
+
// 3. When the sparse tensor has zero nnz, in which case we are free to change
|
| 109 |
+
// the shapes of both its sparse and dense dimensions.
|
| 110 |
+
//
|
| 111 |
+
// This function DOESN'T support (and will throw an error) the following
|
| 112 |
+
// cases:
|
| 113 |
+
// 1. When we attempt to change the number of sparse dimensions on a non-empty
|
| 114 |
+
// sparse tensor (such an operation will invalidate the indices stored).
|
| 115 |
+
// 2. When we attempt to change the number of dense dimensions on a non-empty
|
| 116 |
+
// sparse tensor (such an operation will behave differently from an equivalent
|
| 117 |
+
// dense tensor's resize method, and for API consistency we don't support it).
|
| 118 |
+
// 3. When we attempt to shrink the size of any of the dense dimensions on a
|
| 119 |
+
// non-empty sparse tensor (such an operation will behave differently from an
|
| 120 |
+
// equivalent dense tensor's resize method, and for API consistency we don't
|
| 121 |
+
// support it).
|
| 122 |
+
// 4. When we attempt to shrink the size of any of the sparse dimensions on a
|
| 123 |
+
// non-empty sparse tensor (this could make some of the stored indices
|
| 124 |
+
// out-of-bound and thus unsafe).
|
| 125 |
+
template <typename T>
|
| 126 |
+
void _resize_(int64_t sparse_dim, int64_t dense_dim, ArrayRef<T> size) {
|
| 127 |
+
TORCH_CHECK(
|
| 128 |
+
allow_tensor_metadata_change(),
|
| 129 |
+
"resize_ ",
|
| 130 |
+
err_msg_tensor_metadata_change_not_allowed);
|
| 131 |
+
TORCH_CHECK(
|
| 132 |
+
!has_symbolic_sizes_strides_,
|
| 133 |
+
"resize_ called on tensor with symbolic shape")
|
| 134 |
+
TORCH_CHECK(
|
| 135 |
+
sparse_dim + dense_dim == static_cast<int64_t>(size.size()),
|
| 136 |
+
"number of dimensions must be sparse_dim (",
|
| 137 |
+
sparse_dim,
|
| 138 |
+
") + dense_dim (",
|
| 139 |
+
dense_dim,
|
| 140 |
+
"), but got ",
|
| 141 |
+
size.size());
|
| 142 |
+
if (nnz() > 0) {
|
| 143 |
+
[[maybe_unused]] auto constexpr alt_options_msg =
|
| 144 |
+
"You could try the following options:\n\
|
| 145 |
+
1. If you need an empty sparse tensor of this size, call `x = torch.sparse_coo_tensor(size)`.\n\
|
| 146 |
+
2. If you need to resize this tensor, you have the following options:\n\
|
| 147 |
+
1. For both sparse and dense dimensions, keep the number of them constant and the size of them non-shrinking, and then try the same call again.\n\
|
| 148 |
+
2. Or, create a new sparse tensor with the correct indices and values from this sparse tensor.";
|
| 149 |
+
|
| 150 |
+
TORCH_CHECK(
|
| 151 |
+
sparse_dim == sparse_dim_,
|
| 152 |
+
"changing the number of sparse dimensions (from ",
|
| 153 |
+
sparse_dim_,
|
| 154 |
+
" to ",
|
| 155 |
+
sparse_dim,
|
| 156 |
+
") on a non-empty sparse tensor is not supported.\n",
|
| 157 |
+
alt_options_msg);
|
| 158 |
+
|
| 159 |
+
TORCH_CHECK(
|
| 160 |
+
dense_dim == dense_dim_,
|
| 161 |
+
"changing the number of dense dimensions (from ",
|
| 162 |
+
dense_dim_,
|
| 163 |
+
" to ",
|
| 164 |
+
dense_dim,
|
| 165 |
+
") on a non-empty sparse tensor is not supported.\n",
|
| 166 |
+
alt_options_msg);
|
| 167 |
+
|
| 168 |
+
bool shrinking_sparse_dims = false;
|
| 169 |
+
bool shrinking_dense_dim = false;
|
| 170 |
+
auto sparse_size_original = generic_sizes<T>().slice(0, sparse_dim);
|
| 171 |
+
auto sparse_size_new = size.slice(0, sparse_dim);
|
| 172 |
+
for (const auto i : c10::irange(sparse_dim)) {
|
| 173 |
+
if (sparse_size_new[i] < sparse_size_original[i]) {
|
| 174 |
+
shrinking_sparse_dims = true;
|
| 175 |
+
break;
|
| 176 |
+
}
|
| 177 |
+
}
|
| 178 |
+
auto dense_size_original = generic_sizes<T>().slice(sparse_dim);
|
| 179 |
+
auto dense_size_new = size.slice(sparse_dim);
|
| 180 |
+
for (const auto i : c10::irange(dense_dim)) {
|
| 181 |
+
if (dense_size_new[i] < dense_size_original[i]) {
|
| 182 |
+
shrinking_dense_dim = true;
|
| 183 |
+
break;
|
| 184 |
+
}
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
TORCH_CHECK(
|
| 188 |
+
!shrinking_sparse_dims,
|
| 189 |
+
"shrinking the size of sparse dimensions (from ",
|
| 190 |
+
sparse_size_original,
|
| 191 |
+
" to ",
|
| 192 |
+
sparse_size_new,
|
| 193 |
+
") on a non-empty sparse tensor is not supported.\n",
|
| 194 |
+
alt_options_msg);
|
| 195 |
+
|
| 196 |
+
TORCH_CHECK(
|
| 197 |
+
!shrinking_dense_dim,
|
| 198 |
+
"shrinking the size of dense dimensions (from ",
|
| 199 |
+
dense_size_original,
|
| 200 |
+
" to ",
|
| 201 |
+
dense_size_new,
|
| 202 |
+
") on a non-empty sparse tensor is not supported.\n",
|
| 203 |
+
alt_options_msg);
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
auto sizes_and_strides = generic_sizes<T>();
|
| 207 |
+
const bool size_equals_sizes = std::equal(
|
| 208 |
+
size.begin(),
|
| 209 |
+
size.end(),
|
| 210 |
+
sizes_and_strides.begin(),
|
| 211 |
+
sizes_and_strides.end());
|
| 212 |
+
if ((!size_equals_sizes) || (sparse_dim != sparse_dim_) ||
|
| 213 |
+
(dense_dim != dense_dim_)) {
|
| 214 |
+
auto nnz = at::symint::sizes<T>(values())[0];
|
| 215 |
+
std::vector<T> values_size = {nnz};
|
| 216 |
+
auto dense_size = size.slice(sparse_dim);
|
| 217 |
+
values_size.insert(
|
| 218 |
+
values_size.end(), dense_size.begin(), dense_size.end());
|
| 219 |
+
at::symint::resize_<T>(values_, values_size);
|
| 220 |
+
at::symint::resize_<T>(indices_, {T(sparse_dim), nnz});
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
if (!size_equals_sizes) {
|
| 224 |
+
set_sizes_and_strides(size, std::vector<T>(size.size()));
|
| 225 |
+
}
|
| 226 |
+
sparse_dim_ = sparse_dim;
|
| 227 |
+
dense_dim_ = dense_dim;
|
| 228 |
+
refresh_numel();
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
void resize_(int64_t sparse_dim, int64_t dense_dim, ArrayRef<int64_t> size) {
|
| 232 |
+
return _resize_(sparse_dim, dense_dim, size);
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
void resize_(
|
| 236 |
+
int64_t sparse_dim,
|
| 237 |
+
int64_t dense_dim,
|
| 238 |
+
ArrayRef<c10::SymInt> size) {
|
| 239 |
+
return _resize_(sparse_dim, dense_dim, size);
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
// NOTE: this function will resize the sparse tensor and also set `indices`
|
| 243 |
+
// and `values` to empty.
|
| 244 |
+
void resize_and_clear_(
|
| 245 |
+
int64_t sparse_dim,
|
| 246 |
+
int64_t dense_dim,
|
| 247 |
+
IntArrayRef size) {
|
| 248 |
+
TORCH_CHECK(
|
| 249 |
+
allow_tensor_metadata_change(),
|
| 250 |
+
"resize_and_clear_ ",
|
| 251 |
+
err_msg_tensor_metadata_change_not_allowed);
|
| 252 |
+
TORCH_CHECK(
|
| 253 |
+
!has_symbolic_sizes_strides_,
|
| 254 |
+
"resize_and_clear_ called on tensor with symbolic shape")
|
| 255 |
+
TORCH_CHECK(
|
| 256 |
+
sparse_dim + dense_dim == static_cast<int64_t>(size.size()),
|
| 257 |
+
"number of dimensions must be sparse_dim (",
|
| 258 |
+
sparse_dim,
|
| 259 |
+
") + dense_dim (",
|
| 260 |
+
dense_dim,
|
| 261 |
+
"), but got ",
|
| 262 |
+
size.size());
|
| 263 |
+
|
| 264 |
+
set_sizes_and_strides(size, std::vector<int64_t>(size.size()));
|
| 265 |
+
sparse_dim_ = sparse_dim;
|
| 266 |
+
dense_dim_ = dense_dim;
|
| 267 |
+
|
| 268 |
+
auto empty_indices = at::empty({sparse_dim, 0}, indices().options());
|
| 269 |
+
std::vector<int64_t> values_size = {0};
|
| 270 |
+
auto dense_size = sizes().slice(sparse_dim);
|
| 271 |
+
values_size.insert(values_size.end(), dense_size.begin(), dense_size.end());
|
| 272 |
+
auto empty_values = at::empty(values_size, values().options());
|
| 273 |
+
set_indices_and_values_unsafe(empty_indices, empty_values);
|
| 274 |
+
refresh_numel();
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
void set_coalesced(bool coalesced) {
|
| 278 |
+
TORCH_CHECK(
|
| 279 |
+
allow_tensor_metadata_change(),
|
| 280 |
+
"set_coalesced ",
|
| 281 |
+
err_msg_tensor_metadata_change_not_allowed);
|
| 282 |
+
coalesced_ = coalesced;
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
// NOTE: this function is only used internally and not exposed to Python
|
| 286 |
+
// frontend
|
| 287 |
+
void set_nnz_and_narrow(int64_t new_nnz) {
|
| 288 |
+
TORCH_CHECK(
|
| 289 |
+
allow_tensor_metadata_change(),
|
| 290 |
+
"set_nnz_and_narrow ",
|
| 291 |
+
err_msg_tensor_metadata_change_not_allowed);
|
| 292 |
+
AT_ASSERT(new_nnz <= nnz());
|
| 293 |
+
indices_ = indices_.narrow(1, 0, new_nnz);
|
| 294 |
+
values_ = values_.narrow(0, 0, new_nnz);
|
| 295 |
+
if (new_nnz < 2) {
|
| 296 |
+
coalesced_ = true;
|
| 297 |
+
}
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
// Takes indices and values and directly puts them into the sparse tensor, no
|
| 301 |
+
// copy. NOTE: this function is unsafe because it doesn't check whether any
|
| 302 |
+
// indices are out of boundaries of `sizes`, so it should ONLY be used where
|
| 303 |
+
// we know that the indices are guaranteed to be within bounds. This used to
|
| 304 |
+
// be called THSTensor_(_move) NB: This used to be able to avoid a refcount
|
| 305 |
+
// bump, but I was too lazy to make it happen
|
| 306 |
+
void set_indices_and_values_unsafe(
|
| 307 |
+
const Tensor& indices,
|
| 308 |
+
const Tensor& values);
|
| 309 |
+
|
| 310 |
+
template <typename VariableVersion>
|
| 311 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
|
| 312 |
+
VariableVersion&& version_counter,
|
| 313 |
+
bool allow_tensor_metadata_change) const {
|
| 314 |
+
const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len();
|
| 315 |
+
c10::impl::PyInterpreter&& interpreter = nullptr;
|
| 316 |
+
if (mode_stack_len > 0 &&
|
| 317 |
+
!c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
|
| 318 |
+
const auto& cur_torch_dispatch_mode_state =
|
| 319 |
+
c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
|
| 320 |
+
interpreter = cur_torch_dispatch_mode_state->pyinterpreter();
|
| 321 |
+
} else if (
|
| 322 |
+
key_set_.has(DispatchKey::Python) &&
|
| 323 |
+
!c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
|
| 324 |
+
interpreter = pyobj_slot_.load_pyobj_interpreter();
|
| 325 |
+
} else {
|
| 326 |
+
// otherwise just copy the SparseTensorImpl and not the PyObject.
|
| 327 |
+
auto impl = c10::make_intrusive<SparseTensorImpl>(key_set(), dtype());
|
| 328 |
+
copy_tensor_metadata(
|
| 329 |
+
/*src_sparse_impl=*/this,
|
| 330 |
+
/*dest_sparse_impl=*/impl.get(),
|
| 331 |
+
/*version_counter=*/version_counter,
|
| 332 |
+
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
|
| 333 |
+
impl->refresh_numel();
|
| 334 |
+
return impl;
|
| 335 |
+
}
|
| 336 |
+
auto r = interpreter->detach(this);
|
| 337 |
+
r->set_version_counter(std::forward<VariableVersion>(version_counter));
|
| 338 |
+
r->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
|
| 339 |
+
return r;
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
/**
|
| 343 |
+
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
|
| 344 |
+
*
|
| 345 |
+
* For usage of `version_counter` and `allow_tensor_metadata_change`,
|
| 346 |
+
* see NOTE [ TensorImpl Shallow-Copying ].
|
| 347 |
+
*/
|
| 348 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 349 |
+
const c10::VariableVersion& version_counter,
|
| 350 |
+
bool allow_tensor_metadata_change) const override {
|
| 351 |
+
return shallow_copy_and_detach_core(
|
| 352 |
+
version_counter, allow_tensor_metadata_change);
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
/**
|
| 356 |
+
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
|
| 357 |
+
*
|
| 358 |
+
* For usage of `version_counter` and `allow_tensor_metadata_change`,
|
| 359 |
+
* see NOTE [ TensorImpl Shallow-Copying ].
|
| 360 |
+
*/
|
| 361 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 362 |
+
c10::VariableVersion&& version_counter,
|
| 363 |
+
bool allow_tensor_metadata_change) const override {
|
| 364 |
+
return shallow_copy_and_detach_core(
|
| 365 |
+
std::move(version_counter), allow_tensor_metadata_change);
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
/**
|
| 369 |
+
* Shallow-copies data from another TensorImpl into this TensorImpl.
|
| 370 |
+
*
|
| 371 |
+
* For why this function doesn't check this TensorImpl's
|
| 372 |
+
* `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ].
|
| 373 |
+
*/
|
| 374 |
+
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
|
| 375 |
+
AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set()));
|
| 376 |
+
auto sparse_impl = static_cast<const SparseTensorImpl*>(impl.get());
|
| 377 |
+
copy_tensor_metadata(
|
| 378 |
+
/*src_sparse_impl=*/sparse_impl,
|
| 379 |
+
/*dest_sparse_impl=*/this,
|
| 380 |
+
/*version_counter=*/version_counter(),
|
| 381 |
+
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
|
| 382 |
+
refresh_numel();
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
private:
|
| 386 |
+
explicit SparseTensorImpl(
|
| 387 |
+
at::DispatchKeySet,
|
| 388 |
+
const caffe2::TypeMeta,
|
| 389 |
+
at::Tensor indices,
|
| 390 |
+
at::Tensor values);
|
| 391 |
+
|
| 392 |
+
/**
|
| 393 |
+
* Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
|
| 394 |
+
* storage_offset) from one TensorImpl to another TensorImpl.
|
| 395 |
+
*
|
| 396 |
+
* For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
|
| 397 |
+
* [ TensorImpl Shallow-Copying ].
|
| 398 |
+
*/
|
| 399 |
+
static void copy_tensor_metadata(
|
| 400 |
+
const SparseTensorImpl* src_sparse_impl,
|
| 401 |
+
SparseTensorImpl* dest_sparse_impl,
|
| 402 |
+
c10::VariableVersion version_counter,
|
| 403 |
+
bool allow_tensor_metadata_change) {
|
| 404 |
+
TensorImpl::copy_tensor_metadata(
|
| 405 |
+
src_sparse_impl,
|
| 406 |
+
dest_sparse_impl,
|
| 407 |
+
std::move(version_counter),
|
| 408 |
+
allow_tensor_metadata_change);
|
| 409 |
+
|
| 410 |
+
// Sparse-specific fields
|
| 411 |
+
dest_sparse_impl->sparse_dim_ = src_sparse_impl->sparse_dim();
|
| 412 |
+
dest_sparse_impl->dense_dim_ = src_sparse_impl->dense_dim();
|
| 413 |
+
dest_sparse_impl->indices_ = src_sparse_impl->indices();
|
| 414 |
+
dest_sparse_impl->values_ = src_sparse_impl->values();
|
| 415 |
+
dest_sparse_impl->coalesced_ = src_sparse_impl->coalesced();
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
const char* tensorimpl_type_name() const override;
|
| 419 |
+
};
|
| 420 |
+
|
| 421 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/StorageUtils.h
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/Storage.h>
|
| 4 |
+
#include <c10/core/StorageImpl.h>
|
| 5 |
+
#include <c10/util/intrusive_ptr.h>
|
| 6 |
+
|
| 7 |
+
namespace at {
|
| 8 |
+
|
| 9 |
+
class TensorBase;
|
| 10 |
+
|
| 11 |
+
// Here we define a series of utils to create/manipulate ATen backed
|
| 12 |
+
// c10 storage implementations.
|
| 13 |
+
|
| 14 |
+
/**
|
| 15 |
+
* Create a new shared memory storage impl managed by file descriptor
|
| 16 |
+
*
|
| 17 |
+
* @param size size in bytes
|
| 18 |
+
*/
|
| 19 |
+
C10_EXPORT c10::intrusive_ptr<c10::StorageImpl> new_shm_fd_storage(size_t size);
|
| 20 |
+
|
| 21 |
+
/**
|
| 22 |
+
* Copy src to dst
|
| 23 |
+
* Caller must guarantee the validness of the storage objects
|
| 24 |
+
* during the entire copy process, esp. when it's async.
|
| 25 |
+
*
|
| 26 |
+
* This can probably live in c10 namespace later if needed,
|
| 27 |
+
* but for now keep it in at to keep implementation simple.
|
| 28 |
+
*
|
| 29 |
+
* @param dst dst tensor
|
| 30 |
+
* @param src src tensor
|
| 31 |
+
* @param non_blocking (default false) whether this operation blocks caller
|
| 32 |
+
*/
|
| 33 |
+
C10_EXPORT void storage_copy(
|
| 34 |
+
c10::Storage& dst,
|
| 35 |
+
const c10::Storage& src,
|
| 36 |
+
bool non_blocking = false);
|
| 37 |
+
|
| 38 |
+
/**
|
| 39 |
+
* In place change the storage to shm based.
|
| 40 |
+
*
|
| 41 |
+
* This is only applicable to CPU tensors not already shared.
|
| 42 |
+
* Otherwise, it's a no op to mirror the THP tensor behavior:
|
| 43 |
+
* https://pytorch.org/docs/stable/generated/torch.Tensor.share_memory_.html
|
| 44 |
+
*
|
| 45 |
+
* @param t a tensor
|
| 46 |
+
*/
|
| 47 |
+
C10_EXPORT void share_memory_(TensorBase& t);
|
| 48 |
+
|
| 49 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/TensorAccessor.h
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/TensorAccessor.h>
|
.venv/lib/python3.11/site-packages/torch/include/ATen/TensorIndexing.h
ADDED
|
@@ -0,0 +1,737 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/ExpandUtils.h>
|
| 4 |
+
#include <ATen/ScalarOps.h>
|
| 5 |
+
#include <ATen/core/Tensor.h>
|
| 6 |
+
#include <ATen/core/TensorBody.h>
|
| 7 |
+
#include <c10/core/SymInt.h>
|
| 8 |
+
#include <c10/util/irange.h>
|
| 9 |
+
#include <optional>
|
| 10 |
+
|
| 11 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 12 |
+
#include <ATen/Functions.h>
|
| 13 |
+
#include <ATen/NativeFunctions.h>
|
| 14 |
+
#else
|
| 15 |
+
#include <ATen/ops/alias.h>
|
| 16 |
+
#include <ATen/ops/empty.h>
|
| 17 |
+
#include <ATen/ops/scalar_tensor.h>
|
| 18 |
+
#include <ATen/ops/zeros.h>
|
| 19 |
+
#endif
|
| 20 |
+
|
| 21 |
+
#include <ATen/core/List.h>
|
| 22 |
+
|
| 23 |
+
#include <utility>
|
| 24 |
+
|
| 25 |
+
namespace at::indexing {
|
| 26 |
+
|
| 27 |
+
constexpr int64_t INDEX_MIN = c10::SymInt::min_representable_int();
|
| 28 |
+
constexpr int64_t INDEX_MAX = -(INDEX_MIN + 1);
|
| 29 |
+
|
| 30 |
+
enum class TensorIndexType { None, Ellipsis, SymInt, Boolean, Slice, Tensor };
|
| 31 |
+
|
| 32 |
+
constexpr std::nullopt_t None = std::nullopt;
|
| 33 |
+
|
| 34 |
+
struct TORCH_API EllipsisIndexType final {
|
| 35 |
+
EllipsisIndexType() = default;
|
| 36 |
+
};
|
| 37 |
+
TORCH_API extern const EllipsisIndexType Ellipsis;
|
| 38 |
+
|
| 39 |
+
struct TORCH_API Slice final {
|
| 40 |
+
public:
|
| 41 |
+
Slice(
|
| 42 |
+
std::optional<c10::SymInt> start_index = std::nullopt,
|
| 43 |
+
std::optional<c10::SymInt> stop_index = std::nullopt,
|
| 44 |
+
std::optional<c10::SymInt> step_index = std::nullopt) {
|
| 45 |
+
if (!step_index.has_value()) {
|
| 46 |
+
step_ = c10::SymInt(1);
|
| 47 |
+
} else {
|
| 48 |
+
step_ = std::move(step_index).value();
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
TORCH_CHECK_VALUE(
|
| 52 |
+
step_.sym_ne(0).expect_true(__FILE__, __LINE__),
|
| 53 |
+
"slice step cannot be zero");
|
| 54 |
+
|
| 55 |
+
if (!start_index.has_value()) {
|
| 56 |
+
start_ = c10::SymInt(step_ < 0 ? INDEX_MAX : 0);
|
| 57 |
+
} else {
|
| 58 |
+
start_ = std::move(start_index).value();
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
if (!stop_index.has_value()) {
|
| 62 |
+
stop_ = c10::SymInt(step_ < 0 ? INDEX_MIN : INDEX_MAX);
|
| 63 |
+
} else {
|
| 64 |
+
stop_ = std::move(stop_index).value();
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
inline c10::SymInt start() const {
|
| 69 |
+
return start_;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
inline c10::SymInt stop() const {
|
| 73 |
+
return stop_;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
inline c10::SymInt step() const {
|
| 77 |
+
return step_;
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
private:
|
| 81 |
+
c10::SymInt start_;
|
| 82 |
+
c10::SymInt stop_;
|
| 83 |
+
c10::SymInt step_;
|
| 84 |
+
};
|
| 85 |
+
|
| 86 |
+
TORCH_API std::ostream& operator<<(std::ostream& stream, const Slice& slice);
|
| 87 |
+
|
| 88 |
+
// `at::indexing::TensorIndex` is used for converting C++ tensor indices such as
|
| 89 |
+
// `{None, "...", Ellipsis, 0, true, Slice(1, None, 2), torch::tensor({1, 2})}`
|
| 90 |
+
// into its equivalent `std::vector<TensorIndex>`, so that further tensor
|
| 91 |
+
// indexing operations can be performed using the supplied indices.
|
| 92 |
+
//
|
| 93 |
+
// There is one-to-one correspondence between Python and C++ tensor index types:
|
| 94 |
+
// Python | C++
|
| 95 |
+
// -----------------------------------------------------
|
| 96 |
+
// `None` | `at::indexing::None`
|
| 97 |
+
// `Ellipsis` | `at::indexing::Ellipsis`
|
| 98 |
+
// `...` | `"..."`
|
| 99 |
+
// `123` | `123`
|
| 100 |
+
// `True` / `False` | `true` / `false`
|
| 101 |
+
// `:` | `Slice()` / `Slice(None, None)`
|
| 102 |
+
// `::` | `Slice()` / `Slice(None, None, None)`
|
| 103 |
+
// `1:` | `Slice(1, None)`
|
| 104 |
+
// `1::` | `Slice(1, None, None)`
|
| 105 |
+
// `:3` | `Slice(None, 3)`
|
| 106 |
+
// `:3:` | `Slice(None, 3, None)`
|
| 107 |
+
// `::2` | `Slice(None, None, 2)`
|
| 108 |
+
// `1:3` | `Slice(1, 3)`
|
| 109 |
+
// `1::2` | `Slice(1, None, 2)`
|
| 110 |
+
// `:3:2` | `Slice(None, 3, 2)`
|
| 111 |
+
// `1:3:2` | `Slice(1, 3, 2)`
|
| 112 |
+
// `torch.tensor([1, 2])`) | `torch::tensor({1, 2})`
|
| 113 |
+
struct TORCH_API TensorIndex final {
|
| 114 |
+
// Case 1: `at::indexing::None`
|
| 115 |
+
TensorIndex(std::nullopt_t) : type_(TensorIndexType::None) {}
|
| 116 |
+
|
| 117 |
+
// Case 2: "..." / `at::indexing::Ellipsis`
|
| 118 |
+
TensorIndex(at::indexing::EllipsisIndexType)
|
| 119 |
+
: type_(TensorIndexType::Ellipsis) {}
|
| 120 |
+
TensorIndex(const char* str) : TensorIndex(at::indexing::Ellipsis) {
|
| 121 |
+
TORCH_CHECK_VALUE(
|
| 122 |
+
strcmp(str, "...") == 0,
|
| 123 |
+
"Expected \"...\" to represent an ellipsis index, but got \"",
|
| 124 |
+
str,
|
| 125 |
+
"\"");
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
// Case 3: (Sym) Integer value
|
| 129 |
+
TensorIndex(SymInt integer)
|
| 130 |
+
: integer_(std::move(integer)), type_(TensorIndexType::SymInt) {}
|
| 131 |
+
TensorIndex(int64_t integer) : TensorIndex(SymInt(integer)) {}
|
| 132 |
+
TensorIndex(int integer) : TensorIndex(SymInt(integer)) {}
|
| 133 |
+
|
| 134 |
+
// Case 4: Boolean value
|
| 135 |
+
template <class T, class = std::enable_if_t<std::is_same_v<bool, T>>>
|
| 136 |
+
TensorIndex(T boolean) : boolean_(boolean), type_(TensorIndexType::Boolean) {}
|
| 137 |
+
|
| 138 |
+
// Case 5: Slice represented in `at::indexing::Slice` form
|
| 139 |
+
TensorIndex(Slice slice)
|
| 140 |
+
: slice_(std::move(slice)), type_(TensorIndexType::Slice) {}
|
| 141 |
+
|
| 142 |
+
// Case 6: Tensor value
|
| 143 |
+
TensorIndex(Tensor tensor)
|
| 144 |
+
: tensor_(std::move(tensor)), type_(TensorIndexType::Tensor) {}
|
| 145 |
+
|
| 146 |
+
inline bool is_none() const {
|
| 147 |
+
return type_ == TensorIndexType::None;
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
inline bool is_ellipsis() const {
|
| 151 |
+
return type_ == TensorIndexType::Ellipsis;
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
inline bool is_integer() const {
|
| 155 |
+
return type_ == TensorIndexType::SymInt;
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
inline SymInt integer() const {
|
| 159 |
+
return integer_;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
inline bool is_boolean() const {
|
| 163 |
+
return type_ == TensorIndexType::Boolean;
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
inline bool boolean() const {
|
| 167 |
+
return boolean_;
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
inline bool is_slice() const {
|
| 171 |
+
return type_ == TensorIndexType::Slice;
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
inline const Slice& slice() const {
|
| 175 |
+
return slice_;
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
inline bool is_tensor() const {
|
| 179 |
+
return type_ == TensorIndexType::Tensor;
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
inline const Tensor& tensor() const {
|
| 183 |
+
return tensor_;
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
private:
|
| 187 |
+
SymInt integer_ = 0;
|
| 188 |
+
bool boolean_ = false;
|
| 189 |
+
Slice slice_;
|
| 190 |
+
Tensor tensor_;
|
| 191 |
+
TensorIndexType type_;
|
| 192 |
+
};
|
| 193 |
+
|
| 194 |
+
TORCH_API std::ostream& operator<<(
|
| 195 |
+
std::ostream& stream,
|
| 196 |
+
const TensorIndex& tensor_index);
|
| 197 |
+
TORCH_API std::ostream& operator<<(
|
| 198 |
+
std::ostream& stream,
|
| 199 |
+
const std::vector<TensorIndex>& tensor_indices);
|
| 200 |
+
|
| 201 |
+
namespace impl {
|
| 202 |
+
inline Tensor applySlice(
|
| 203 |
+
const Tensor& self,
|
| 204 |
+
int64_t dim,
|
| 205 |
+
c10::SymInt start,
|
| 206 |
+
c10::SymInt stop,
|
| 207 |
+
c10::SymInt step,
|
| 208 |
+
bool disable_slice_optimization,
|
| 209 |
+
const at::Device& self_device,
|
| 210 |
+
const std::optional<SymIntArrayRef>& self_sizes) {
|
| 211 |
+
// TODO: implement negative step
|
| 212 |
+
TORCH_CHECK_VALUE(
|
| 213 |
+
step.sym_gt(0).expect_true(__FILE__, __LINE__),
|
| 214 |
+
"step must be greater than zero");
|
| 215 |
+
|
| 216 |
+
// See NOTE [nested tensor size for indexing]
|
| 217 |
+
if (self_sizes.has_value()) {
|
| 218 |
+
// Skip this optimization if we are tracing, as the trace may be polymorphic
|
| 219 |
+
// over the shape of the `self` tensor, and we still want to record
|
| 220 |
+
// the slice.
|
| 221 |
+
SymInt length = (self_device == at::kCPU || self_device == at::kCUDA)
|
| 222 |
+
? (*self_sizes)[dim]
|
| 223 |
+
: self.sym_size(dim);
|
| 224 |
+
if (!disable_slice_optimization &&
|
| 225 |
+
TORCH_GUARD_SIZE_OBLIVIOUS(start.sym_eq(0)) &&
|
| 226 |
+
TORCH_GUARD_SIZE_OBLIVIOUS(length.sym_eq(stop)) && step == 1) {
|
| 227 |
+
return self;
|
| 228 |
+
}
|
| 229 |
+
}
|
| 230 |
+
return self.slice_symint(
|
| 231 |
+
dim, std::move(start), std::move(stop), std::move(step));
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
inline Tensor applySelect(
|
| 235 |
+
const Tensor& self,
|
| 236 |
+
int64_t dim,
|
| 237 |
+
SymInt index,
|
| 238 |
+
int64_t real_dim,
|
| 239 |
+
const at::Device& /*self_device*/,
|
| 240 |
+
const std::optional<SymIntArrayRef>& self_sizes) {
|
| 241 |
+
// See NOTE [nested tensor size for indexing]
|
| 242 |
+
if (self_sizes.has_value()) {
|
| 243 |
+
auto maybe_index = index.maybe_as_int();
|
| 244 |
+
if (maybe_index.has_value()) {
|
| 245 |
+
TORCH_CHECK_INDEX(
|
| 246 |
+
!(maybe_index.value() == 0 && dim == 0 && self_sizes->empty()),
|
| 247 |
+
"invalid index of a 0-dim tensor. ",
|
| 248 |
+
"Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number");
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
auto size = (*self_sizes)[dim];
|
| 252 |
+
// Note: `size >= -index` is not equivalent to `size > -1 - index` if index
|
| 253 |
+
// is INT64_MIN For std::numeric_limits<int64_t>::min() result of unary
|
| 254 |
+
// minus is undefined by the standard but in practice is equal to self. On
|
| 255 |
+
// the other hand, indexing wraping is valid for all negative int64_t
|
| 256 |
+
// values, as x[INT64_MIN] is the same as x[INT64_MAX]
|
| 257 |
+
TORCH_CHECK_INDEX(
|
| 258 |
+
size > -1 - index && size > index,
|
| 259 |
+
"index ",
|
| 260 |
+
index,
|
| 261 |
+
" is out of bounds for dimension ",
|
| 262 |
+
real_dim,
|
| 263 |
+
" with size ",
|
| 264 |
+
size);
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
// if the index is negative, do not normalize it because that would fix the
|
| 268 |
+
// index on the current tensor size in the tracer. aten::select also works on
|
| 269 |
+
// negative indices
|
| 270 |
+
return self.select_symint(dim, std::move(index));
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
inline Tensor boolToIndexingTensorCPUOrCUDA(const Tensor& self, bool value) {
|
| 274 |
+
// booleans add a dimension of size 1. true indexes this dimension as if 0:,
|
| 275 |
+
// false as empty.
|
| 276 |
+
if (value) {
|
| 277 |
+
return at::empty({1}, self.options().dtype(kLong)).fill_(0.);
|
| 278 |
+
} else {
|
| 279 |
+
return at::empty({0}, self.options().dtype(kLong));
|
| 280 |
+
}
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
inline Tensor boolToIndexingTensorNonNativeDeviceType(
|
| 284 |
+
const Tensor& self,
|
| 285 |
+
bool value) {
|
| 286 |
+
// booleans add a dimension of size 1. true indexes this dimension as if 0:,
|
| 287 |
+
// false as empty.
|
| 288 |
+
if (value) {
|
| 289 |
+
return at::zeros({1}, self.options().dtype(kLong));
|
| 290 |
+
} else {
|
| 291 |
+
return at::empty({0}, self.options().dtype(kLong));
|
| 292 |
+
}
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
inline Tensor boolToIndexingTensor(
|
| 296 |
+
const Tensor& self,
|
| 297 |
+
bool value,
|
| 298 |
+
const at::Device& self_device) {
|
| 299 |
+
if (self_device == at::kCPU || self_device == at::kCUDA) {
|
| 300 |
+
return boolToIndexingTensorCPUOrCUDA(self, value);
|
| 301 |
+
} else {
|
| 302 |
+
return boolToIndexingTensorNonNativeDeviceType(self, value);
|
| 303 |
+
}
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
inline Tensor scalarToTensorNonNativeDeviceType(
|
| 307 |
+
const Scalar& v,
|
| 308 |
+
const TensorOptions& options) {
|
| 309 |
+
return at::scalar_tensor(v, options);
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
inline void recordTensorIndex(
|
| 313 |
+
const Tensor& tensor,
|
| 314 |
+
std::vector<Tensor>& outIndices,
|
| 315 |
+
int64_t* dim_ptr) {
|
| 316 |
+
// TODO: check scalarType
|
| 317 |
+
outIndices.resize(*dim_ptr + 1);
|
| 318 |
+
outIndices[*dim_ptr] = tensor;
|
| 319 |
+
(*dim_ptr)++;
|
| 320 |
+
};
|
| 321 |
+
|
| 322 |
+
inline c10::List<::std::optional<Tensor>> typeConvertIndices(
|
| 323 |
+
const Tensor& /*self*/,
|
| 324 |
+
std::vector<Tensor>&& indices) {
|
| 325 |
+
c10::List<::std::optional<Tensor>> converted_inds;
|
| 326 |
+
converted_inds.reserve(indices.size());
|
| 327 |
+
for (auto&& i : std::move(indices)) {
|
| 328 |
+
converted_inds.push_back(std::move(i));
|
| 329 |
+
}
|
| 330 |
+
return converted_inds;
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
// NOTE: Why do we mirror instead of replace the `count_specified_dimensions`
|
| 334 |
+
// function in torch/csrc/autograd/python_variable_indexing.cpp? It's because
|
| 335 |
+
// `count_specified_dimensions` is on the hot path of Python tensor multi-dim
|
| 336 |
+
// indexing (i.e. it's called by `applySlicing` which is called by
|
| 337 |
+
// `THPVariable_getitem` / `THPVariable_setitem` when handling indexing of more
|
| 338 |
+
// than one dimension). If we were to merge the Python/C++
|
| 339 |
+
// `count_specified_dimensions` function, on the Python side we would have to
|
| 340 |
+
// construct a `std::vector` container to be consumed by the C++
|
| 341 |
+
// `count_specified_dimensions` function, which adds 100s of nanoseconds
|
| 342 |
+
// overhead and is undesirable.
|
| 343 |
+
inline int64_t count_specified_dimensions(
|
| 344 |
+
const ArrayRef<TensorIndex>& indices) {
|
| 345 |
+
// Count the number of indexed dimensions (everything but ellipsis and None)
|
| 346 |
+
int64_t count = 0;
|
| 347 |
+
for (auto& obj : indices) {
|
| 348 |
+
if (obj.is_tensor()) {
|
| 349 |
+
auto& tensor = obj.tensor();
|
| 350 |
+
if (tensor.scalar_type() == kByte || tensor.scalar_type() == kBool) {
|
| 351 |
+
count += tensor.dim();
|
| 352 |
+
} else {
|
| 353 |
+
count++;
|
| 354 |
+
}
|
| 355 |
+
} else if (!obj.is_none() && !obj.is_ellipsis() && !obj.is_boolean()) {
|
| 356 |
+
count++;
|
| 357 |
+
}
|
| 358 |
+
}
|
| 359 |
+
return count;
|
| 360 |
+
}
|
| 361 |
+
} // namespace impl
|
| 362 |
+
|
| 363 |
+
// NOTE: Many functions below are only for consumption from Python indexing
|
| 364 |
+
// implementation, they include:
|
| 365 |
+
//
|
| 366 |
+
// - `Tensor scalarToTensor(...)`
|
| 367 |
+
// - `IntArrayRef slicePrefix1sSize(...)`
|
| 368 |
+
// - `void copy_to(...)`
|
| 369 |
+
// - `Tensor handleDimInMultiDimIndexing(...)`
|
| 370 |
+
// - `Tensor dispatch_index(...)`
|
| 371 |
+
// - `Tensor dispatch_index_put_(...)`
|
| 372 |
+
// - `Tensor get_item(...)`
|
| 373 |
+
// - `void set_item(...)`
|
| 374 |
+
//
|
| 375 |
+
// The rest of the functions are in `at::indexing::impl` namespace, signifying
|
| 376 |
+
// that they shouldn't be used from Python indexing implementation.
|
| 377 |
+
inline Tensor scalarToTensor(
|
| 378 |
+
const Scalar& v,
|
| 379 |
+
const TensorOptions& options,
|
| 380 |
+
const at::Device& self_device) {
|
| 381 |
+
if (self_device == at::kCPU && !v.isSymbolic()) {
|
| 382 |
+
return at::detail::scalar_tensor_static(
|
| 383 |
+
v, options.dtype_opt()->toScalarType(), self_device);
|
| 384 |
+
} else {
|
| 385 |
+
return impl::scalarToTensorNonNativeDeviceType(v, options);
|
| 386 |
+
}
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
// To match numpy semantics:
|
| 390 |
+
// As a special case for backwards compatibility,
|
| 391 |
+
// strip away unit dimensions from the left of 'src'
|
| 392 |
+
inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) {
|
| 393 |
+
size_t first_non1_src = sizes.size();
|
| 394 |
+
for (const auto i : c10::irange(sizes.size())) {
|
| 395 |
+
// Unbacked SymInt has different behavior, but this is sound because
|
| 396 |
+
// failing to slice will only ever cause an error, not divergent
|
| 397 |
+
// behavior
|
| 398 |
+
if (!sizes[i].has_hint() || sizes[i] != 1) {
|
| 399 |
+
first_non1_src = i;
|
| 400 |
+
break;
|
| 401 |
+
}
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
return sizes.slice(first_non1_src);
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
inline void copy_to(const Tensor& dst, const Tensor& src) {
|
| 408 |
+
if (dst.sym_sizes().equals(src.sym_sizes())) {
|
| 409 |
+
// A shortcut to avoid generating hard-coded constant sizes during tracing.
|
| 410 |
+
// This is not a perfect solution: when src & dst have different shapes,
|
| 411 |
+
// constants will still appear. Users can workaround that case by
|
| 412 |
+
// dst[index..] = src.reshape(..)
|
| 413 |
+
dst.copy_(src);
|
| 414 |
+
return;
|
| 415 |
+
} else if (src.dim() == 0 && src.device().type() == at::kCPU) {
|
| 416 |
+
dst.fill_(src);
|
| 417 |
+
return;
|
| 418 |
+
}
|
| 419 |
+
auto src_view = src.view_symint(slicePrefix1sSize(src.sym_sizes()));
|
| 420 |
+
c10::MaybeOwned<Tensor> b_src = expand_inplace(dst, src_view, "setitem");
|
| 421 |
+
dst.copy_(*b_src);
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
// See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor
|
| 425 |
+
// indexing functions from Python ]
|
| 426 |
+
inline Tensor handleDimInMultiDimIndexing(
|
| 427 |
+
const Tensor& prev_dim_result,
|
| 428 |
+
const Tensor& original_tensor,
|
| 429 |
+
const TensorIndex& index,
|
| 430 |
+
int64_t* dim_ptr,
|
| 431 |
+
int64_t* specified_dims_ptr,
|
| 432 |
+
int64_t real_dim,
|
| 433 |
+
std::vector<Tensor>& outIndices,
|
| 434 |
+
bool disable_slice_optimization,
|
| 435 |
+
const at::Device& original_tensor_device,
|
| 436 |
+
const std::optional<SymIntArrayRef>& prev_dim_result_sizes) {
|
| 437 |
+
if (index.is_integer()) {
|
| 438 |
+
return impl::applySelect(
|
| 439 |
+
prev_dim_result,
|
| 440 |
+
*dim_ptr,
|
| 441 |
+
index.integer(),
|
| 442 |
+
real_dim,
|
| 443 |
+
original_tensor_device,
|
| 444 |
+
prev_dim_result_sizes);
|
| 445 |
+
} else if (index.is_slice()) {
|
| 446 |
+
Tensor result = impl::applySlice(
|
| 447 |
+
prev_dim_result,
|
| 448 |
+
*dim_ptr,
|
| 449 |
+
index.slice().start(),
|
| 450 |
+
index.slice().stop(),
|
| 451 |
+
index.slice().step(),
|
| 452 |
+
/*disable_slice_optimization=*/disable_slice_optimization,
|
| 453 |
+
original_tensor_device,
|
| 454 |
+
prev_dim_result_sizes);
|
| 455 |
+
(*dim_ptr)++;
|
| 456 |
+
return result;
|
| 457 |
+
} else if (index.is_ellipsis()) {
|
| 458 |
+
(*dim_ptr) += original_tensor.dim() - (*specified_dims_ptr);
|
| 459 |
+
return prev_dim_result;
|
| 460 |
+
} else if (index.is_none()) {
|
| 461 |
+
Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
|
| 462 |
+
(*dim_ptr)++;
|
| 463 |
+
return result;
|
| 464 |
+
} else if (index.is_boolean()) {
|
| 465 |
+
Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
|
| 466 |
+
impl::recordTensorIndex(
|
| 467 |
+
impl::boolToIndexingTensor(
|
| 468 |
+
result, index.boolean(), original_tensor_device),
|
| 469 |
+
outIndices,
|
| 470 |
+
dim_ptr);
|
| 471 |
+
return result;
|
| 472 |
+
} else if (index.is_tensor()) {
|
| 473 |
+
Tensor result = prev_dim_result;
|
| 474 |
+
const Tensor& tensor = index.tensor();
|
| 475 |
+
auto scalar_type = tensor.scalar_type();
|
| 476 |
+
if (tensor.dim() == 0 &&
|
| 477 |
+
at::isIntegralType(scalar_type, /*includeBool=*/true)) {
|
| 478 |
+
if (scalar_type != at::kByte && scalar_type != at::kBool) {
|
| 479 |
+
result = impl::applySelect(
|
| 480 |
+
result,
|
| 481 |
+
*dim_ptr,
|
| 482 |
+
tensor.item<int64_t>(),
|
| 483 |
+
real_dim,
|
| 484 |
+
original_tensor_device,
|
| 485 |
+
prev_dim_result_sizes);
|
| 486 |
+
} else {
|
| 487 |
+
result = result.unsqueeze(*dim_ptr);
|
| 488 |
+
if (scalar_type == at::kBool) {
|
| 489 |
+
impl::recordTensorIndex(
|
| 490 |
+
impl::boolToIndexingTensor(
|
| 491 |
+
result, tensor.item<bool>() != 0, original_tensor_device),
|
| 492 |
+
outIndices,
|
| 493 |
+
dim_ptr);
|
| 494 |
+
} else {
|
| 495 |
+
impl::recordTensorIndex(
|
| 496 |
+
impl::boolToIndexingTensor(
|
| 497 |
+
result, tensor.item<uint8_t>() != 0, original_tensor_device),
|
| 498 |
+
outIndices,
|
| 499 |
+
dim_ptr);
|
| 500 |
+
}
|
| 501 |
+
}
|
| 502 |
+
} else {
|
| 503 |
+
impl::recordTensorIndex(tensor, outIndices, dim_ptr);
|
| 504 |
+
}
|
| 505 |
+
return result;
|
| 506 |
+
} else {
|
| 507 |
+
TORCH_INTERNAL_ASSERT(false, "Invalid TensorIndex type");
|
| 508 |
+
}
|
| 509 |
+
}
|
| 510 |
+
|
| 511 |
+
namespace impl {
|
| 512 |
+
// This mirrors `applySlicing` in
|
| 513 |
+
// torch/csrc/autograd/python_variable_indexing.cpp
|
| 514 |
+
inline Tensor applySlicing(
|
| 515 |
+
const Tensor& self,
|
| 516 |
+
const ArrayRef<TensorIndex>& indices,
|
| 517 |
+
std::vector<Tensor>& outIndices,
|
| 518 |
+
bool disable_slice_optimization,
|
| 519 |
+
const at::Device& self_device,
|
| 520 |
+
const std::optional<SymIntArrayRef>& self_sizes) {
|
| 521 |
+
int64_t dim = 0;
|
| 522 |
+
int64_t specified_dims = impl::count_specified_dimensions(indices);
|
| 523 |
+
|
| 524 |
+
// See NOTE [nested tensor size for indexing]
|
| 525 |
+
if (self_sizes.has_value()) {
|
| 526 |
+
TORCH_CHECK_INDEX(
|
| 527 |
+
specified_dims <= (int64_t)self_sizes->size(),
|
| 528 |
+
"too many indices for tensor of dimension ",
|
| 529 |
+
(int)self_sizes->size());
|
| 530 |
+
}
|
| 531 |
+
|
| 532 |
+
Tensor result = self;
|
| 533 |
+
for (const auto i : c10::irange(indices.size())) {
|
| 534 |
+
auto& obj = indices[i];
|
| 535 |
+
// See NOTE [nested tensor size for indexing]
|
| 536 |
+
std::optional<SymIntArrayRef> result_sizes = result.is_nested()
|
| 537 |
+
? std::optional<SymIntArrayRef>(std::nullopt)
|
| 538 |
+
: std::optional<SymIntArrayRef>(result.sym_sizes());
|
| 539 |
+
result = handleDimInMultiDimIndexing(
|
| 540 |
+
/*prev_dim_result=*/result,
|
| 541 |
+
/*original_tensor=*/self,
|
| 542 |
+
/*index=*/obj,
|
| 543 |
+
/*dim_ptr=*/&dim,
|
| 544 |
+
/*specified_dims_ptr=*/&specified_dims,
|
| 545 |
+
/*real_dim=*/static_cast<int64_t>(i),
|
| 546 |
+
/*outIndices=*/outIndices,
|
| 547 |
+
/*disable_slice_optimization=*/disable_slice_optimization,
|
| 548 |
+
/*original_tensor_device=*/self_device,
|
| 549 |
+
/*prev_dim_result_sizes=*/result_sizes);
|
| 550 |
+
}
|
| 551 |
+
return result;
|
| 552 |
+
}
|
| 553 |
+
} // namespace impl
|
| 554 |
+
|
| 555 |
+
inline Tensor dispatch_index(
|
| 556 |
+
const Tensor& self,
|
| 557 |
+
std::vector<Tensor>&& indices) {
|
| 558 |
+
return self.index(impl::typeConvertIndices(self, std::move(indices)));
|
| 559 |
+
}
|
| 560 |
+
|
| 561 |
+
inline Tensor dispatch_index_put_(
|
| 562 |
+
Tensor& self,
|
| 563 |
+
std::vector<Tensor>&& indices,
|
| 564 |
+
const Tensor& value) {
|
| 565 |
+
return self.index_put_(
|
| 566 |
+
impl::typeConvertIndices(self, std::move(indices)), value);
|
| 567 |
+
}
|
| 568 |
+
|
| 569 |
+
// NOTE [ Setting `disable_slice_optimization` when calling C++ tensor indexing
|
| 570 |
+
// functions from Python ]
|
| 571 |
+
//
|
| 572 |
+
// Question: When should we set `disable_slice_optimization` to `true` when
|
| 573 |
+
// calling C++ tensor indexing functions from Python indexing code?
|
| 574 |
+
//
|
| 575 |
+
// Answer: What "slice optimization" means: when we have a slicing expression
|
| 576 |
+
// like `x[0:5, 0]`, where the sliced tensor was of size 5 in dimension 0, we
|
| 577 |
+
// would skip dispatching the actual slice call as an optimization. However,
|
| 578 |
+
// here are the cases where we DON'T want this optimization:
|
| 579 |
+
//
|
| 580 |
+
// 1. When we are doing 1-D slicing (e.g. `tensor[:]`).
|
| 581 |
+
// Reason: we always return a shallow copy for expressions such as
|
| 582 |
+
// `tensor[:]` / `tensor[...]` / `tensor[:, :]`. (Note that for `tensor[:,
|
| 583 |
+
// :]`, we return an alias of `tensor` by doing the following:
|
| 584 |
+
// ```
|
| 585 |
+
// Tensor sliced = impl::applySlicing(self, indices, tensorIndices,
|
| 586 |
+
// disable_slice_optimization, self_device, self_sizes); if
|
| 587 |
+
// (tensorIndices.empty()) {
|
| 588 |
+
// if (sliced.is_same(self)) {
|
| 589 |
+
// // ensure we return a shallow copy for things like x[...]
|
| 590 |
+
// sliced = at::alias(sliced);
|
| 591 |
+
// }
|
| 592 |
+
// return sliced;
|
| 593 |
+
// }
|
| 594 |
+
// ```)
|
| 595 |
+
// 2. When we are doing JIT tracing.
|
| 596 |
+
// Reason: JIT tracing needs the `self.slice(...)` call to properly trace the
|
| 597 |
+
// slice operation.
|
| 598 |
+
|
| 599 |
+
// This mirrors `THPVariable_getitem` in
|
| 600 |
+
// torch/csrc/autograd/python_variable_indexing.cpp See NOTE [ Setting
|
| 601 |
+
// `disable_slice_optimization` when calling C++ tensor indexing functions from
|
| 602 |
+
// Python ]
|
| 603 |
+
inline Tensor get_item(
|
| 604 |
+
const Tensor& self,
|
| 605 |
+
const ArrayRef<TensorIndex>& indices,
|
| 606 |
+
bool disable_slice_optimization = false) {
|
| 607 |
+
at::Device self_device = self.device();
|
| 608 |
+
// NOTE [nested tensor size for indexing]
|
| 609 |
+
// nested tensor does not have a size (yet) so for now we represent its size
|
| 610 |
+
// as null may need to be changed after we reach a better solution for nested
|
| 611 |
+
// tensor size
|
| 612 |
+
std::optional<SymIntArrayRef> self_sizes = self.is_nested()
|
| 613 |
+
? std::optional<SymIntArrayRef>(std::nullopt)
|
| 614 |
+
: std::optional<SymIntArrayRef>(self.sym_sizes());
|
| 615 |
+
|
| 616 |
+
// handle simple types: integers, slices, none, ellipsis, bool
|
| 617 |
+
if (indices.size() == 1) {
|
| 618 |
+
const TensorIndex& index = indices[0];
|
| 619 |
+
if (index.is_integer()) {
|
| 620 |
+
return impl::applySelect(
|
| 621 |
+
self, 0, index.integer(), 0, self_device, self_sizes);
|
| 622 |
+
} else if (index.is_slice()) {
|
| 623 |
+
return impl::applySlice(
|
| 624 |
+
self,
|
| 625 |
+
0,
|
| 626 |
+
index.slice().start(),
|
| 627 |
+
index.slice().stop(),
|
| 628 |
+
index.slice().step(),
|
| 629 |
+
/*disable_slice_optimization=*/true,
|
| 630 |
+
self_device,
|
| 631 |
+
self_sizes);
|
| 632 |
+
} else if (index.is_none()) {
|
| 633 |
+
return self.unsqueeze(0);
|
| 634 |
+
} else if (index.is_ellipsis()) {
|
| 635 |
+
return at::alias(self);
|
| 636 |
+
} else if (index.is_boolean()) {
|
| 637 |
+
Tensor result = self.unsqueeze(0);
|
| 638 |
+
return dispatch_index(
|
| 639 |
+
result,
|
| 640 |
+
std::vector<Tensor>{impl::boolToIndexingTensor(
|
| 641 |
+
result, index.boolean(), self_device)});
|
| 642 |
+
}
|
| 643 |
+
}
|
| 644 |
+
|
| 645 |
+
std::vector<Tensor> tensorIndices;
|
| 646 |
+
Tensor sliced = impl::applySlicing(
|
| 647 |
+
self,
|
| 648 |
+
indices,
|
| 649 |
+
tensorIndices,
|
| 650 |
+
disable_slice_optimization,
|
| 651 |
+
self_device,
|
| 652 |
+
self_sizes);
|
| 653 |
+
if (tensorIndices.empty()) {
|
| 654 |
+
if (sliced.is_same(self)) {
|
| 655 |
+
// ensure we return a shallow copy for things like x[...]
|
| 656 |
+
sliced = at::alias(sliced);
|
| 657 |
+
}
|
| 658 |
+
return sliced;
|
| 659 |
+
}
|
| 660 |
+
|
| 661 |
+
// indexing by tensors ("advanced" indexing)
|
| 662 |
+
return dispatch_index(sliced, std::move(tensorIndices));
|
| 663 |
+
}
|
| 664 |
+
|
| 665 |
+
// This mirrors `THPVariable_setitem` in
|
| 666 |
+
// torch/csrc/autograd/python_variable_indexing.cpp for "the assigned value is a
|
| 667 |
+
// Tensor" case See NOTE [ Setting `disable_slice_optimization` when calling C++
|
| 668 |
+
// tensor indexing functions from Python ]
|
| 669 |
+
inline void set_item(
|
| 670 |
+
const Tensor& self,
|
| 671 |
+
const ArrayRef<TensorIndex>& indices,
|
| 672 |
+
const Tensor& value,
|
| 673 |
+
bool disable_slice_optimization = false) {
|
| 674 |
+
at::Device self_device = self.device();
|
| 675 |
+
SymIntArrayRef self_sizes = self.sym_sizes();
|
| 676 |
+
|
| 677 |
+
// handle simple types: integers, slices, ellipsis, bool
|
| 678 |
+
if (indices.size() == 1) {
|
| 679 |
+
const TensorIndex& index = indices[0];
|
| 680 |
+
if (index.is_boolean() && !index.boolean()) {
|
| 681 |
+
// do nothing for false (technically we should check the size, but we
|
| 682 |
+
// don't have real 0-sized shapes.
|
| 683 |
+
return;
|
| 684 |
+
} else if (index.is_ellipsis()) {
|
| 685 |
+
copy_to(self, value);
|
| 686 |
+
return;
|
| 687 |
+
} else if (index.is_none() || (index.is_boolean() && index.boolean())) {
|
| 688 |
+
copy_to(self.unsqueeze(0), value);
|
| 689 |
+
return;
|
| 690 |
+
} else if (index.is_integer()) {
|
| 691 |
+
copy_to(
|
| 692 |
+
impl::applySelect(
|
| 693 |
+
self, 0, index.integer(), 0, self_device, self_sizes),
|
| 694 |
+
value);
|
| 695 |
+
return;
|
| 696 |
+
} else if (index.is_slice()) {
|
| 697 |
+
copy_to(
|
| 698 |
+
impl::applySlice(
|
| 699 |
+
self,
|
| 700 |
+
0,
|
| 701 |
+
index.slice().start(),
|
| 702 |
+
index.slice().stop(),
|
| 703 |
+
index.slice().step(),
|
| 704 |
+
/*disable_slice_optimization=*/disable_slice_optimization,
|
| 705 |
+
self_device,
|
| 706 |
+
self_sizes),
|
| 707 |
+
value);
|
| 708 |
+
return;
|
| 709 |
+
}
|
| 710 |
+
}
|
| 711 |
+
|
| 712 |
+
std::vector<Tensor> tensorIndices;
|
| 713 |
+
Tensor sliced = impl::applySlicing(
|
| 714 |
+
self,
|
| 715 |
+
indices,
|
| 716 |
+
tensorIndices,
|
| 717 |
+
disable_slice_optimization,
|
| 718 |
+
self_device,
|
| 719 |
+
self_sizes);
|
| 720 |
+
if (tensorIndices.empty()) {
|
| 721 |
+
copy_to(sliced, value);
|
| 722 |
+
return;
|
| 723 |
+
}
|
| 724 |
+
|
| 725 |
+
SymIntArrayRef valueSizes = value.sym_sizes();
|
| 726 |
+
SymIntArrayRef slicedValueSizes = slicePrefix1sSize(valueSizes);
|
| 727 |
+
Tensor valuesSliced;
|
| 728 |
+
if (!valueSizes.equals(slicedValueSizes)) {
|
| 729 |
+
valuesSliced = value.view_symint(slicedValueSizes);
|
| 730 |
+
} else {
|
| 731 |
+
valuesSliced = value;
|
| 732 |
+
}
|
| 733 |
+
dispatch_index_put_(sliced, std::move(tensorIndices), valuesSliced);
|
| 734 |
+
return;
|
| 735 |
+
}
|
| 736 |
+
|
| 737 |
+
} // namespace at::indexing
|
.venv/lib/python3.11/site-packages/torch/include/ATen/TensorIteratorInternal.h
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/native/TensorIterator.h>
|
| 3 |
+
#include <c10/util/SmallBuffer.h>
|
| 4 |
+
#include <c10/util/irange.h>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
|
| 8 |
+
struct DimCounter {
|
| 9 |
+
DimCounter(IntArrayRef shape, Range range);
|
| 10 |
+
|
| 11 |
+
void increment(const std::array<int64_t, 2>& step);
|
| 12 |
+
bool is_done() const;
|
| 13 |
+
std::array<int64_t, 2> max_2d_step() const;
|
| 14 |
+
|
| 15 |
+
IntArrayRef shape;
|
| 16 |
+
Range range;
|
| 17 |
+
c10::SmallBuffer<int64_t, 4> values;
|
| 18 |
+
int64_t offset;
|
| 19 |
+
};
|
| 20 |
+
|
| 21 |
+
namespace internal {
|
| 22 |
+
|
| 23 |
+
inline void get_data_ptrs(
|
| 24 |
+
char** ptrs,
|
| 25 |
+
ArrayRef<char*> base,
|
| 26 |
+
IntArrayRef strides,
|
| 27 |
+
IntArrayRef counter) {
|
| 28 |
+
const auto ntensors = base.size();
|
| 29 |
+
const auto ndim = counter.size();
|
| 30 |
+
std::copy(base.begin(), base.end(), ptrs);
|
| 31 |
+
for (const auto dim : c10::irange(ndim)) {
|
| 32 |
+
int64_t value = counter[dim];
|
| 33 |
+
for (const auto arg : c10::irange(ntensors)) {
|
| 34 |
+
ptrs[arg] += value * strides[dim * ntensors + arg];
|
| 35 |
+
}
|
| 36 |
+
}
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
inline void serial_for_each(
|
| 40 |
+
IntArrayRef shape,
|
| 41 |
+
IntArrayRef strides,
|
| 42 |
+
char** base_ptrs,
|
| 43 |
+
size_t ntensors,
|
| 44 |
+
typename TensorIteratorBase::loop2d_t loop,
|
| 45 |
+
Range range) {
|
| 46 |
+
const auto ndim = shape.size();
|
| 47 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
| 48 |
+
strides.size() == ntensors * std::max(size_t{2}, ndim));
|
| 49 |
+
|
| 50 |
+
if (ndim <= 1) {
|
| 51 |
+
if (range.begin == 0) {
|
| 52 |
+
loop(base_ptrs, strides.data(), range.size(), 1);
|
| 53 |
+
} else {
|
| 54 |
+
c10::SmallBuffer<char*, 4> ptrs(ntensors);
|
| 55 |
+
get_data_ptrs(ptrs.data(), {base_ptrs, ntensors}, strides, {range.begin});
|
| 56 |
+
loop(ptrs.data(), strides.data(), range.size(), 1);
|
| 57 |
+
}
|
| 58 |
+
} else {
|
| 59 |
+
c10::SmallBuffer<char*, 4> ptrs(ntensors);
|
| 60 |
+
auto counter = DimCounter(shape, range);
|
| 61 |
+
while (!counter.is_done()) {
|
| 62 |
+
get_data_ptrs(
|
| 63 |
+
ptrs.data(), {base_ptrs, ntensors}, strides, counter.values);
|
| 64 |
+
auto step = counter.max_2d_step();
|
| 65 |
+
loop(ptrs.data(), strides.data(), step[0], step[1]);
|
| 66 |
+
counter.increment(step);
|
| 67 |
+
}
|
| 68 |
+
}
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
} // namespace internal
|
| 72 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/TensorOptions.h
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <c10/core/TensorOptions.h>
|
.venv/lib/python3.11/site-packages/torch/include/ATen/TensorSubclassLikeUtils.h
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/List.h>
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <c10/core/impl/TorchDispatchModeTLS.h>
|
| 5 |
+
|
| 6 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 7 |
+
#include <ATen/Functions.h>
|
| 8 |
+
#else
|
| 9 |
+
#include <ATen/ops/equal.h>
|
| 10 |
+
#endif
|
| 11 |
+
|
| 12 |
+
namespace at {
|
| 13 |
+
|
| 14 |
+
// Note [Tensor-subclass-like Tensors]
|
| 15 |
+
// Tensor-subclass-like is defined as:
|
| 16 |
+
// - a Tensor subclass (via __torch_dispatch__ in Python or extending
|
| 17 |
+
// TensorImpl in C++)
|
| 18 |
+
// - anything else that shares the same perils as Tensor subclasses.
|
| 19 |
+
// For example, many Tensor subclasses do not have storage and meta Tensors
|
| 20 |
+
// do not have storage either, so meta Tensors belong here.
|
| 21 |
+
//
|
| 22 |
+
// We should ensure that PyTorch internals supports Tensor-subclass-like
|
| 23 |
+
// objects. In particular, Tensor-subclass-like objects struggle with two
|
| 24 |
+
// classes of operations that are problematic for Tensor subclasses:
|
| 25 |
+
// 1. Because some Tensor subclasses do not have storage, .item() or
|
| 26 |
+
// .data_ptr() calls are not good.
|
| 27 |
+
// 2. Certain in-place operations can eliminate the typing of the Tensor
|
| 28 |
+
// subclass. For example:
|
| 29 |
+
// >>> torch.zeros(input.sizes(), grad.options()).diag().copy_(input)
|
| 30 |
+
// If input is a Tensor subclass, then the above ends up either erroring out
|
| 31 |
+
// or returning a regular non-Tensor-subclass Tensor!
|
| 32 |
+
|
| 33 |
+
constexpr auto kFunctorchWrappedTensors = DispatchKeySet(
|
| 34 |
+
{DispatchKey::FuncTorchGradWrapper,
|
| 35 |
+
DispatchKey::FuncTorchBatched,
|
| 36 |
+
DispatchKey::Functionalize});
|
| 37 |
+
|
| 38 |
+
constexpr auto kTensorSubclassLike =
|
| 39 |
+
kFunctorchWrappedTensors |
|
| 40 |
+
DispatchKeySet(
|
| 41 |
+
{// WARNING: DO NOT put combined backend component + functionality keys
|
| 42 |
+
// here, you will incorrectly always match on the functionality key
|
| 43 |
+
// no matter the backend component
|
| 44 |
+
DispatchKey::Batched,
|
| 45 |
+
DispatchKey::Sparse,
|
| 46 |
+
DispatchKey::SparseCsr,
|
| 47 |
+
DispatchKey::Python}) |
|
| 48 |
+
DispatchKeySet(BackendComponent::MetaBit);
|
| 49 |
+
|
| 50 |
+
inline bool isTensorSubclassLike(const Tensor& tensor) {
|
| 51 |
+
if (c10::impl::dispatch_mode_enabled())
|
| 52 |
+
return true;
|
| 53 |
+
auto key_set = tensor.unsafeGetTensorImpl()->key_set();
|
| 54 |
+
return !(key_set & kTensorSubclassLike).empty();
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
inline bool areAnyTensorSubclassLike(TensorList tensors) {
|
| 58 |
+
if (c10::impl::dispatch_mode_enabled())
|
| 59 |
+
return true;
|
| 60 |
+
return std::any_of(tensors.begin(), tensors.end(), isTensorSubclassLike);
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
inline bool areAnyOptionalTensorSubclassLike(
|
| 64 |
+
const c10::List<std::optional<Tensor>>& tensors) {
|
| 65 |
+
if (c10::impl::dispatch_mode_enabled())
|
| 66 |
+
return true;
|
| 67 |
+
return std::any_of(
|
| 68 |
+
tensors.begin(),
|
| 69 |
+
tensors.end(),
|
| 70 |
+
[](const std::optional<Tensor>& opt_tensor) {
|
| 71 |
+
return (
|
| 72 |
+
opt_tensor.has_value() && isTensorSubclassLike(opt_tensor.value()));
|
| 73 |
+
});
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
// Helper function to deal testing truthfulness of a scalar tensor
|
| 77 |
+
// in a Composite Compliant manner.
|
| 78 |
+
// NOTE: This function expects a scalar tensor of boolean dtype.
|
| 79 |
+
// Eg.
|
| 80 |
+
// Non-Composite Compliant Pattern : (t == 0).all().item<bool>()
|
| 81 |
+
// Composite Compliant Patter : is_salar_tensor_true((t == 0).all())
|
| 82 |
+
inline bool is_scalar_tensor_true(const Tensor& t) {
|
| 83 |
+
TORCH_INTERNAL_ASSERT(t.dim() == 0)
|
| 84 |
+
TORCH_INTERNAL_ASSERT(t.scalar_type() == kBool)
|
| 85 |
+
return at::equal(t, t.new_ones({}, t.options()));
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/TensorUtils.h
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/DimVector.h>
|
| 4 |
+
#include <ATen/EmptyTensor.h>
|
| 5 |
+
#include <ATen/Tensor.h>
|
| 6 |
+
#include <ATen/TensorGeometry.h>
|
| 7 |
+
#include <ATen/Utils.h>
|
| 8 |
+
|
| 9 |
+
#include <utility>
|
| 10 |
+
|
| 11 |
+
// These functions are NOT in Utils.h, because this file has a dep on Tensor.h
|
| 12 |
+
|
| 13 |
+
#define TORCH_CHECK_TENSOR_ALL(cond, ...) \
|
| 14 |
+
TORCH_CHECK((cond)._is_all_true().item<bool>(), __VA_ARGS__);
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
// The following are utility functions for checking that arguments
|
| 19 |
+
// make sense. These are particularly useful for native functions,
|
| 20 |
+
// which do NO argument checking by default.
|
| 21 |
+
|
| 22 |
+
struct TORCH_API TensorArg {
|
| 23 |
+
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
| 24 |
+
const Tensor& tensor;
|
| 25 |
+
const char* name;
|
| 26 |
+
int pos; // 1-indexed
|
| 27 |
+
TensorArg(const Tensor& tensor, const char* name, int pos)
|
| 28 |
+
: tensor(tensor), name(name), pos(pos) {}
|
| 29 |
+
// Try to mitigate any possibility of dangling reference to temporaries.
|
| 30 |
+
// NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
|
| 31 |
+
TensorArg(Tensor&& tensor, const char* name, int pos) = delete;
|
| 32 |
+
const Tensor* operator->() const {
|
| 33 |
+
return &tensor;
|
| 34 |
+
}
|
| 35 |
+
const Tensor& operator*() const {
|
| 36 |
+
return tensor;
|
| 37 |
+
}
|
| 38 |
+
};
|
| 39 |
+
|
| 40 |
+
struct TORCH_API TensorGeometryArg {
|
| 41 |
+
TensorGeometry tensor;
|
| 42 |
+
const char* name;
|
| 43 |
+
int pos; // 1-indexed
|
| 44 |
+
/* implicit */ TensorGeometryArg(TensorArg arg)
|
| 45 |
+
: tensor(TensorGeometry{arg.tensor}), name(arg.name), pos(arg.pos) {}
|
| 46 |
+
TensorGeometryArg(TensorGeometry tensor, const char* name, int pos)
|
| 47 |
+
: tensor(std::move(tensor)), name(name), pos(pos) {}
|
| 48 |
+
const TensorGeometry* operator->() const {
|
| 49 |
+
return &tensor;
|
| 50 |
+
}
|
| 51 |
+
const TensorGeometry& operator*() const {
|
| 52 |
+
return tensor;
|
| 53 |
+
}
|
| 54 |
+
};
|
| 55 |
+
|
| 56 |
+
// A string describing which function did checks on its input
|
| 57 |
+
// arguments.
|
| 58 |
+
// TODO: Consider generalizing this into a call stack.
|
| 59 |
+
using CheckedFrom = const char*;
|
| 60 |
+
|
| 61 |
+
// The undefined convention: singular operators assume their arguments
|
| 62 |
+
// are defined, but functions which take multiple tensors will
|
| 63 |
+
// implicitly filter out undefined tensors (to make it easier to perform
|
| 64 |
+
// tests which should apply if the tensor is defined, and should not
|
| 65 |
+
// otherwise.)
|
| 66 |
+
//
|
| 67 |
+
// NB: This means that the n-ary operators take lists of TensorArg,
|
| 68 |
+
// not TensorGeometryArg, because the Tensor to TensorGeometry
|
| 69 |
+
// conversion will blow up if you have undefined tensors.
|
| 70 |
+
|
| 71 |
+
TORCH_API std::ostream& operator<<(
|
| 72 |
+
std::ostream& out,
|
| 73 |
+
const TensorGeometryArg& t);
|
| 74 |
+
TORCH_API void checkDim(
|
| 75 |
+
CheckedFrom c,
|
| 76 |
+
const Tensor& tensor,
|
| 77 |
+
const char* name,
|
| 78 |
+
int pos, // 1-indexed
|
| 79 |
+
int64_t dim);
|
| 80 |
+
TORCH_API void checkDim(CheckedFrom c, const TensorGeometryArg& t, int64_t dim);
|
| 81 |
+
// NB: this is an inclusive-exclusive range
|
| 82 |
+
TORCH_API void checkDimRange(
|
| 83 |
+
CheckedFrom c,
|
| 84 |
+
const TensorGeometryArg& t,
|
| 85 |
+
int64_t dim_start,
|
| 86 |
+
int64_t dim_end);
|
| 87 |
+
TORCH_API void checkSameDim(
|
| 88 |
+
CheckedFrom c,
|
| 89 |
+
const TensorGeometryArg& t1,
|
| 90 |
+
const TensorGeometryArg& t2);
|
| 91 |
+
TORCH_API void checkContiguous(CheckedFrom c, const TensorGeometryArg& t);
|
| 92 |
+
TORCH_API void checkAllContiguous(CheckedFrom c, at::ArrayRef<TensorArg> ts);
|
| 93 |
+
TORCH_API void checkSize(
|
| 94 |
+
CheckedFrom c,
|
| 95 |
+
const TensorGeometryArg& t,
|
| 96 |
+
IntArrayRef sizes);
|
| 97 |
+
TORCH_API void checkSize_symint(
|
| 98 |
+
CheckedFrom c,
|
| 99 |
+
const TensorGeometryArg& t,
|
| 100 |
+
c10::SymIntArrayRef sizes);
|
| 101 |
+
TORCH_API void checkSize(
|
| 102 |
+
CheckedFrom c,
|
| 103 |
+
const TensorGeometryArg& t,
|
| 104 |
+
int64_t dim,
|
| 105 |
+
int64_t size);
|
| 106 |
+
TORCH_API void checkSize_symint(
|
| 107 |
+
CheckedFrom c,
|
| 108 |
+
const TensorGeometryArg& t,
|
| 109 |
+
int64_t dim,
|
| 110 |
+
const c10::SymInt& size);
|
| 111 |
+
TORCH_API void checkNumel(
|
| 112 |
+
CheckedFrom c,
|
| 113 |
+
const TensorGeometryArg& t,
|
| 114 |
+
int64_t numel);
|
| 115 |
+
TORCH_API void checkSameNumel(
|
| 116 |
+
CheckedFrom c,
|
| 117 |
+
const TensorArg& t1,
|
| 118 |
+
const TensorArg& t2);
|
| 119 |
+
TORCH_API void checkAllSameNumel(CheckedFrom c, ArrayRef<TensorArg> tensors);
|
| 120 |
+
TORCH_API void checkScalarType(CheckedFrom c, const TensorArg& t, ScalarType s);
|
| 121 |
+
TORCH_API void checkScalarTypes(
|
| 122 |
+
CheckedFrom c,
|
| 123 |
+
const TensorArg& t,
|
| 124 |
+
at::ArrayRef<ScalarType> l);
|
| 125 |
+
TORCH_API void checkSameGPU(
|
| 126 |
+
CheckedFrom c,
|
| 127 |
+
const TensorArg& t1,
|
| 128 |
+
const TensorArg& t2);
|
| 129 |
+
TORCH_API void checkAllSameGPU(CheckedFrom c, ArrayRef<TensorArg> tensors);
|
| 130 |
+
TORCH_API void checkSameType(
|
| 131 |
+
CheckedFrom c,
|
| 132 |
+
const TensorArg& t1,
|
| 133 |
+
const TensorArg& t2);
|
| 134 |
+
TORCH_API void checkAllSameType(CheckedFrom c, ArrayRef<TensorArg> tensors);
|
| 135 |
+
TORCH_API void checkSameSize(
|
| 136 |
+
CheckedFrom c,
|
| 137 |
+
const TensorArg& t1,
|
| 138 |
+
const TensorArg& t2);
|
| 139 |
+
TORCH_API void checkAllSameSize(CheckedFrom c, ArrayRef<TensorArg> tensors);
|
| 140 |
+
TORCH_API void checkDefined(CheckedFrom c, const TensorArg& t);
|
| 141 |
+
TORCH_API void checkAllDefined(CheckedFrom c, at::ArrayRef<TensorArg> t);
|
| 142 |
+
|
| 143 |
+
// FixMe: does TensorArg slow things down?
|
| 144 |
+
TORCH_API void checkBackend(
|
| 145 |
+
CheckedFrom c,
|
| 146 |
+
at::ArrayRef<Tensor> t,
|
| 147 |
+
at::Backend backend);
|
| 148 |
+
|
| 149 |
+
TORCH_API void checkDeviceType(
|
| 150 |
+
CheckedFrom c,
|
| 151 |
+
at::ArrayRef<Tensor> tensors,
|
| 152 |
+
at::DeviceType device_type);
|
| 153 |
+
|
| 154 |
+
TORCH_API void checkLayout(CheckedFrom c, const Tensor& t, Layout layout);
|
| 155 |
+
|
| 156 |
+
TORCH_API void checkLayout(
|
| 157 |
+
CheckedFrom c,
|
| 158 |
+
at::ArrayRef<Tensor> tensors,
|
| 159 |
+
at::Layout layout);
|
| 160 |
+
|
| 161 |
+
// Methods for getting data_ptr if tensor is defined
|
| 162 |
+
TORCH_API void* maybe_data_ptr(const Tensor& tensor);
|
| 163 |
+
TORCH_API void* maybe_data_ptr(const TensorArg& tensor);
|
| 164 |
+
|
| 165 |
+
TORCH_API void check_dim_size(
|
| 166 |
+
const Tensor& tensor,
|
| 167 |
+
int64_t dim,
|
| 168 |
+
int64_t dim_size,
|
| 169 |
+
int64_t size);
|
| 170 |
+
|
| 171 |
+
namespace detail {
|
| 172 |
+
TORCH_API std::vector<int64_t> defaultStrides(IntArrayRef sizes);
|
| 173 |
+
|
| 174 |
+
TORCH_API std::optional<std::vector<int64_t>> computeStride(
|
| 175 |
+
IntArrayRef oldshape,
|
| 176 |
+
IntArrayRef oldstride,
|
| 177 |
+
IntArrayRef newshape);
|
| 178 |
+
|
| 179 |
+
TORCH_API std::optional<SymDimVector> computeStride(
|
| 180 |
+
c10::SymIntArrayRef oldshape,
|
| 181 |
+
c10::SymIntArrayRef oldstride,
|
| 182 |
+
c10::SymIntArrayRef newshape);
|
| 183 |
+
|
| 184 |
+
TORCH_API std::optional<DimVector> computeStride(
|
| 185 |
+
IntArrayRef oldshape,
|
| 186 |
+
IntArrayRef oldstride,
|
| 187 |
+
const DimVector& newshape);
|
| 188 |
+
|
| 189 |
+
} // namespace detail
|
| 190 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/TypeDefault.h
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/Dimname.h>
|
| 4 |
+
#include <c10/core/MemoryFormat.h>
|
| 5 |
+
#include <c10/core/QScheme.h>
|
| 6 |
+
#include <c10/core/Scalar.h>
|
| 7 |
+
#include <c10/core/TensorOptions.h>
|
| 8 |
+
#include <c10/macros/Export.h>
|
| 9 |
+
#include <c10/util/ArrayRef.h>
|
| 10 |
+
#include <c10/util/intrusive_ptr.h>
|
| 11 |
+
|
| 12 |
+
namespace c10 {
|
| 13 |
+
struct Storage;
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
class Tensor;
|
| 19 |
+
using TensorList = ArrayRef<Tensor>;
|
| 20 |
+
|
| 21 |
+
class Context;
|
| 22 |
+
struct Generator;
|
| 23 |
+
|
| 24 |
+
struct Quantizer;
|
| 25 |
+
// This is temporary typedef to enable Quantizer in aten native function API
|
| 26 |
+
// we'll remove them when we are actually exposing Quantizer class
|
| 27 |
+
// to frontend
|
| 28 |
+
using ConstQuantizerPtr = const c10::intrusive_ptr<Quantizer>&;
|
| 29 |
+
|
| 30 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/Utils.h
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/EmptyTensor.h>
|
| 4 |
+
#include <ATen/Formatting.h>
|
| 5 |
+
#include <ATen/core/ATenGeneral.h>
|
| 6 |
+
#include <ATen/core/Generator.h>
|
| 7 |
+
#include <c10/core/ScalarType.h>
|
| 8 |
+
#include <c10/core/StorageImpl.h>
|
| 9 |
+
#include <c10/core/UndefinedTensorImpl.h>
|
| 10 |
+
#include <c10/util/ArrayRef.h>
|
| 11 |
+
#include <c10/util/Exception.h>
|
| 12 |
+
#include <c10/util/accumulate.h>
|
| 13 |
+
#include <c10/util/irange.h>
|
| 14 |
+
|
| 15 |
+
#include <algorithm>
|
| 16 |
+
|
| 17 |
+
#define AT_DISALLOW_COPY_AND_ASSIGN(TypeName) \
|
| 18 |
+
TypeName(const TypeName&) = delete; \
|
| 19 |
+
void operator=(const TypeName&) = delete
|
| 20 |
+
|
| 21 |
+
namespace at {
|
| 22 |
+
|
| 23 |
+
TORCH_API int _crash_if_asan(int);
|
| 24 |
+
|
| 25 |
+
// Converts a TensorList (i.e. ArrayRef<Tensor> to vector of TensorImpl*)
|
| 26 |
+
// NB: This is ONLY used by legacy TH bindings, and ONLY used by cat.
|
| 27 |
+
// Once cat is ported entirely to ATen this can be deleted!
|
| 28 |
+
inline std::vector<TensorImpl*> checked_dense_tensor_list_unwrap(
|
| 29 |
+
ArrayRef<Tensor> tensors,
|
| 30 |
+
const char* name,
|
| 31 |
+
int pos,
|
| 32 |
+
c10::DeviceType device_type,
|
| 33 |
+
ScalarType scalar_type) {
|
| 34 |
+
std::vector<TensorImpl*> unwrapped;
|
| 35 |
+
unwrapped.reserve(tensors.size());
|
| 36 |
+
for (const auto i : c10::irange(tensors.size())) {
|
| 37 |
+
const auto& expr = tensors[i];
|
| 38 |
+
if (expr.layout() != Layout::Strided) {
|
| 39 |
+
AT_ERROR(
|
| 40 |
+
"Expected dense tensor but got ",
|
| 41 |
+
expr.layout(),
|
| 42 |
+
" for sequence element ",
|
| 43 |
+
i,
|
| 44 |
+
" in sequence argument at position #",
|
| 45 |
+
pos,
|
| 46 |
+
" '",
|
| 47 |
+
name,
|
| 48 |
+
"'");
|
| 49 |
+
}
|
| 50 |
+
if (expr.device().type() != device_type) {
|
| 51 |
+
AT_ERROR(
|
| 52 |
+
"Expected object of device type ",
|
| 53 |
+
device_type,
|
| 54 |
+
" but got device type ",
|
| 55 |
+
expr.device().type(),
|
| 56 |
+
" for sequence element ",
|
| 57 |
+
i,
|
| 58 |
+
" in sequence argument at position #",
|
| 59 |
+
pos,
|
| 60 |
+
" '",
|
| 61 |
+
name,
|
| 62 |
+
"'");
|
| 63 |
+
}
|
| 64 |
+
if (expr.scalar_type() != scalar_type) {
|
| 65 |
+
AT_ERROR(
|
| 66 |
+
"Expected object of scalar type ",
|
| 67 |
+
scalar_type,
|
| 68 |
+
" but got scalar type ",
|
| 69 |
+
expr.scalar_type(),
|
| 70 |
+
" for sequence element ",
|
| 71 |
+
i,
|
| 72 |
+
" in sequence argument at position #",
|
| 73 |
+
pos,
|
| 74 |
+
" '",
|
| 75 |
+
name,
|
| 76 |
+
"'");
|
| 77 |
+
}
|
| 78 |
+
unwrapped.emplace_back(expr.unsafeGetTensorImpl());
|
| 79 |
+
}
|
| 80 |
+
return unwrapped;
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
template <size_t N>
|
| 84 |
+
std::array<int64_t, N> check_intlist(
|
| 85 |
+
ArrayRef<int64_t> list,
|
| 86 |
+
const char* name,
|
| 87 |
+
int pos) {
|
| 88 |
+
if (list.empty()) {
|
| 89 |
+
// TODO: is this necessary? We used to treat nullptr-vs-not in IntList
|
| 90 |
+
// differently with strides as a way of faking optional.
|
| 91 |
+
list = {};
|
| 92 |
+
}
|
| 93 |
+
auto res = std::array<int64_t, N>();
|
| 94 |
+
if (list.size() == 1 && N > 1) {
|
| 95 |
+
res.fill(list[0]);
|
| 96 |
+
return res;
|
| 97 |
+
}
|
| 98 |
+
if (list.size() != N) {
|
| 99 |
+
AT_ERROR(
|
| 100 |
+
"Expected a list of ",
|
| 101 |
+
N,
|
| 102 |
+
" ints but got ",
|
| 103 |
+
list.size(),
|
| 104 |
+
" for argument #",
|
| 105 |
+
pos,
|
| 106 |
+
" '",
|
| 107 |
+
name,
|
| 108 |
+
"'");
|
| 109 |
+
}
|
| 110 |
+
std::copy_n(list.begin(), N, res.begin());
|
| 111 |
+
return res;
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
using at::detail::check_size_nonnegative;
|
| 115 |
+
|
| 116 |
+
namespace detail {
|
| 117 |
+
|
| 118 |
+
template <typename T>
|
| 119 |
+
TORCH_API Tensor tensor_cpu(ArrayRef<T> values, const TensorOptions& options);
|
| 120 |
+
|
| 121 |
+
template <typename T>
|
| 122 |
+
TORCH_API Tensor
|
| 123 |
+
tensor_backend(ArrayRef<T> values, const TensorOptions& options);
|
| 124 |
+
|
| 125 |
+
template <typename T>
|
| 126 |
+
TORCH_API Tensor
|
| 127 |
+
tensor_complex_cpu(ArrayRef<T> values, const TensorOptions& options);
|
| 128 |
+
|
| 129 |
+
template <typename T>
|
| 130 |
+
TORCH_API Tensor
|
| 131 |
+
tensor_complex_backend(ArrayRef<T> values, const TensorOptions& options);
|
| 132 |
+
} // namespace detail
|
| 133 |
+
|
| 134 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/VmapGeneratedPlumbing.h
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cpp_custom_type_hack.h
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 2 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 3 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 4 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 5 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 6 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 7 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 8 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 9 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 10 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 11 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 12 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 13 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 14 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 15 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 16 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 17 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 18 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 19 |
+
|
| 20 |
+
// YOU ARE IN THE WRONG PLACE! TURN BACK NOW!
|
| 21 |
+
|
| 22 |
+
// This code was a temporary hack to enable embedding arbitrary C++ structures
|
| 23 |
+
// into Tensors. THIS IS UNSAFE AND IS NOT SUPPORTED. IF YOU USE THIS CODE,
|
| 24 |
+
// IT __WILL__ BREAK.
|
| 25 |
+
|
| 26 |
+
// This code has been superseded by custom classes:
|
| 27 |
+
// https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html
|
| 28 |
+
|
| 29 |
+
// Please use custom classes and **DO NOT ADD MORE CALLSITES TO THINGS DEFINED
|
| 30 |
+
// IN THIS FILE**.
|
| 31 |
+
|
| 32 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 33 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 34 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 35 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 36 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 37 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 38 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 39 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 40 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 41 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 42 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 43 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 44 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 45 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 46 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 47 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 48 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 49 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 50 |
+
|
| 51 |
+
#include <ATen/TracerMode.h>
|
| 52 |
+
#include <ATen/core/Tensor.h>
|
| 53 |
+
|
| 54 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 55 |
+
#include <ATen/Functions.h>
|
| 56 |
+
#else
|
| 57 |
+
#include <ATen/ops/empty.h>
|
| 58 |
+
#endif
|
| 59 |
+
|
| 60 |
+
namespace at::cpp_custom_type_hack {
|
| 61 |
+
|
| 62 |
+
template <typename T>
|
| 63 |
+
[[deprecated(
|
| 64 |
+
"Use custom classes instead: "
|
| 65 |
+
"https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html")]] bool
|
| 66 |
+
isa(const Tensor& packed) {
|
| 67 |
+
return (packed.scalar_type() == kByte) &&
|
| 68 |
+
(packed.storage().data_ptr().get_deleter() ==
|
| 69 |
+
caffe2::TypeMeta::Make<T>().deleteFn());
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
template <typename T>
|
| 73 |
+
[[deprecated(
|
| 74 |
+
"Use custom classes instead: "
|
| 75 |
+
"https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html")]] T&
|
| 76 |
+
cast(const Tensor& packed) {
|
| 77 |
+
TORCH_CHECK(
|
| 78 |
+
packed.scalar_type() == kByte, "Expected temporary cpp type wrapper");
|
| 79 |
+
TORCH_CHECK(
|
| 80 |
+
packed.storage().data_ptr().get_deleter() ==
|
| 81 |
+
caffe2::TypeMeta::Make<T>().deleteFn(),
|
| 82 |
+
"Expected temporary cpp type wrapper of type ",
|
| 83 |
+
caffe2::TypeMeta::TypeName<T>());
|
| 84 |
+
return *reinterpret_cast<T*>(packed.storage().data_ptr().get());
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
template <typename T>
|
| 88 |
+
[[deprecated(
|
| 89 |
+
"Use custom classes instead: "
|
| 90 |
+
"https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html")]] Tensor
|
| 91 |
+
create(std::unique_ptr<T> ptr, TensorOptions options) {
|
| 92 |
+
// None of this should trace, so turn off Tracer dispatching
|
| 93 |
+
at::AutoDispatchBelowADInplaceOrView guard; // TODO: remove
|
| 94 |
+
at::tracer::impl::NoTracerDispatchMode tracer_guard;
|
| 95 |
+
|
| 96 |
+
// We store this instance away in a Tensor and register a deleter function
|
| 97 |
+
// so that we do not leak memory. On the other side, we pull out the storage's
|
| 98 |
+
// data_ptr and get the right typed pointer.
|
| 99 |
+
void* raw_ptr = ptr.release();
|
| 100 |
+
at::DataPtr at_ptr(
|
| 101 |
+
raw_ptr, raw_ptr, caffe2::TypeMeta::Make<T>().deleteFn(), at::kCPU);
|
| 102 |
+
|
| 103 |
+
// size doesn't really matter, but we can align it to the actual size
|
| 104 |
+
// returning variables because one likely want to use this hack from python
|
| 105 |
+
auto retval = at::empty({sizeof(T)}, options.device(kCPU).dtype(at::kByte));
|
| 106 |
+
retval.storage().set_data_ptr_noswap(std::move(at_ptr));
|
| 107 |
+
return retval;
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
} // namespace at::cpp_custom_type_hack
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/ATenCUDAGeneral.h
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cuda.h>
|
| 4 |
+
#include <cuda_runtime.h>
|
| 5 |
+
#include <cuda_fp16.h>
|
| 6 |
+
|
| 7 |
+
#include <c10/macros/Export.h>
|
| 8 |
+
|
| 9 |
+
// Use TORCH_CUDA_CPP_API or TORCH_CUDA_CU_API for exports from this folder
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/ApplyGridUtils.cuh
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 2 |
+
|
| 3 |
+
#include <cuda_runtime.h>
|
| 4 |
+
|
| 5 |
+
namespace at::cuda {
|
| 6 |
+
|
| 7 |
+
/**
|
| 8 |
+
Computes ceil(a / b)
|
| 9 |
+
*/
|
| 10 |
+
template <typename T>
|
| 11 |
+
__host__ __device__ __forceinline__ T ATenCeilDiv(T a, T b) {
|
| 12 |
+
return (a + b - 1) / b;
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
namespace {
|
| 16 |
+
|
| 17 |
+
// Threads per block for our apply kernel
|
| 18 |
+
// FIXME: use occupancy calculator instead
|
| 19 |
+
constexpr uint32_t AT_APPLY_THREADS_PER_BLOCK = 512;
|
| 20 |
+
constexpr uint32_t AT_APPLY_BLOCKS_PER_SM = 4;
|
| 21 |
+
|
| 22 |
+
template <int step = 1>
|
| 23 |
+
inline bool getApplyGrid(uint64_t totalElements, dim3& grid, c10::DeviceIndex curDevice, int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK) {
|
| 24 |
+
if (curDevice == -1) return false;
|
| 25 |
+
uint64_t numel_per_thread = static_cast<uint64_t>(max_threads_per_block) * static_cast<uint64_t>(step);
|
| 26 |
+
uint64_t numBlocks = ATenCeilDiv(totalElements, numel_per_thread);
|
| 27 |
+
uint64_t maxGridX = at::cuda::getDeviceProperties(curDevice)->maxGridSize[0];
|
| 28 |
+
if (numBlocks > maxGridX)
|
| 29 |
+
numBlocks = maxGridX;
|
| 30 |
+
grid = dim3(numBlocks);
|
| 31 |
+
return true;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
constexpr int getApplyBlocksPerSM() {
|
| 35 |
+
return AT_APPLY_BLOCKS_PER_SM;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
constexpr int getApplyBlockSize() {
|
| 39 |
+
return AT_APPLY_THREADS_PER_BLOCK;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
inline dim3 getApplyBlock(int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK) {
|
| 43 |
+
return dim3(max_threads_per_block);
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
} // anonymous namespace
|
| 47 |
+
} // namespace at::cuda
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/Atomic.cuh
ADDED
|
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cuda.h>
|
| 4 |
+
#include <c10/util/Half.h>
|
| 5 |
+
#include <c10/util/BFloat16.h>
|
| 6 |
+
|
| 7 |
+
#include <ATen/NumericUtils.h>
|
| 8 |
+
|
| 9 |
+
#if !(defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
|
| 10 |
+
#include <cuda_bf16.h>
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
template <typename T>
|
| 14 |
+
struct AtomicFPOp;
|
| 15 |
+
|
| 16 |
+
template <>
|
| 17 |
+
struct AtomicFPOp<at::Half> {
|
| 18 |
+
template <typename func_t>
|
| 19 |
+
inline __device__ at::Half operator() (at::Half *address, at::Half val, const func_t& func) {
|
| 20 |
+
unsigned int * address_as_ui =
|
| 21 |
+
(unsigned int *) ((char *)address - ((size_t)address & 2));
|
| 22 |
+
unsigned int old = *address_as_ui;
|
| 23 |
+
unsigned int assumed;
|
| 24 |
+
|
| 25 |
+
at::Half hsum;
|
| 26 |
+
do {
|
| 27 |
+
assumed = old;
|
| 28 |
+
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
| 29 |
+
hsum = func(hsum, val);
|
| 30 |
+
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
|
| 31 |
+
old = atomicCAS(address_as_ui, assumed, old);
|
| 32 |
+
} while (assumed != old);
|
| 33 |
+
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
| 34 |
+
return hsum;
|
| 35 |
+
}
|
| 36 |
+
};
|
| 37 |
+
|
| 38 |
+
template <>
|
| 39 |
+
struct AtomicFPOp<at::BFloat16> {
|
| 40 |
+
template <typename func_t>
|
| 41 |
+
inline __device__ at::BFloat16 operator() (at::BFloat16 *address, at::BFloat16 val, const func_t& func) {
|
| 42 |
+
unsigned int * address_as_ui =
|
| 43 |
+
(unsigned int *) ((char *)address - ((size_t)address & 2));
|
| 44 |
+
unsigned int old = *address_as_ui;
|
| 45 |
+
unsigned int assumed;
|
| 46 |
+
|
| 47 |
+
at::BFloat16 bsum;
|
| 48 |
+
do {
|
| 49 |
+
assumed = old;
|
| 50 |
+
bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
| 51 |
+
bsum = func(bsum, val);
|
| 52 |
+
old = (size_t)address & 2 ? (old & 0xffff) | (bsum.x << 16) : (old & 0xffff0000) | bsum.x;
|
| 53 |
+
old = atomicCAS(address_as_ui, assumed, old);
|
| 54 |
+
} while (assumed != old);
|
| 55 |
+
bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
| 56 |
+
return bsum.x;
|
| 57 |
+
}
|
| 58 |
+
};
|
| 59 |
+
|
| 60 |
+
template <>
|
| 61 |
+
struct AtomicFPOp<double> {
|
| 62 |
+
template <typename func_t>
|
| 63 |
+
inline __device__ double operator() (double * address, double val, const func_t& func) {
|
| 64 |
+
unsigned long long int* address_as_ull = (unsigned long long int*)address;
|
| 65 |
+
unsigned long long int old = *address_as_ull;
|
| 66 |
+
unsigned long long int assumed;
|
| 67 |
+
|
| 68 |
+
do {
|
| 69 |
+
assumed = old;
|
| 70 |
+
old = atomicCAS(address_as_ull, assumed, func(val, assumed));
|
| 71 |
+
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
|
| 72 |
+
} while (assumed != old);
|
| 73 |
+
|
| 74 |
+
return __longlong_as_double(old);
|
| 75 |
+
}
|
| 76 |
+
};
|
| 77 |
+
|
| 78 |
+
#define ATOMIC_INTEGER_IMPL(NAME) \
|
| 79 |
+
template <typename T, size_t n> \
|
| 80 |
+
struct Atomic##NAME##IntegerImpl; \
|
| 81 |
+
\
|
| 82 |
+
template<typename T> \
|
| 83 |
+
struct Atomic##NAME##IntegerImpl<T, 1> { \
|
| 84 |
+
template <typename func_t> \
|
| 85 |
+
inline __device__ void operator()(T *address, T val, const func_t& func) { \
|
| 86 |
+
size_t offset = (size_t)address & 3; \
|
| 87 |
+
uint32_t * address_as_ui = (uint32_t *)((char *)address - offset); \
|
| 88 |
+
uint32_t old = *address_as_ui; \
|
| 89 |
+
uint32_t shift = offset * 8; \
|
| 90 |
+
uint32_t old_byte; \
|
| 91 |
+
uint32_t newval; \
|
| 92 |
+
uint32_t assumed; \
|
| 93 |
+
\
|
| 94 |
+
do { \
|
| 95 |
+
assumed = old; \
|
| 96 |
+
old_byte = (old >> shift) & 0xff; \
|
| 97 |
+
newval = static_cast<uint8_t>(func(val, static_cast<T>(old_byte))); \
|
| 98 |
+
newval = (old & ~(0x000000ff << shift)) | (newval << shift); \
|
| 99 |
+
old = atomicCAS(address_as_ui, assumed, newval); \
|
| 100 |
+
} while (assumed != old); \
|
| 101 |
+
} \
|
| 102 |
+
}; \
|
| 103 |
+
\
|
| 104 |
+
template<typename T> \
|
| 105 |
+
struct Atomic##NAME##IntegerImpl<T, 2> { \
|
| 106 |
+
template <typename func_t> \
|
| 107 |
+
inline __device__ void operator()(T *address, T val, const func_t& func) { \
|
| 108 |
+
size_t offset = (size_t)address & 2; \
|
| 109 |
+
uint32_t * address_as_ui = (uint32_t *)((char *)address - offset); \
|
| 110 |
+
bool is_32_align = offset; \
|
| 111 |
+
uint32_t old = *address_as_ui; \
|
| 112 |
+
uint32_t old_bytes; \
|
| 113 |
+
uint32_t newval; \
|
| 114 |
+
uint32_t assumed; \
|
| 115 |
+
\
|
| 116 |
+
do { \
|
| 117 |
+
assumed = old; \
|
| 118 |
+
old_bytes = is_32_align ? old >> 16 : old & 0xffff; \
|
| 119 |
+
newval = static_cast<uint16_t>(func(val, static_cast<T>(old_bytes))); \
|
| 120 |
+
newval = is_32_align ? (old & 0xffff) | (newval << 16) : (old & 0xffff0000) | newval; \
|
| 121 |
+
old = atomicCAS(address_as_ui, assumed, newval); \
|
| 122 |
+
} while (assumed != old); \
|
| 123 |
+
} \
|
| 124 |
+
}; \
|
| 125 |
+
\
|
| 126 |
+
template<typename T> \
|
| 127 |
+
struct Atomic##NAME##IntegerImpl<T, 4> { \
|
| 128 |
+
template <typename func_t> \
|
| 129 |
+
inline __device__ void operator()(T *address, T val, const func_t& func) { \
|
| 130 |
+
uint32_t * address_as_ui = (uint32_t *) (address); \
|
| 131 |
+
uint32_t old = *address_as_ui; \
|
| 132 |
+
uint32_t newval; \
|
| 133 |
+
uint32_t assumed; \
|
| 134 |
+
\
|
| 135 |
+
do { \
|
| 136 |
+
assumed = old; \
|
| 137 |
+
newval = static_cast<uint32_t>(func(val, static_cast<T>(old))); \
|
| 138 |
+
old = atomicCAS(address_as_ui, assumed, newval); \
|
| 139 |
+
} while (assumed != old); \
|
| 140 |
+
} \
|
| 141 |
+
}; \
|
| 142 |
+
\
|
| 143 |
+
template<typename T> \
|
| 144 |
+
struct Atomic##NAME##IntegerImpl<T, 8> { \
|
| 145 |
+
template <typename func_t> \
|
| 146 |
+
inline __device__ void operator()(T *address, T val, const func_t& func) { \
|
| 147 |
+
unsigned long long * address_as_ui = (unsigned long long *) (address); \
|
| 148 |
+
unsigned long long old = *address_as_ui; \
|
| 149 |
+
unsigned long long newval; \
|
| 150 |
+
unsigned long long assumed; \
|
| 151 |
+
\
|
| 152 |
+
do { \
|
| 153 |
+
assumed = old; \
|
| 154 |
+
newval = static_cast<uint64_t>(func(val, static_cast<T>(old))); \
|
| 155 |
+
old = atomicCAS(address_as_ui, assumed, newval); \
|
| 156 |
+
} while (assumed != old); \
|
| 157 |
+
} \
|
| 158 |
+
};
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# define GPU_ATOMIC_INTEGER(NAME, OP, DTYPE) \
|
| 162 |
+
inline __device__ void gpuAtomic##NAME(DTYPE *address, DTYPE val) { \
|
| 163 |
+
Atomic##NAME##IntegerImpl<DTYPE, sizeof(DTYPE)>()(address, \
|
| 164 |
+
val, \
|
| 165 |
+
[](DTYPE a, DTYPE b) { \
|
| 166 |
+
return OP; \
|
| 167 |
+
}); \
|
| 168 |
+
} \
|
| 169 |
+
|
| 170 |
+
ATOMIC_INTEGER_IMPL(Add)
|
| 171 |
+
GPU_ATOMIC_INTEGER(Add, a || b, bool)
|
| 172 |
+
|
| 173 |
+
// Don't instantiate gpuAtomicAdd with the macro as it seems non-standard (see int32, int64)
|
| 174 |
+
inline __device__ void gpuAtomicAdd(uint8_t *address, uint8_t val) {
|
| 175 |
+
AtomicAddIntegerImpl<uint8_t, sizeof(uint8_t)>()(address,
|
| 176 |
+
val,
|
| 177 |
+
[](uint8_t a, uint8_t b) {
|
| 178 |
+
return a + b;
|
| 179 |
+
});
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
inline __device__ void gpuAtomicAdd(int8_t *address, int8_t val) {
|
| 183 |
+
AtomicAddIntegerImpl<int8_t, sizeof(int8_t)>()(address,
|
| 184 |
+
val,
|
| 185 |
+
[](int8_t a, int8_t b) {
|
| 186 |
+
return a + b;
|
| 187 |
+
});
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
inline __device__ void gpuAtomicAdd(int16_t *address, int16_t val) {
|
| 191 |
+
AtomicAddIntegerImpl<int16_t, sizeof(int16_t)>()(address,
|
| 192 |
+
val,
|
| 193 |
+
[](int16_t a, int16_t b) {
|
| 194 |
+
return a + b;
|
| 195 |
+
});
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
inline __device__ int32_t gpuAtomicAdd(int32_t *address, int32_t val) {
|
| 199 |
+
return atomicAdd(address, val);
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
inline __device__ void gpuAtomicAdd(int64_t *address, int64_t val) {
|
| 203 |
+
#if defined(USE_ROCM)
|
| 204 |
+
__atomic_fetch_add(address, val, __ATOMIC_RELAXED);
|
| 205 |
+
#else
|
| 206 |
+
static_assert(sizeof(unsigned long long int) == sizeof(int64_t), "bitwidth change is not allowed");
|
| 207 |
+
atomicAdd(reinterpret_cast<unsigned long long int *>(address), static_cast<unsigned long long int>(val));
|
| 208 |
+
#endif
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
inline __device__ at::Half gpuAtomicAdd(at::Half *address, at::Half val) {
|
| 212 |
+
#if defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
|
| 213 |
+
return AtomicFPOp<at::Half>()(address, val,
|
| 214 |
+
[](at::Half hsum, at::Half val) {
|
| 215 |
+
return hsum + val;
|
| 216 |
+
});
|
| 217 |
+
#else
|
| 218 |
+
return atomicAdd(reinterpret_cast<__half*>(address), val);
|
| 219 |
+
#endif
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
inline __device__ at::BFloat16 gpuAtomicAdd(at::BFloat16 *address, at::BFloat16 val) {
|
| 223 |
+
#if defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))
|
| 224 |
+
return AtomicFPOp<at::BFloat16>()(address, val,
|
| 225 |
+
[](at::BFloat16 bsum, at::BFloat16 val) {
|
| 226 |
+
return bsum + val;
|
| 227 |
+
});
|
| 228 |
+
#else
|
| 229 |
+
__nv_bfloat16 r = atomicAdd(reinterpret_cast<__nv_bfloat16*>(address), *reinterpret_cast<__nv_bfloat16*>(&val));
|
| 230 |
+
return *reinterpret_cast<c10::BFloat16*>(&r);
|
| 231 |
+
#endif
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600)
|
| 235 |
+
// from CUDA C Programmic Guide
|
| 236 |
+
inline __device__ double atomicAdd(double* address, double val)
|
| 237 |
+
#if defined(__clang__) && defined(__CUDA__)
|
| 238 |
+
#pragma GCC diagnostic push
|
| 239 |
+
#pragma GCC diagnostic ignored "-Wgcc-compat"
|
| 240 |
+
__attribute__((enable_if(true, "")))
|
| 241 |
+
#pragma GCC diagnostic pop
|
| 242 |
+
#endif
|
| 243 |
+
{
|
| 244 |
+
|
| 245 |
+
return AtomicFPOp<double>()(address, val,
|
| 246 |
+
[](double val, unsigned long long int assumed) {
|
| 247 |
+
return __double_as_longlong(val + __longlong_as_double(assumed));
|
| 248 |
+
});
|
| 249 |
+
}
|
| 250 |
+
#elif defined(USE_ROCM) || !(defined(__CUDA_ARCH__))
|
| 251 |
+
|
| 252 |
+
/* Note [hip-clang differences to hcc]
|
| 253 |
+
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 254 |
+
* The upcoming hip-clang compiler for ROCm differs from hcc in a few details.
|
| 255 |
+
* It exports the __HIP__ macro, we can hence differentiate between hcc and
|
| 256 |
+
* hip-clang. In the below, hcc only received support for atomicAdd with double
|
| 257 |
+
* typing after work week 18312. hip-clang had support from the first version.
|
| 258 |
+
* In general, the code-visible differences between hip-clang and hcc will be
|
| 259 |
+
* minimal.
|
| 260 |
+
*/
|
| 261 |
+
|
| 262 |
+
#if defined(USE_ROCM) && __hcc_workweek__ < 18312 && !__HIP__
|
| 263 |
+
// This needs to be defined for the host side pass
|
| 264 |
+
inline __device__ double atomicAdd(double *address, double val) { }
|
| 265 |
+
#endif
|
| 266 |
+
#endif
|
| 267 |
+
|
| 268 |
+
inline __device__ double gpuAtomicAdd(double *address, double val) {
|
| 269 |
+
return atomicAdd(address, val);
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
inline __device__ float gpuAtomicAdd(float *address, float val) {
|
| 273 |
+
return atomicAdd(address, val);
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
template<typename T>
|
| 277 |
+
inline __device__ void gpuAtomicAdd(c10::complex<T> *address, c10::complex<T> val) {
|
| 278 |
+
gpuAtomicAdd(&address->real_, val.real_);
|
| 279 |
+
gpuAtomicAdd(&address->imag_, val.imag_);
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
/* Note [gpuAtomicAdd vs atomicAdd]
|
| 283 |
+
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 284 |
+
* Some extensions such as torchvision call atomicAdd()
|
| 285 |
+
* directly and require non-library provided data type support. Only for these, we
|
| 286 |
+
* continue to provide atomicAdd overloads.
|
| 287 |
+
*/
|
| 288 |
+
inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) {
|
| 289 |
+
return gpuAtomicAdd(address, val);
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
inline __device__ at::BFloat16 atomicAdd(at::BFloat16 *address, at::BFloat16 val) {
|
| 293 |
+
return gpuAtomicAdd(address, val);
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
inline __device__ void atomicAdd(uint8_t *address, uint8_t val) {
|
| 297 |
+
gpuAtomicAdd(address, val);
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
inline __device__ void atomicAdd(int8_t *address, int8_t val) {
|
| 301 |
+
gpuAtomicAdd(address, val);
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
inline __device__ void atomicAdd(int16_t *address, int16_t val) {
|
| 305 |
+
gpuAtomicAdd(address, val);
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
inline __device__ void atomicAdd(int64_t *address, int64_t val) {
|
| 309 |
+
gpuAtomicAdd(address, val);
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
inline __device__ void atomicAdd(bool *address, bool val) {
|
| 313 |
+
gpuAtomicAdd(address, val);
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
/* Note [explicitly non-returning atomics]
|
| 317 |
+
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 318 |
+
* AMD's MI100 (gfx908) provides an optimized fp32 atomicAdd, exposed via atomicAddNoRet().
|
| 319 |
+
* Due to compiler limitations, callers must opt-in to guarantee the optimized instruction.
|
| 320 |
+
* This non-returning atomicAddNoRet cannot be used to implement the returning atomicAdd,
|
| 321 |
+
* therefore we need a new API 'gpuAtomicAddNoReturn'.
|
| 322 |
+
*/
|
| 323 |
+
template<typename T>
|
| 324 |
+
inline __device__ void gpuAtomicAddNoReturn(c10::complex<T> *address, c10::complex<T> val) { gpuAtomicAdd(address, val); }
|
| 325 |
+
inline __device__ void gpuAtomicAddNoReturn(uint8_t *address, uint8_t val) { gpuAtomicAdd(address, val); }
|
| 326 |
+
inline __device__ void gpuAtomicAddNoReturn(int8_t *address, int8_t val) { gpuAtomicAdd(address, val); }
|
| 327 |
+
inline __device__ void gpuAtomicAddNoReturn(int16_t *address, int16_t val) { gpuAtomicAdd(address, val); }
|
| 328 |
+
inline __device__ void gpuAtomicAddNoReturn(int32_t *address, int32_t val) { gpuAtomicAdd(address, val); }
|
| 329 |
+
inline __device__ void gpuAtomicAddNoReturn(int64_t *address, int64_t val) { gpuAtomicAdd(address, val); }
|
| 330 |
+
inline __device__ void gpuAtomicAddNoReturn(bool *address, bool val) { gpuAtomicAdd(address, val); }
|
| 331 |
+
inline __device__ void gpuAtomicAddNoReturn(at::Half *address, at::Half val) { gpuAtomicAdd(address, val); }
|
| 332 |
+
inline __device__ void gpuAtomicAddNoReturn(at::BFloat16 *address, at::BFloat16 val) { gpuAtomicAdd(address, val); }
|
| 333 |
+
inline __device__ void gpuAtomicAddNoReturn(double *address, double val) { gpuAtomicAdd(address, val); }
|
| 334 |
+
|
| 335 |
+
/* Special case fp32 atomic. */
|
| 336 |
+
#if defined(USE_ROCM)
|
| 337 |
+
inline __device__ void gpuAtomicAddNoReturn(float *address, float val) {
|
| 338 |
+
#if defined(__gfx908__)
|
| 339 |
+
atomicAddNoRet(address, val);
|
| 340 |
+
#else
|
| 341 |
+
(void)unsafeAtomicAdd(address, val);
|
| 342 |
+
#endif
|
| 343 |
+
}
|
| 344 |
+
#else
|
| 345 |
+
inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { gpuAtomicAdd(address, val); }
|
| 346 |
+
#endif
|
| 347 |
+
|
| 348 |
+
// Atomic multiplication implementation.
|
| 349 |
+
|
| 350 |
+
ATOMIC_INTEGER_IMPL(Mul)
|
| 351 |
+
GPU_ATOMIC_INTEGER(Mul, a * b, uint8_t)
|
| 352 |
+
GPU_ATOMIC_INTEGER(Mul, a * b, int8_t)
|
| 353 |
+
GPU_ATOMIC_INTEGER(Mul, a * b, int16_t)
|
| 354 |
+
GPU_ATOMIC_INTEGER(Mul, a * b, int32_t)
|
| 355 |
+
GPU_ATOMIC_INTEGER(Mul, a * b, int64_t)
|
| 356 |
+
|
| 357 |
+
inline __device__ at::Half gpuAtomicMul(at::Half * address, at::Half val) {
|
| 358 |
+
return AtomicFPOp<at::Half>()(address, val,
|
| 359 |
+
[](at::Half bsum, at::Half val) {
|
| 360 |
+
return bsum * val;
|
| 361 |
+
});
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
inline __device__ at::BFloat16 gpuAtomicMul(at::BFloat16 * address, at::BFloat16 val) {
|
| 365 |
+
return AtomicFPOp<at::BFloat16>()(address, val,
|
| 366 |
+
[](at::BFloat16 bsum, at::BFloat16 val) {
|
| 367 |
+
return bsum * val;
|
| 368 |
+
});
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
inline __device__ double gpuAtomicMul(double * address, double val) {
|
| 372 |
+
return AtomicFPOp<double>()(address, val,
|
| 373 |
+
[](double val, unsigned long long int assumed) {
|
| 374 |
+
return __double_as_longlong(val * __longlong_as_double(assumed));
|
| 375 |
+
});
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
// Dont use a templated function for this since the addition function defaults to the CUDA built-in.
|
| 379 |
+
inline __device__ float gpuAtomicMul (float * address, float val) {
|
| 380 |
+
unsigned int* address_as_ull = (unsigned int*)address;
|
| 381 |
+
unsigned int old = *address_as_ull;
|
| 382 |
+
unsigned int assumed;
|
| 383 |
+
|
| 384 |
+
do {
|
| 385 |
+
assumed = old;
|
| 386 |
+
old = atomicCAS(address_as_ull, assumed,
|
| 387 |
+
__float_as_int(val *
|
| 388 |
+
__int_as_float(assumed)));
|
| 389 |
+
|
| 390 |
+
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
|
| 391 |
+
} while (assumed != old);
|
| 392 |
+
|
| 393 |
+
return __int_as_float(old);
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
// Atomic maximum implementation.
|
| 397 |
+
|
| 398 |
+
template <typename T>
|
| 399 |
+
__host__ __device__ T safe_max(T a, T b) {
|
| 400 |
+
#if defined(__HIPCC__)
|
| 401 |
+
// TODO: remove this special case for HIP when issue is fixed:
|
| 402 |
+
// https://github.com/ROCm-Developer-Tools/HIP/issues/2209
|
| 403 |
+
T max = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::max<T>(a, b));
|
| 404 |
+
#else
|
| 405 |
+
T max = at::_isnan(b) ? b : std::max<T>(a, b);
|
| 406 |
+
#endif
|
| 407 |
+
|
| 408 |
+
return max;
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
ATOMIC_INTEGER_IMPL(Max)
|
| 412 |
+
GPU_ATOMIC_INTEGER(Max, safe_max(a, b), uint8_t)
|
| 413 |
+
GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int8_t)
|
| 414 |
+
GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int16_t)
|
| 415 |
+
GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int32_t)
|
| 416 |
+
GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int64_t)
|
| 417 |
+
|
| 418 |
+
inline __device__ at::Half gpuAtomicMax(at::Half * address, at::Half val) {
|
| 419 |
+
return AtomicFPOp<at::Half>()(address, val,
|
| 420 |
+
[](at::Half bsum, at::Half val) {
|
| 421 |
+
return safe_max(bsum, val);
|
| 422 |
+
});
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
inline __device__ at::BFloat16 gpuAtomicMax(at::BFloat16 * address, at::BFloat16 val) {
|
| 426 |
+
return AtomicFPOp<at::BFloat16>()(address, val,
|
| 427 |
+
[](at::BFloat16 bsum, at::BFloat16 val) {
|
| 428 |
+
return safe_max(bsum, val);
|
| 429 |
+
});
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
inline __device__ double gpuAtomicMax(double * address, double val) {
|
| 433 |
+
return AtomicFPOp<double>()(address, val,
|
| 434 |
+
[](double val, unsigned long long int assumed) {
|
| 435 |
+
return __double_as_longlong(safe_max(val, __longlong_as_double(assumed)));
|
| 436 |
+
});
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
// Dont use a templated function for this since the addition function defaults to the CUDA built-in.
|
| 440 |
+
inline __device__ float gpuAtomicMax(float * address, float val) {
|
| 441 |
+
unsigned int* address_as_ull = (unsigned int*)address;
|
| 442 |
+
unsigned int old = *address_as_ull;
|
| 443 |
+
unsigned int assumed;
|
| 444 |
+
|
| 445 |
+
do {
|
| 446 |
+
assumed = old;
|
| 447 |
+
old = atomicCAS(address_as_ull, assumed,
|
| 448 |
+
__float_as_int(safe_max(val, __int_as_float(assumed))));
|
| 449 |
+
|
| 450 |
+
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
|
| 451 |
+
} while (assumed != old);
|
| 452 |
+
|
| 453 |
+
return __int_as_float(old);
|
| 454 |
+
}
|
| 455 |
+
|
| 456 |
+
// Atomic minimum implementation.
|
| 457 |
+
|
| 458 |
+
template <typename T>
|
| 459 |
+
__host__ __device__ T safe_min(T a, T b) {
|
| 460 |
+
#if defined(__HIPCC__)
|
| 461 |
+
// TODO: remove this special case for HIP when issue is fixed:
|
| 462 |
+
// https://github.com/ROCm-Developer-Tools/HIP/issues/2209
|
| 463 |
+
T min = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::min<T>(a, b));
|
| 464 |
+
#else
|
| 465 |
+
T min = at::_isnan(b) ? b : std::min<T>(a, b);
|
| 466 |
+
#endif
|
| 467 |
+
|
| 468 |
+
return min;
|
| 469 |
+
}
|
| 470 |
+
|
| 471 |
+
ATOMIC_INTEGER_IMPL(Min)
|
| 472 |
+
GPU_ATOMIC_INTEGER(Min, safe_min(a, b), uint8_t)
|
| 473 |
+
GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int8_t)
|
| 474 |
+
GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int16_t)
|
| 475 |
+
GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int32_t)
|
| 476 |
+
GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int64_t)
|
| 477 |
+
|
| 478 |
+
inline __device__ at::Half gpuAtomicMin(at::Half * address, at::Half val) {
|
| 479 |
+
return AtomicFPOp<at::Half>()(address, val,
|
| 480 |
+
[](at::Half bsum, at::Half val) {
|
| 481 |
+
return safe_min(bsum, val);
|
| 482 |
+
});
|
| 483 |
+
}
|
| 484 |
+
|
| 485 |
+
inline __device__ at::BFloat16 gpuAtomicMin(at::BFloat16 * address, at::BFloat16 val) {
|
| 486 |
+
return AtomicFPOp<at::BFloat16>()(address, val,
|
| 487 |
+
[](at::BFloat16 bsum, at::BFloat16 val) {
|
| 488 |
+
return safe_min(bsum, val);
|
| 489 |
+
});
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
inline __device__ double gpuAtomicMin(double * address, double val) {
|
| 493 |
+
return AtomicFPOp<double>()(address, val,
|
| 494 |
+
[](double val, unsigned long long int assumed) {
|
| 495 |
+
return __double_as_longlong(safe_min(val, __longlong_as_double(assumed)));
|
| 496 |
+
});
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
// Dont use a templated function for this since the addition function defaults to the CUDA built-in.
|
| 500 |
+
inline __device__ float gpuAtomicMin(float * address, float val) {
|
| 501 |
+
unsigned int* address_as_ull = (unsigned int*)address;
|
| 502 |
+
unsigned int old = *address_as_ull;
|
| 503 |
+
unsigned int assumed;
|
| 504 |
+
|
| 505 |
+
do {
|
| 506 |
+
assumed = old;
|
| 507 |
+
old = atomicCAS(address_as_ull, assumed,
|
| 508 |
+
__float_as_int(safe_min(val, __int_as_float(assumed))));
|
| 509 |
+
|
| 510 |
+
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
|
| 511 |
+
} while (assumed != old);
|
| 512 |
+
|
| 513 |
+
return __int_as_float(old);
|
| 514 |
+
}
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAApplyUtils.cuh
ADDED
|
@@ -0,0 +1,537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/cuda/ApplyGridUtils.cuh>
|
| 4 |
+
#include <ATen/cuda/detail/IndexUtils.cuh>
|
| 5 |
+
#include <ATen/core/TensorBase.h>
|
| 6 |
+
#include <ATen/ceil_div.h>
|
| 7 |
+
#include <ATen/cuda/Atomic.cuh>
|
| 8 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 9 |
+
#include <c10/macros/Macros.h>
|
| 10 |
+
#include <ATen/native/Copy.h>
|
| 11 |
+
|
| 12 |
+
#include <math.h>
|
| 13 |
+
|
| 14 |
+
//
|
| 15 |
+
// This file contains pointwise operation functions and kernels that
|
| 16 |
+
// work on both contiguous and non-contiguous tensor arguments of
|
| 17 |
+
// arbitrary (up to MAX_CUTORCH_DIMS) dimensioned arguments without
|
| 18 |
+
// copying or temporary storage.
|
| 19 |
+
//
|
| 20 |
+
|
| 21 |
+
/*
|
| 22 |
+
NOTE [ CUDA_tensor_applyN helpers ]
|
| 23 |
+
|
| 24 |
+
The following CUDA_tensor_applyN (where N currently can be 1, 2, 3, or 4)
|
| 25 |
+
functions apply a pointwise operator to N tensor(s).
|
| 26 |
+
|
| 27 |
+
The calling convention is
|
| 28 |
+
|
| 29 |
+
1. The template arguments should be, sequentially,
|
| 30 |
+
- First N typename args specify the scalar types of each of the N tensors.
|
| 31 |
+
- (Optional) `int step` arg specifies the number of elements processed
|
| 32 |
+
together at the same time.
|
| 33 |
+
Default is 1.
|
| 34 |
+
- A usually omitted (i.e., inferred) typename arg specifies the type of the
|
| 35 |
+
function/functor applied on `N * step` values in each iteration of each
|
| 36 |
+
CUDA thread.
|
| 37 |
+
2. The arguments should be, sequentially,
|
| 38 |
+
- N tensors
|
| 39 |
+
- op: a function/functor that processes `N * step` values at the same time.
|
| 40 |
+
- If `step == 1`, it must have signature
|
| 41 |
+
`void(*)(scalar1_t&, scalar2_t&, ..., scalarN_t&)`, where
|
| 42 |
+
`scalar*_t`s are the first N typename template args, and the inputs
|
| 43 |
+
are the `N` values from the `N` tensors retrieved at a common index.
|
| 44 |
+
- Otherwise, it must must have signature
|
| 45 |
+
void(*)(int n, scalar1_t&, scalar1_t&, ..., scalar1_t&, // repeat `step` times
|
| 46 |
+
scalar2_t&, scalar2_t&, ..., scalar2_t&, // repeat `step` times
|
| 47 |
+
...,
|
| 48 |
+
scalarN_t&, scalarN_t&, ..., scalarN_t&) // repeat `step` times
|
| 49 |
+
Different from `step == 1` case, it processes `N * step` values taken
|
| 50 |
+
from `step` common indices. Moreover, the first input `n` represents the
|
| 51 |
+
number of valid indices (it will always have `0 < n <= step`). It will
|
| 52 |
+
almost always be `step`, but at the boundary we may not have full `step`
|
| 53 |
+
elements and `n` can be a lesser value.
|
| 54 |
+
|
| 55 |
+
E.g., if `step == 4` and `N == 2`, `op` could be
|
| 56 |
+
|
| 57 |
+
[](int n, scalar1_t &u1, scalar1_t &u2, scalar1_t &u3, scalar1_t &u4,
|
| 58 |
+
scalar2_t &v1, scalar2_t &v2, scalar2_t &v3, scalar2_t &v4) {
|
| 59 |
+
// Only process u1, ..., un and v1, ..., vn.
|
| 60 |
+
// So if `n == 3`, `u4` and `v4` need not to be considered.
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
In both cases, the references can actually be const, but at least one of
|
| 64 |
+
them should be non-const in order to write the output.
|
| 65 |
+
- (Optional, but recommended) N TensorArgType args that specify for each
|
| 66 |
+
tensor whether `op` reads AND writes ] (i.e., TensorArgType::ReadWrite),
|
| 67 |
+
or only reads (i.e., TensorArgType::ReadOnly).
|
| 68 |
+
Default is TensorArgType::ReadWrite for first Tensor, and
|
| 69 |
+
TensorArgType::ReadOnly for the rest.
|
| 70 |
+
|
| 71 |
+
E.g.,
|
| 72 |
+
|
| 73 |
+
to compute a = b^2 for a and b of same dtype, we can call
|
| 74 |
+
|
| 75 |
+
CUDA_tensor_apply2<scalar, scalar>(
|
| 76 |
+
a, b,
|
| 77 |
+
[] __device__ (scalar &a_val, const scalar &b_val) { a_val = b_val * b_val; }
|
| 78 |
+
);
|
| 79 |
+
|
| 80 |
+
to work on 2 values at the same time, we can call
|
| 81 |
+
|
| 82 |
+
CUDA_tensor_apply2<scalar1, scalar2, 2>(
|
| 83 |
+
a, b,
|
| 84 |
+
[] __device__ (int n, scalar1 &a_val1, scalar1 &a_val2,
|
| 85 |
+
const scalar2 &b_val1, const scalar2 &b_val2) {
|
| 86 |
+
// call special vectorized op here, or just do elementwise and enjoy unrolling...
|
| 87 |
+
// if n == 1, only process a_val1 and b_val1
|
| 88 |
+
}
|
| 89 |
+
);
|
| 90 |
+
*/
|
| 91 |
+
|
| 92 |
+
namespace at::cuda {
|
| 93 |
+
|
| 94 |
+
// TODO: combine with TensorArg? So far that's been for debugging, and this is functional...
|
| 95 |
+
enum class TensorArgType { ReadWrite, ReadOnly };
|
| 96 |
+
|
| 97 |
+
namespace {
|
| 98 |
+
|
| 99 |
+
// Rearrange dimensions for pointwise operations so that strides are in
|
| 100 |
+
// decreasing order as much as possible, so that kernels have better memory
|
| 101 |
+
// access patterns.
|
| 102 |
+
//
|
| 103 |
+
// For example, consider a binary operation on two "transposed" 2-dim tensors:
|
| 104 |
+
// sizes: 256 512
|
| 105 |
+
// aInfo->strides: 1 256
|
| 106 |
+
// bInfo->strides: 1 256
|
| 107 |
+
//
|
| 108 |
+
// Given this, each concurrent memory access inside kernelPointwiseApply2() is
|
| 109 |
+
// exactly 256 elements apart, resulting in poor performance.
|
| 110 |
+
//
|
| 111 |
+
// This function exchanges dimensions so that memory access is contiguous:
|
| 112 |
+
// sizes: 512 256
|
| 113 |
+
// aInfo->strides: 256 1
|
| 114 |
+
// bInfo->strides: 256 1
|
| 115 |
+
//
|
| 116 |
+
// (Actually, it becomes even better because now collapseDims() can turn each
|
| 117 |
+
// input into one contiguous array.)
|
| 118 |
+
//
|
| 119 |
+
// In general, given M (<=4) TensorInfo's with N dimensions, we can view each
|
| 120 |
+
// strides[i] (0 <= i < N) as an M-tuple. Given each pair i < j, we exchange
|
| 121 |
+
// strides[i] and [j] if
|
| 122 |
+
// (1) strides[i][k] < strides[j][k] for some k (0 <= k < M)
|
| 123 |
+
// (exchanging them will benefit input #k), and
|
| 124 |
+
// (2) strides[i][k] <= strieds[j][k] for all k
|
| 125 |
+
// (exchanging them will not make any input worse).
|
| 126 |
+
template <typename T1, typename IndexType,
|
| 127 |
+
typename T2 = void, typename T3 = void, typename T4 = void>
|
| 128 |
+
inline void rearrangeDims(detail::TensorInfo<T1, IndexType>* aInfo,
|
| 129 |
+
detail::TensorInfo<T2, IndexType>* bInfo = nullptr,
|
| 130 |
+
detail::TensorInfo<T3, IndexType>* cInfo = nullptr,
|
| 131 |
+
detail::TensorInfo<T4, IndexType>* dInfo = nullptr) {
|
| 132 |
+
int numInfos = 1;
|
| 133 |
+
int dims = aInfo->dims;
|
| 134 |
+
IndexType *sizes[4] = { aInfo->sizes, };
|
| 135 |
+
IndexType *strides[4] = { aInfo->strides, };
|
| 136 |
+
|
| 137 |
+
if (bInfo != nullptr) {
|
| 138 |
+
++numInfos;
|
| 139 |
+
if (bInfo->dims != dims) return;
|
| 140 |
+
sizes[1] = bInfo->sizes;
|
| 141 |
+
strides[1] = bInfo->strides;
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
if (cInfo != nullptr) {
|
| 145 |
+
++numInfos;
|
| 146 |
+
if (cInfo->dims != dims) return;
|
| 147 |
+
sizes[2] = cInfo->sizes;
|
| 148 |
+
strides[2] = cInfo->strides;
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
if (dInfo != nullptr) {
|
| 152 |
+
++numInfos;
|
| 153 |
+
if (dInfo->dims != dims) return;
|
| 154 |
+
sizes[3] = dInfo->sizes;
|
| 155 |
+
strides[3] = dInfo->strides;
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
// Bail out if sizes do not match: we are using "deprecated pointwise
|
| 159 |
+
// behavior" among tensors of different shapes but same number of elements.
|
| 160 |
+
for (int i = 1; i < numInfos; ++i) {
|
| 161 |
+
for (int j = 0; j < dims; ++j) {
|
| 162 |
+
if (sizes[i][j] != sizes[0][j]) return;
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
for (int i = 0; i < dims - 1; ++i) {
|
| 167 |
+
// No need to consider dimensions of size 1.
|
| 168 |
+
if (sizes[0][i] == 1) continue;
|
| 169 |
+
|
| 170 |
+
for (int j = i + 1; j < dims; ++j) {
|
| 171 |
+
if (sizes[0][j] == 1) continue;
|
| 172 |
+
|
| 173 |
+
// Compare the relative sizes of strides between dim #i and dim #j.
|
| 174 |
+
bool hasIncreasingStrides = false;
|
| 175 |
+
bool hasDecreasingStrides = false;
|
| 176 |
+
|
| 177 |
+
for (int k = 0; k < numInfos; k++) {
|
| 178 |
+
IndexType stride_i = strides[k][i];
|
| 179 |
+
IndexType stride_j = strides[k][j];
|
| 180 |
+
if (stride_i < stride_j) {
|
| 181 |
+
hasIncreasingStrides = true;
|
| 182 |
+
} else if (stride_i > stride_j) {
|
| 183 |
+
hasDecreasingStrides = true;
|
| 184 |
+
}
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
if (hasIncreasingStrides && !hasDecreasingStrides) {
|
| 188 |
+
for (int k = 0; k < numInfos; k++) {
|
| 189 |
+
IndexType size = sizes[k][i];
|
| 190 |
+
sizes[k][i] = sizes[k][j];
|
| 191 |
+
sizes[k][j] = size;
|
| 192 |
+
|
| 193 |
+
IndexType stride = strides[k][i];
|
| 194 |
+
strides[k][i] = strides[k][j];
|
| 195 |
+
strides[k][j] = stride;
|
| 196 |
+
}
|
| 197 |
+
}
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
// The `remaining_steps` argument is used to support Op that operates on
|
| 203 |
+
// multiple elements at the same time. Generally, the strategy of ApplyOpN is to
|
| 204 |
+
// 1. Initialize `remaining_steps = step`, where `step` is the template arg of
|
| 205 |
+
// CUDA_tensor_applyN helpers. The input arg `n` to `apply()` represents the
|
| 206 |
+
// number of elements in bound for this call. It will almost always equal to
|
| 207 |
+
// `step` except at boundaries.
|
| 208 |
+
// 2. If `remaining_steps > 0` convert the current linearIndex to offset (if in
|
| 209 |
+
// bound), and recursively call `ApplyOpN` with `remaining_steps - 1`.
|
| 210 |
+
// 3. At `remaining_steps = 0`,
|
| 211 |
+
// if `step = 1`, call `op(tensor1_val, tensor2_val, ...)`;
|
| 212 |
+
// if `step > 1`, call `op(n, tensor1_val1, tensor1_val2, ..., tesor1_valstep,
|
| 213 |
+
// tensor2_val1, tensor2_val2, ..., tesor2_valstep,
|
| 214 |
+
// ...
|
| 215 |
+
// tensorN_val1, tensorN_val2, ..., tesorN_valstep);`
|
| 216 |
+
//
|
| 217 |
+
// See NOTE [ CUDA_tensor_applyN helpers ] above for how Op may look like.
|
| 218 |
+
|
| 219 |
+
template <typename Op,
|
| 220 |
+
typename scalar,
|
| 221 |
+
typename IndexType,
|
| 222 |
+
int ADims,
|
| 223 |
+
int remaining_steps,
|
| 224 |
+
typename... Offsets>
|
| 225 |
+
struct ApplyOp1 {
|
| 226 |
+
__device__ __forceinline__
|
| 227 |
+
static void apply(detail::TensorInfo<scalar, IndexType> &a, const Op &op, int n,
|
| 228 |
+
IndexType linearIndex, Offsets... aOffsets) {
|
| 229 |
+
// Convert `linearIndex` into an offset of `a`
|
| 230 |
+
const IndexType aOffset = sizeof...(Offsets) < n ?
|
| 231 |
+
detail::IndexToOffset<scalar, IndexType, ADims>::get(linearIndex, a) : 0;
|
| 232 |
+
|
| 233 |
+
ApplyOp1<Op, scalar, IndexType, ADims, remaining_steps - 1, const IndexType, Offsets...>::apply(
|
| 234 |
+
a, op, n, linearIndex + 1, aOffsets..., aOffset
|
| 235 |
+
);
|
| 236 |
+
}
|
| 237 |
+
};
|
| 238 |
+
|
| 239 |
+
// Specialize `step=1` case (i.e., `remaining_steps=0` and `len(Offsets)=1`).
|
| 240 |
+
// We don't need to pass in how many elements need to processed in this case.
|
| 241 |
+
template <typename Op,
|
| 242 |
+
typename scalar,
|
| 243 |
+
typename IndexType,
|
| 244 |
+
int ADims,
|
| 245 |
+
typename Offset>
|
| 246 |
+
struct ApplyOp1<Op, scalar, IndexType, ADims, 0, Offset> {
|
| 247 |
+
__device__ __forceinline__
|
| 248 |
+
static void apply(detail::TensorInfo<scalar, IndexType> &a, const Op &op,
|
| 249 |
+
int n, IndexType linearIndex, Offset offset) {
|
| 250 |
+
op(a.data[offset]);
|
| 251 |
+
}
|
| 252 |
+
};
|
| 253 |
+
|
| 254 |
+
template <typename Op,
|
| 255 |
+
typename scalar,
|
| 256 |
+
typename IndexType,
|
| 257 |
+
int ADims,
|
| 258 |
+
typename... Offsets>
|
| 259 |
+
struct ApplyOp1<Op, scalar, IndexType, ADims, 0, Offsets...> {
|
| 260 |
+
__device__ __forceinline__
|
| 261 |
+
static void apply(detail::TensorInfo<scalar, IndexType> &a, const Op &op, int n,
|
| 262 |
+
IndexType linearIndex, Offsets... offsets) {
|
| 263 |
+
op(n, a.data[offsets]...);
|
| 264 |
+
}
|
| 265 |
+
};
|
| 266 |
+
|
| 267 |
+
template <typename Op,
|
| 268 |
+
typename scalar,
|
| 269 |
+
typename IndexType,
|
| 270 |
+
int ADims,
|
| 271 |
+
int step>
|
| 272 |
+
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
|
| 273 |
+
C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
|
| 274 |
+
#endif
|
| 275 |
+
__global__ void kernelPointwiseApply1(detail::TensorInfo<scalar, IndexType> a,
|
| 276 |
+
IndexType totalElements, const Op op) {
|
| 277 |
+
for (IndexType linearIndex = (blockIdx.x * blockDim.x + threadIdx.x) * step;
|
| 278 |
+
linearIndex < totalElements;
|
| 279 |
+
linearIndex += gridDim.x * blockDim.x * step) {
|
| 280 |
+
ApplyOp1<Op, scalar, IndexType, ADims, step>::apply(
|
| 281 |
+
a, op, ::min(step, static_cast<int>(totalElements - linearIndex)), linearIndex);
|
| 282 |
+
}
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
template <typename Op,
|
| 287 |
+
typename scalar1,
|
| 288 |
+
typename scalar2,
|
| 289 |
+
typename IndexType,
|
| 290 |
+
int ADims,
|
| 291 |
+
int BDims,
|
| 292 |
+
int remaining_steps,
|
| 293 |
+
typename... Offsets>
|
| 294 |
+
struct ApplyOp2 {
|
| 295 |
+
__device__ __forceinline__
|
| 296 |
+
static void apply(detail::TensorInfo<scalar1, IndexType> &a,
|
| 297 |
+
detail::TensorInfo<scalar2, IndexType> &b,
|
| 298 |
+
const Op &op, int64_t n, IndexType linearIndex,
|
| 299 |
+
Offsets... aOffsets, Offsets... bOffsets) {
|
| 300 |
+
// Convert `linearIndex` into an offset of `a`
|
| 301 |
+
const IndexType aOffset = static_cast<int64_t>(sizeof...(Offsets)) < n ?
|
| 302 |
+
detail::IndexToOffset<scalar1, IndexType, ADims>::get(linearIndex, a) : 0;
|
| 303 |
+
|
| 304 |
+
// Convert `linearIndex` into an offset of `b`
|
| 305 |
+
const IndexType bOffset = static_cast<int64_t>(sizeof...(Offsets)) < n ?
|
| 306 |
+
detail::IndexToOffset<scalar2, IndexType, BDims>::get(linearIndex, b) : 0;
|
| 307 |
+
|
| 308 |
+
ApplyOp2<Op, scalar1, scalar2, IndexType, ADims, BDims, remaining_steps - 1, const IndexType, Offsets...>::apply(
|
| 309 |
+
a, b, op, n, linearIndex + 1, aOffsets..., aOffset, bOffsets..., bOffset
|
| 310 |
+
);
|
| 311 |
+
}
|
| 312 |
+
};
|
| 313 |
+
|
| 314 |
+
// Specialize `step=1` case (i.e., `remaining_steps=0` and `len(Offsets)=1`).
|
| 315 |
+
// We don't need to pass in how many elements need to processed in this case.
|
| 316 |
+
template <typename Op,
|
| 317 |
+
typename scalar1,
|
| 318 |
+
typename scalar2,
|
| 319 |
+
typename IndexType,
|
| 320 |
+
int ADims,
|
| 321 |
+
int BDims,
|
| 322 |
+
typename Offset>
|
| 323 |
+
struct ApplyOp2<Op, scalar1, scalar2, IndexType, ADims, BDims, 0, Offset> {
|
| 324 |
+
__device__ __forceinline__
|
| 325 |
+
static void apply(detail::TensorInfo<scalar1, IndexType> &a,
|
| 326 |
+
detail::TensorInfo<scalar2, IndexType> &b,
|
| 327 |
+
const Op &op, int /*n*/, IndexType /*linearIndex*/,
|
| 328 |
+
Offset aOffset, Offset bOffset) {
|
| 329 |
+
op(a.data[aOffset], b.data[bOffset]);
|
| 330 |
+
}
|
| 331 |
+
};
|
| 332 |
+
|
| 333 |
+
template <typename Op,
|
| 334 |
+
typename scalar1,
|
| 335 |
+
typename scalar2,
|
| 336 |
+
typename IndexType,
|
| 337 |
+
int ADims,
|
| 338 |
+
int BDims,
|
| 339 |
+
typename... Offsets>
|
| 340 |
+
struct ApplyOp2<Op, scalar1, scalar2, IndexType, ADims, BDims, 0, Offsets...> {
|
| 341 |
+
__device__ __forceinline__
|
| 342 |
+
static void apply(detail::TensorInfo<scalar1, IndexType> &a,
|
| 343 |
+
detail::TensorInfo<scalar2, IndexType> &b,
|
| 344 |
+
const Op &op, int n, IndexType linearIndex,
|
| 345 |
+
Offsets... aOffsets, Offsets... bOffsets) {
|
| 346 |
+
op(n, a.data[aOffsets]..., b.data[bOffsets]...);
|
| 347 |
+
}
|
| 348 |
+
};
|
| 349 |
+
|
| 350 |
+
template <typename Op,
|
| 351 |
+
typename scalar1,
|
| 352 |
+
typename scalar2,
|
| 353 |
+
typename IndexType,
|
| 354 |
+
int ADims, int BDims,
|
| 355 |
+
int step,
|
| 356 |
+
int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK,
|
| 357 |
+
int min_blocks_per_sm=AT_APPLY_BLOCKS_PER_SM>
|
| 358 |
+
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
|
| 359 |
+
C10_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm)
|
| 360 |
+
#endif
|
| 361 |
+
__global__ void
|
| 362 |
+
kernelPointwiseApply2(detail::TensorInfo<scalar1, IndexType> a,
|
| 363 |
+
detail::TensorInfo<scalar2, IndexType> b,
|
| 364 |
+
IndexType totalElements,
|
| 365 |
+
const Op op) {
|
| 366 |
+
for (IndexType linearIndex = (blockIdx.x * blockDim.x + threadIdx.x) * step;
|
| 367 |
+
linearIndex < totalElements;
|
| 368 |
+
linearIndex += gridDim.x * blockDim.x * step) {
|
| 369 |
+
ApplyOp2<Op, scalar1, scalar2, IndexType, ADims, BDims, step>::apply(
|
| 370 |
+
a, b, op, ::min(step, static_cast<int>(totalElements - linearIndex)),
|
| 371 |
+
linearIndex);
|
| 372 |
+
}
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
} // anonymous namespace
|
| 376 |
+
|
| 377 |
+
template <typename scalar1, typename scalar2, int step, typename Op,
|
| 378 |
+
int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK,
|
| 379 |
+
int min_blocks_per_sm=AT_APPLY_BLOCKS_PER_SM>
|
| 380 |
+
inline bool CUDA_tensor_apply2(at::TensorBase a,
|
| 381 |
+
at::TensorBase b,
|
| 382 |
+
const Op op,
|
| 383 |
+
TensorArgType aType = TensorArgType::ReadWrite,
|
| 384 |
+
TensorArgType bType = TensorArgType::ReadOnly) {
|
| 385 |
+
TORCH_CHECK(a.device().is_cuda() && b.device().is_cuda(),
|
| 386 |
+
"CUDA_tensor_apply2: Expected tensors to have CUDA DeviceType, but got "
|
| 387 |
+
"tensors with type ", a.device().type(), " and ", b.device().type());
|
| 388 |
+
int64_t totalElements = a.numel();
|
| 389 |
+
|
| 390 |
+
if (totalElements != b.numel()) {
|
| 391 |
+
return false;
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
if (a.dim() > MAX_TENSORINFO_DIMS ||
|
| 395 |
+
b.dim() > MAX_TENSORINFO_DIMS) {
|
| 396 |
+
return false;
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
if (a.numel() == 0) {
|
| 400 |
+
// Empty tensor; do nothing
|
| 401 |
+
return true;
|
| 402 |
+
}
|
| 403 |
+
const dim3 block = getApplyBlock(max_threads_per_block);
|
| 404 |
+
|
| 405 |
+
dim3 grid;
|
| 406 |
+
auto curDevice = current_device();
|
| 407 |
+
if (curDevice == -1) return false;
|
| 408 |
+
if (!getApplyGrid<step>(totalElements, grid, curDevice, max_threads_per_block)) {
|
| 409 |
+
return false;
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
/*
|
| 413 |
+
Expands readable/writable tensors whose indices may be "overlapped."
|
| 414 |
+
This ensures that each element of the tensor is operated on once and only
|
| 415 |
+
once.
|
| 416 |
+
*/
|
| 417 |
+
TensorBase oldA;
|
| 418 |
+
TensorBase oldB;
|
| 419 |
+
|
| 420 |
+
if (aType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(a)) {
|
| 421 |
+
// Must perform in contiguous space
|
| 422 |
+
oldA = std::exchange(a, a.contiguous());
|
| 423 |
+
}
|
| 424 |
+
if (bType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(b)) {
|
| 425 |
+
// Must perform in contiguous space
|
| 426 |
+
oldB = std::exchange(b, b.contiguous());
|
| 427 |
+
}
|
| 428 |
+
|
| 429 |
+
// It is possible that the tensor dimensions are able to be collapsed,
|
| 430 |
+
// and thus we can reduce the actual code complexity of the copy by
|
| 431 |
+
// exploiting this knowledge statically, since the div/mod is the
|
| 432 |
+
// most expensive part of the operation, more so than memory accesses.
|
| 433 |
+
// For instance, when copying a non-contiguous to a contiguous tensor
|
| 434 |
+
// (or vice versa), the contiguous tensor can be collapsed to one
|
| 435 |
+
// dimension, and the loop to translate the linear index to the array
|
| 436 |
+
// index can be similarly collapsed. That is what this unrolling is for.
|
| 437 |
+
|
| 438 |
+
#define HANDLE_CASE(TYPE, A, B) \
|
| 439 |
+
kernelPointwiseApply2<Op, \
|
| 440 |
+
scalar1, \
|
| 441 |
+
scalar2, \
|
| 442 |
+
TYPE, A, B, step, \
|
| 443 |
+
max_threads_per_block, \
|
| 444 |
+
min_blocks_per_sm> \
|
| 445 |
+
<<<grid, block, 0, at::cuda::getCurrentCUDAStream(curDevice)>>>( \
|
| 446 |
+
aInfo, bInfo, static_cast<TYPE>(totalElements), op); \
|
| 447 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 448 |
+
|
| 449 |
+
#define HANDLE_B_CASE(TYPE, A, B) { \
|
| 450 |
+
switch (B) { \
|
| 451 |
+
case 1: \
|
| 452 |
+
HANDLE_CASE(TYPE, A, 1); \
|
| 453 |
+
break; \
|
| 454 |
+
case 2: \
|
| 455 |
+
HANDLE_CASE(TYPE, A, 2); \
|
| 456 |
+
break; \
|
| 457 |
+
default: \
|
| 458 |
+
HANDLE_CASE(TYPE, A, -1); \
|
| 459 |
+
break; \
|
| 460 |
+
} \
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
#define HANDLE_A_CASE(TYPE, A, B) { \
|
| 464 |
+
switch (A) { \
|
| 465 |
+
case 1: \
|
| 466 |
+
HANDLE_B_CASE(TYPE, 1, B); \
|
| 467 |
+
break; \
|
| 468 |
+
case 2: \
|
| 469 |
+
HANDLE_B_CASE(TYPE, 2, B); \
|
| 470 |
+
break; \
|
| 471 |
+
default: \
|
| 472 |
+
HANDLE_B_CASE(TYPE, -1, B); \
|
| 473 |
+
break; \
|
| 474 |
+
} \
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
if (detail::canUse32BitIndexMath(a) &&
|
| 478 |
+
detail::canUse32BitIndexMath(b)) {
|
| 479 |
+
detail::TensorInfo<scalar1, unsigned int> aInfo =
|
| 480 |
+
detail::getTensorInfo<scalar1, unsigned int>(a);
|
| 481 |
+
|
| 482 |
+
detail::TensorInfo<scalar2, unsigned int> bInfo =
|
| 483 |
+
detail::getTensorInfo<scalar2, unsigned int>(b);
|
| 484 |
+
rearrangeDims(&aInfo, &bInfo);
|
| 485 |
+
aInfo.collapseDims();
|
| 486 |
+
bInfo.collapseDims();
|
| 487 |
+
|
| 488 |
+
HANDLE_A_CASE(unsigned int, aInfo.dims, bInfo.dims);
|
| 489 |
+
} else {
|
| 490 |
+
detail::TensorInfo<scalar1, uint64_t> aInfo =
|
| 491 |
+
detail::getTensorInfo<scalar1, uint64_t>(a);
|
| 492 |
+
|
| 493 |
+
detail::TensorInfo<scalar2, uint64_t> bInfo =
|
| 494 |
+
detail::getTensorInfo<scalar2, uint64_t>(b);
|
| 495 |
+
rearrangeDims(&aInfo, &bInfo);
|
| 496 |
+
aInfo.collapseDims();
|
| 497 |
+
bInfo.collapseDims();
|
| 498 |
+
|
| 499 |
+
/*
|
| 500 |
+
Only instantiates the all 1D special case and the fallback all nD case for
|
| 501 |
+
large (64-bit indexed) tensors to reduce compilation time.
|
| 502 |
+
*/
|
| 503 |
+
if (aInfo.dims == 1 && bInfo.dims == 1) {
|
| 504 |
+
HANDLE_CASE(uint64_t, 1, 1);
|
| 505 |
+
} else {
|
| 506 |
+
HANDLE_CASE(uint64_t, -1, -1);
|
| 507 |
+
}
|
| 508 |
+
}
|
| 509 |
+
#undef HANDLE_CASE
|
| 510 |
+
#undef HANDLE_B_CASE
|
| 511 |
+
#undef HANDLE_A_CASE
|
| 512 |
+
|
| 513 |
+
if (oldA.defined()) {
|
| 514 |
+
at::native::copy_ignoring_overlaps(oldA, a);
|
| 515 |
+
}
|
| 516 |
+
|
| 517 |
+
if (oldB.defined()) {
|
| 518 |
+
at::native::copy_ignoring_overlaps(oldB, b);
|
| 519 |
+
}
|
| 520 |
+
|
| 521 |
+
return true;
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
/* Provides default step = 1 to CUDA_tensor_apply2. */
|
| 525 |
+
template <typename scalar1, typename scalar2, typename Op,
|
| 526 |
+
int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK,
|
| 527 |
+
int min_blocks_per_sm=AT_APPLY_BLOCKS_PER_SM>
|
| 528 |
+
inline bool CUDA_tensor_apply2(const at::TensorBase &a,
|
| 529 |
+
const at::TensorBase &b,
|
| 530 |
+
const Op op,
|
| 531 |
+
TensorArgType aType = TensorArgType::ReadWrite,
|
| 532 |
+
TensorArgType bType = TensorArgType::ReadOnly) {
|
| 533 |
+
return CUDA_tensor_apply2<scalar1, scalar2, 1, Op,
|
| 534 |
+
max_threads_per_block, min_blocks_per_sm>(a, b, op, aType, bType);
|
| 535 |
+
}
|
| 536 |
+
|
| 537 |
+
} // namespace at::cuda
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDABlas.h
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
/*
|
| 3 |
+
Provides a subset of CUDA BLAS functions as templates:
|
| 4 |
+
|
| 5 |
+
gemm<Dtype>(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c,
|
| 6 |
+
ldc)
|
| 7 |
+
|
| 8 |
+
gemv<Dtype>(transa, m, n, alpha, a, lda, x, incx, beta, y, incy)
|
| 9 |
+
|
| 10 |
+
dot<Dtype>(n, x, incx, y, incy, result)
|
| 11 |
+
|
| 12 |
+
where Dtype is double, float, at::Half or at::BFloat16 (ROCm, NOT for dot).
|
| 13 |
+
The functions are available in at::cuda::blas namespace.
|
| 14 |
+
*/
|
| 15 |
+
|
| 16 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 17 |
+
#include <ATen/OpMathType.h>
|
| 18 |
+
|
| 19 |
+
namespace at::cuda::blas {
|
| 20 |
+
|
| 21 |
+
// RAII guard that sets the CuBLAS pointer mode and restores it to
|
| 22 |
+
// its previous value when the guard is destroyed
|
| 23 |
+
class PointerModeGuard {
|
| 24 |
+
public:
|
| 25 |
+
PointerModeGuard(cublasHandle_t handle, cublasPointerMode_t mode) :
|
| 26 |
+
handle(handle) {
|
| 27 |
+
TORCH_CUDABLAS_CHECK(cublasGetPointerMode(handle, &previous_mode));
|
| 28 |
+
TORCH_CUDABLAS_CHECK(cublasSetPointerMode(handle, mode));
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
~PointerModeGuard() {
|
| 32 |
+
cublasSetPointerMode(handle, previous_mode);
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
private:
|
| 36 |
+
cublasHandle_t handle;
|
| 37 |
+
cublasPointerMode_t previous_mode;
|
| 38 |
+
};
|
| 39 |
+
|
| 40 |
+
/* LEVEL 3 BLAS FUNCTIONS */
|
| 41 |
+
|
| 42 |
+
#define CUDABLAS_GEMM_ARGTYPES(Dtype) \
|
| 43 |
+
char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type<Dtype> alpha, \
|
| 44 |
+
const Dtype *a, int64_t lda, const Dtype *b, int64_t ldb, at::opmath_type<Dtype> beta,\
|
| 45 |
+
Dtype *c, int64_t ldc
|
| 46 |
+
|
| 47 |
+
#define CUDABLAS_GEMM_ARGS(Dtype) transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc
|
| 48 |
+
|
| 49 |
+
template <typename Dtype>
|
| 50 |
+
inline void gemm(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
|
| 51 |
+
static_assert(false&&sizeof(Dtype),"at::cuda::blas::gemm: not implemented");
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
template <>
|
| 55 |
+
void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double));
|
| 56 |
+
template <>
|
| 57 |
+
void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float));
|
| 58 |
+
template <>
|
| 59 |
+
void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>));
|
| 60 |
+
template <>
|
| 61 |
+
void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>));
|
| 62 |
+
template <>
|
| 63 |
+
void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
|
| 64 |
+
template <>
|
| 65 |
+
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
|
| 66 |
+
|
| 67 |
+
template <typename Dtype>
|
| 68 |
+
inline void gemm_internal(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
|
| 69 |
+
static_assert(false&&sizeof(Dtype),"at::cuda::blas::gemm_internal: not implemented");
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
template <>
|
| 73 |
+
void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double));
|
| 74 |
+
template <>
|
| 75 |
+
void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float));
|
| 76 |
+
template <>
|
| 77 |
+
void gemm_internal<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>));
|
| 78 |
+
template <>
|
| 79 |
+
void gemm_internal<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>));
|
| 80 |
+
template <>
|
| 81 |
+
void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
|
| 82 |
+
template <>
|
| 83 |
+
void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
|
| 84 |
+
|
| 85 |
+
enum GEMMAndBiasActivationEpilogue {
|
| 86 |
+
None,
|
| 87 |
+
RELU,
|
| 88 |
+
GELU,
|
| 89 |
+
};
|
| 90 |
+
|
| 91 |
+
// NOTE: GELU activation is not supported prior to CUDA 11.4 and will
|
| 92 |
+
// do nothing if passed in that case.
|
| 93 |
+
template <typename Dtype>
|
| 94 |
+
void gemm_and_bias(
|
| 95 |
+
bool transpose_mat1,
|
| 96 |
+
bool transpose_mat2,
|
| 97 |
+
int64_t m,
|
| 98 |
+
int64_t n,
|
| 99 |
+
int64_t k,
|
| 100 |
+
at::opmath_type<Dtype> alpha_val,
|
| 101 |
+
const Dtype* mat1_ptr,
|
| 102 |
+
int64_t mat1_ld,
|
| 103 |
+
const Dtype* mat2_ptr,
|
| 104 |
+
int64_t mat2_ld,
|
| 105 |
+
const Dtype* bias,
|
| 106 |
+
Dtype* result_ptr,
|
| 107 |
+
int64_t result_ld,
|
| 108 |
+
GEMMAndBiasActivationEpilogue activation = GEMMAndBiasActivationEpilogue::None);
|
| 109 |
+
|
| 110 |
+
void int8_gemm(
|
| 111 |
+
bool transpose_mat1,
|
| 112 |
+
bool transpose_mat2,
|
| 113 |
+
int64_t m,
|
| 114 |
+
int64_t n,
|
| 115 |
+
int64_t k,
|
| 116 |
+
const int8_t* mat1_ptr,
|
| 117 |
+
int64_t mat1_ld,
|
| 118 |
+
const int8_t* mat2_ptr,
|
| 119 |
+
int64_t mat2_ld,
|
| 120 |
+
int32_t* result_ptr,
|
| 121 |
+
int64_t result_ld);
|
| 122 |
+
|
| 123 |
+
void scaled_gemm(
|
| 124 |
+
char transa,
|
| 125 |
+
char transb,
|
| 126 |
+
int64_t m,
|
| 127 |
+
int64_t n,
|
| 128 |
+
int64_t k,
|
| 129 |
+
const void* mat1_ptr,
|
| 130 |
+
const void* mat1_scale_ptr,
|
| 131 |
+
int64_t mat1_ld,
|
| 132 |
+
ScalarType mat1_dtype,
|
| 133 |
+
const void* mat2_ptr,
|
| 134 |
+
const void* mat2_scale_ptr,
|
| 135 |
+
int64_t mat2_ld,
|
| 136 |
+
ScalarType mat2_dtype,
|
| 137 |
+
const void* bias_ptr,
|
| 138 |
+
ScalarType bias_dtype,
|
| 139 |
+
void* result_ptr,
|
| 140 |
+
const void* result_scale_ptr,
|
| 141 |
+
int64_t result_ld,
|
| 142 |
+
ScalarType result_dtype,
|
| 143 |
+
void* amax_ptr,
|
| 144 |
+
bool use_fast_accum);
|
| 145 |
+
|
| 146 |
+
#define CUDABLAS_BGEMM_ARGTYPES(Dtype) \
|
| 147 |
+
char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type<Dtype> alpha, \
|
| 148 |
+
const Dtype *a, int64_t lda, int64_t stridea, \
|
| 149 |
+
const Dtype *b, int64_t ldb, int64_t strideb, \
|
| 150 |
+
at::opmath_type<Dtype> beta, Dtype *c, int64_t ldc, int64_t stridec, int64_t num_batches
|
| 151 |
+
|
| 152 |
+
#define CUDABLAS_BGEMM_ARGS(Dtype) \
|
| 153 |
+
transa, transb, m, n, k, alpha, a, lda, stridea, b, ldb, strideb, beta, c, ldc, stridec, num_batches
|
| 154 |
+
|
| 155 |
+
template <typename Dtype>
|
| 156 |
+
inline void bgemm(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
|
| 157 |
+
static_assert(false&&sizeof(Dtype),"at::cuda::blas::bgemm: not implemented");
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
template <>
|
| 161 |
+
void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double));
|
| 162 |
+
template <>
|
| 163 |
+
void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float));
|
| 164 |
+
template <>
|
| 165 |
+
void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>));
|
| 166 |
+
template <>
|
| 167 |
+
void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>));
|
| 168 |
+
template <>
|
| 169 |
+
void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half));
|
| 170 |
+
template <>
|
| 171 |
+
void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));
|
| 172 |
+
|
| 173 |
+
template <typename Dtype>
|
| 174 |
+
inline void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
|
| 175 |
+
static_assert(false&&sizeof(Dtype),"at::cuda::blas::bgemm_internal: not implemented");
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
template <>
|
| 179 |
+
void bgemm_internal<double>(CUDABLAS_BGEMM_ARGTYPES(double));
|
| 180 |
+
template <>
|
| 181 |
+
void bgemm_internal<float>(CUDABLAS_BGEMM_ARGTYPES(float));
|
| 182 |
+
template <>
|
| 183 |
+
void bgemm_internal<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>));
|
| 184 |
+
template <>
|
| 185 |
+
void bgemm_internal<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>));
|
| 186 |
+
template <>
|
| 187 |
+
void bgemm_internal<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half));
|
| 188 |
+
template <>
|
| 189 |
+
void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));
|
| 190 |
+
|
| 191 |
+
#define CUDABLAS_TRSM_ARGTYPES(Dtype) \
|
| 192 |
+
cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, \
|
| 193 |
+
cublasOperation_t trans, cublasDiagType_t diag, int m, int n, \
|
| 194 |
+
const Dtype *alpha, const Dtype *A, int lda, Dtype *B, int ldb
|
| 195 |
+
|
| 196 |
+
template <typename Dtype>
|
| 197 |
+
inline void trsm(CUDABLAS_TRSM_ARGTYPES(Dtype)) {
|
| 198 |
+
static_assert(false&&sizeof(Dtype), "at::cuda::blas::trsm: not implemented");
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
template <>
|
| 202 |
+
TORCH_CUDA_CU_API void trsm<float>(CUDABLAS_TRSM_ARGTYPES(float));
|
| 203 |
+
template <>
|
| 204 |
+
TORCH_CUDA_CU_API void trsm<double>(CUDABLAS_TRSM_ARGTYPES(double));
|
| 205 |
+
template <>
|
| 206 |
+
TORCH_CUDA_CU_API void trsm<c10::complex<float>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<float>));
|
| 207 |
+
template <>
|
| 208 |
+
TORCH_CUDA_CU_API void trsm<c10::complex<double>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<double>));
|
| 209 |
+
|
| 210 |
+
#define CUDABLAS_TRSM_BATCHED_ARGTYPES(Dtype) \
|
| 211 |
+
cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, \
|
| 212 |
+
cublasOperation_t trans, cublasDiagType_t diag, int m, int n, \
|
| 213 |
+
const Dtype *alpha, Dtype *A[], int lda, Dtype *B[], int ldb, \
|
| 214 |
+
int batchCount
|
| 215 |
+
|
| 216 |
+
template <typename Dtype>
|
| 217 |
+
inline void trsmBatched(CUDABLAS_TRSM_BATCHED_ARGTYPES(Dtype)) {
|
| 218 |
+
static_assert(false&&sizeof(Dtype), "at::cuda::blas::trsmBatched: not implemented");
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
template <>
|
| 222 |
+
TORCH_CUDA_CU_API void trsmBatched<float>(CUDABLAS_TRSM_BATCHED_ARGTYPES(float));
|
| 223 |
+
template <>
|
| 224 |
+
TORCH_CUDA_CU_API void trsmBatched<double>(CUDABLAS_TRSM_BATCHED_ARGTYPES(double));
|
| 225 |
+
template <>
|
| 226 |
+
TORCH_CUDA_CU_API void trsmBatched<c10::complex<float>>(CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex<float>));
|
| 227 |
+
template <>
|
| 228 |
+
TORCH_CUDA_CU_API void trsmBatched<c10::complex<double>>(CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex<double>));
|
| 229 |
+
|
| 230 |
+
/* LEVEL 2 BLAS FUNCTIONS */
|
| 231 |
+
|
| 232 |
+
#define CUDABLAS_GEMV_ARGTYPES(Dtype) \
|
| 233 |
+
char trans, int64_t m, int64_t n, Dtype alpha, const Dtype *a, int64_t lda, \
|
| 234 |
+
const Dtype *x, int64_t incx, Dtype beta, Dtype *y, int64_t incy
|
| 235 |
+
|
| 236 |
+
template <typename Dtype>
|
| 237 |
+
inline void gemv(CUDABLAS_GEMV_ARGTYPES(Dtype)) {
|
| 238 |
+
static_assert(false&&sizeof(Dtype), "at::cuda::blas::gemv: not implemented");
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
template <>
|
| 242 |
+
void gemv<double>(CUDABLAS_GEMV_ARGTYPES(double));
|
| 243 |
+
template <>
|
| 244 |
+
void gemv<float>(CUDABLAS_GEMV_ARGTYPES(float));
|
| 245 |
+
template <>
|
| 246 |
+
void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>));
|
| 247 |
+
template <>
|
| 248 |
+
void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>));
|
| 249 |
+
template <>
|
| 250 |
+
void gemv<at::Half>(CUDABLAS_GEMV_ARGTYPES(at::Half));
|
| 251 |
+
template <>
|
| 252 |
+
void gemv<at::BFloat16>(CUDABLAS_GEMV_ARGTYPES(at::BFloat16));
|
| 253 |
+
|
| 254 |
+
/* LEVEL 1 BLAS FUNCTIONS */
|
| 255 |
+
|
| 256 |
+
#define CUDABLAS_DOT_ARGTYPES(Dtype) \
|
| 257 |
+
cublasHandle_t handle, int n, const Dtype *x, int incx, const Dtype *y, \
|
| 258 |
+
int incy, Dtype *result
|
| 259 |
+
|
| 260 |
+
template <typename Dtype>
|
| 261 |
+
inline void dot(CUDABLAS_DOT_ARGTYPES(Dtype)) {
|
| 262 |
+
static_assert(false&&sizeof(Dtype),"at::cuda::blas::dot: not implemented");
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
template <>
|
| 266 |
+
void dot<double>(CUDABLAS_DOT_ARGTYPES(double));
|
| 267 |
+
template <>
|
| 268 |
+
void dot<float>(CUDABLAS_DOT_ARGTYPES(float));
|
| 269 |
+
template <>
|
| 270 |
+
void dot<at::Half>(CUDABLAS_DOT_ARGTYPES(at::Half));
|
| 271 |
+
template <>
|
| 272 |
+
void dot<at::BFloat16>(CUDABLAS_DOT_ARGTYPES(at::BFloat16));
|
| 273 |
+
template <>
|
| 274 |
+
void dot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>));
|
| 275 |
+
template <>
|
| 276 |
+
void dot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>));
|
| 277 |
+
|
| 278 |
+
template <typename Dtype>
|
| 279 |
+
inline void vdot(CUDABLAS_DOT_ARGTYPES(Dtype)) {
|
| 280 |
+
static_assert(false&&sizeof(Dtype),"at::cuda::blas::vdot: not implemented");
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
template <>
|
| 284 |
+
void vdot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>));
|
| 285 |
+
template <>
|
| 286 |
+
void vdot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>));
|
| 287 |
+
|
| 288 |
+
#define CUDABLAS_GETRS_ARGTYPES(Dtype) \
|
| 289 |
+
cublasHandle_t handle, cublasOperation_t trans, \
|
| 290 |
+
int n, int nrhs, Dtype** dA_array, int lda, int* ipiv_array, \
|
| 291 |
+
Dtype** dB_array, int ldb, int* info_array, int batchsize
|
| 292 |
+
|
| 293 |
+
template<class Dtype>
|
| 294 |
+
void getrsBatched(CUDABLAS_GETRS_ARGTYPES(Dtype)) {
|
| 295 |
+
static_assert(false&&sizeof(Dtype),"at::cuda::blas::getrsBatched: not implemented");
|
| 296 |
+
}
|
| 297 |
+
template<>
|
| 298 |
+
TORCH_CUDA_CU_API void getrsBatched<float>(CUDABLAS_GETRS_ARGTYPES(float));
|
| 299 |
+
template<>
|
| 300 |
+
TORCH_CUDA_CU_API void getrsBatched<double>(CUDABLAS_GETRS_ARGTYPES(double));
|
| 301 |
+
template<>
|
| 302 |
+
TORCH_CUDA_CU_API void getrsBatched<c10::complex<float>>(CUDABLAS_GETRS_ARGTYPES(c10::complex<float>));
|
| 303 |
+
template<>
|
| 304 |
+
TORCH_CUDA_CU_API void getrsBatched<c10::complex<double>>(CUDABLAS_GETRS_ARGTYPES(c10::complex<double>));
|
| 305 |
+
|
| 306 |
+
#define CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype) \
|
| 307 |
+
cublasHandle_t handle, int m, int n, Dtype **A_array, int lda, \
|
| 308 |
+
Dtype **tau_array, int *info, int batchsize
|
| 309 |
+
|
| 310 |
+
template <class Dtype>
|
| 311 |
+
void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype)) {
|
| 312 |
+
static_assert(false&&sizeof(Dtype), "at::cuda::blas::geqrfBatched: not implemented");
|
| 313 |
+
}
|
| 314 |
+
template <>
|
| 315 |
+
TORCH_CUDA_CU_API void geqrfBatched<float>(CUDABLAS_GEQRF_BATCHED_ARGTYPES(float));
|
| 316 |
+
template <>
|
| 317 |
+
TORCH_CUDA_CU_API void geqrfBatched<double>(CUDABLAS_GEQRF_BATCHED_ARGTYPES(double));
|
| 318 |
+
template <>
|
| 319 |
+
TORCH_CUDA_CU_API void geqrfBatched<c10::complex<double>>(
|
| 320 |
+
CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<double>));
|
| 321 |
+
template <>
|
| 322 |
+
TORCH_CUDA_CU_API void geqrfBatched<c10::complex<float>>(
|
| 323 |
+
CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<float>));
|
| 324 |
+
|
| 325 |
+
#define CUDABLAS_GETRF_ARGTYPES(Dtype) \
|
| 326 |
+
int n, Dtype** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize
|
| 327 |
+
|
| 328 |
+
template<class Dtype>
|
| 329 |
+
void getrfBatched(CUDABLAS_GETRF_ARGTYPES(Dtype)) {
|
| 330 |
+
TORCH_CHECK(false, "at::cuda::blas::getrfBatched: not implemented");
|
| 331 |
+
}
|
| 332 |
+
template<>
|
| 333 |
+
TORCH_CUDA_CU_API void getrfBatched<float>(CUDABLAS_GETRF_ARGTYPES(float));
|
| 334 |
+
template<>
|
| 335 |
+
TORCH_CUDA_CU_API void getrfBatched<double>(CUDABLAS_GETRF_ARGTYPES(double));
|
| 336 |
+
template<>
|
| 337 |
+
TORCH_CUDA_CU_API void getrfBatched<c10::complex<double>>(CUDABLAS_GETRF_ARGTYPES(c10::complex<double>));
|
| 338 |
+
template<>
|
| 339 |
+
TORCH_CUDA_CU_API void getrfBatched<c10::complex<float>>(CUDABLAS_GETRF_ARGTYPES(c10::complex<float>));
|
| 340 |
+
|
| 341 |
+
#define CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype) \
|
| 342 |
+
cublasHandle_t handle, cublasOperation_t trans, int m, int n, int nrhs, Dtype** dA_array, int ldda, Dtype** dC_array, int lddc, int* info, int *devInfoArray, int batchSize
|
| 343 |
+
|
| 344 |
+
template <class Dtype>
|
| 345 |
+
void gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype)) {
|
| 346 |
+
static_assert(false&&sizeof(Dtype),"at::cuda::blas::gelsBatched: not implemented");
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
template<>
|
| 350 |
+
TORCH_CUDA_CU_API void gelsBatched<double>(CUDABLAS_GELS_BATCHED_ARGTYPES(double));
|
| 351 |
+
template<>
|
| 352 |
+
TORCH_CUDA_CU_API void gelsBatched<float>(CUDABLAS_GELS_BATCHED_ARGTYPES(float));
|
| 353 |
+
template<>
|
| 354 |
+
TORCH_CUDA_CU_API void gelsBatched<c10::complex<double>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<double>));
|
| 355 |
+
template<>
|
| 356 |
+
TORCH_CUDA_CU_API void gelsBatched<c10::complex<float>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<float>));
|
| 357 |
+
|
| 358 |
+
} // namespace at::cuda::blas
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAConfig.h
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// Test these using #if AT_CUDNN_ENABLED(), not #ifdef, so that it's
|
| 4 |
+
// obvious if you forgot to include Config.h
|
| 5 |
+
// c.f. https://stackoverflow.com/questions/33759787/generating-an-error-if-checked-boolean-macro-is-not-defined
|
| 6 |
+
//
|
| 7 |
+
// NB: This header MUST NOT be included from other headers; it should
|
| 8 |
+
// only be included from C++ files.
|
| 9 |
+
#define AT_CUDNN_ENABLED() 1
|
| 10 |
+
#define AT_CUSPARSELT_ENABLED() 1
|
| 11 |
+
#define AT_ROCM_ENABLED() 0
|
| 12 |
+
#define AT_MAGMA_ENABLED() 1
|
| 13 |
+
|
| 14 |
+
// Needed for hipMAGMA to correctly identify implementation
|
| 15 |
+
#if (AT_ROCM_ENABLED() && AT_MAGMA_ENABLED())
|
| 16 |
+
#define HAVE_HIP 1
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
#define NVCC_FLAGS_EXTRA "-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90"
|
.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAContext.h
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/cuda/CUDAContextLight.h>
|
| 4 |
+
|
| 5 |
+
// Preserved for BC, as many files depend on these includes
|
| 6 |
+
#include <ATen/Context.h>
|
| 7 |
+
#include <c10/cuda/CUDAStream.h>
|
| 8 |
+
#include <c10/util/Logging.h>
|
| 9 |
+
#include <ATen/cuda/Exceptions.h>
|