Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/lib/libcudart.so.11.0 +3 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cudagraph_trees.cpython-311.pyc +3 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/CompositeRandomAccessor.h +34 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/DispatchStub.h +315 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Distance.h +20 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/FractionalMaxPooling.h +80 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/FunctionOfAMatrixUtils.h +20 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/GridSampler.h +298 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/GridSamplerUtils.h +109 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Lerp.h +46 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/LinearAlgebraUtils.h +623 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/MathBitFallThroughLists.h +71 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Sorting.h +28 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SpectralOpsUtils.h +84 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorTransformations.h +30 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/GridSamplerKernel.h +34 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/IndexKernelUtils.h +88 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/LogAddExp.h +61 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/PixelShuffleKernel.h +14 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SampledAddmmKernel.h +12 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/UpSampleKernelAVXAntialias.h +1376 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/WeightNormKernel.h +20 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/mixed_data_type.h +41 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/moments_utils.h +206 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/TensorModeKernel.cuh +435 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/nested/NestedTensorBinaryOps.h +16 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/vol2col.h +109 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_cast_Long_native.h +21 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_cufft_clear_plan_cache_compositeimplicitautograd_dispatch.h +23 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_fft_c2c_ops.h +39 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_expm1_ops.h +50 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_log1p_cuda_dispatch.h +24 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_reciprocal_ops.h +50 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_indices_copy.h +39 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_linalg_svd_meta_dispatch.h +25 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_lstm_mps.h +39 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_nested_get_values_copy_native.h +22 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_sparse_mask_projection_ops.h +39 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_spdiags.h +39 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_test_optional_intlist_compositeexplicitautograd_dispatch.h +24 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_upsample_bilinear2d_aa_cpu_dispatch.h +28 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/abs_ops.h +50 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/aminmax.h +39 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/arcsinh_compositeimplicitautograd_dispatch.h +26 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/avg_pool2d_ops.h +39 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/batch_norm_compositeimplicitautograd_dispatch.h +23 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/block_diag_native.h +22 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/cat_compositeimplicitautograd_dispatch.h +25 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/cat_native.h +32 -0
.gitattributes
CHANGED
|
@@ -78,3 +78,5 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/lib
|
|
| 78 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/t64.exe filter=lfs diff=lfs merge=lfs -text
|
| 79 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn.so.8 filter=lfs diff=lfs merge=lfs -text
|
| 80 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/t64-arm.exe filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 78 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/t64.exe filter=lfs diff=lfs merge=lfs -text
|
| 79 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn.so.8 filter=lfs diff=lfs merge=lfs -text
|
| 80 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/t64-arm.exe filter=lfs diff=lfs merge=lfs -text
|
| 81 |
+
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/lib/libcudart.so.11.0 filter=lfs diff=lfs merge=lfs -text
|
| 82 |
+
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cudagraph_trees.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_runtime/lib/libcudart.so.11.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d0da41ae1323cf4eeb610123d69d7714124cfe5ebfcc4e45f02b910e51c57ee6
|
| 3 |
+
size 679264
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/cudagraph_trees.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f112096a5626f67e200c68699bf622cf45f14ef9d7136d8c68afda693609bcdb
|
| 3 |
+
size 106203
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/CompositeRandomAccessor.h
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/CompositeRandomAccessorCommon.h>
|
| 4 |
+
|
| 5 |
+
namespace at::native {
|
| 6 |
+
|
| 7 |
+
struct TupleInfoCPU {
|
| 8 |
+
template <typename ...Types>
|
| 9 |
+
using tuple = std::tuple<Types...>;
|
| 10 |
+
|
| 11 |
+
template <typename ...Types>
|
| 12 |
+
static constexpr auto tie(Types&... args) noexcept {
|
| 13 |
+
return std::tie(args...);
|
| 14 |
+
}
|
| 15 |
+
};
|
| 16 |
+
|
| 17 |
+
template <typename KeyAccessor, typename ValueAccessor>
|
| 18 |
+
using CompositeRandomAccessorCPU =
|
| 19 |
+
CompositeRandomAccessor<KeyAccessor, ValueAccessor, TupleInfoCPU>;
|
| 20 |
+
|
| 21 |
+
template <typename Values, typename References>
|
| 22 |
+
void swap(
|
| 23 |
+
references_holder<Values, References> rh1,
|
| 24 |
+
references_holder<Values, References> rh2
|
| 25 |
+
) {
|
| 26 |
+
return std::swap(rh1.data(), rh2.data());
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
template <int N, typename Values, typename References>
|
| 30 |
+
auto get(references_holder<Values, References> rh) -> decltype(std::get<N>(rh.data())) {
|
| 31 |
+
return std::get<N>(rh.data());
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/DispatchStub.h
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/DeviceType.h>
|
| 4 |
+
#include <c10/macros/Macros.h>
|
| 5 |
+
|
| 6 |
+
#include <atomic>
|
| 7 |
+
#include <utility>
|
| 8 |
+
|
| 9 |
+
// Implements instruction set specific function dispatch.
|
| 10 |
+
//
|
| 11 |
+
// Kernels that may make use of specialized instruction sets (e.g. AVX2) are
|
| 12 |
+
// compiled multiple times with different compiler flags (e.g. -mavx2). A
|
| 13 |
+
// DispatchStub contains a table of function pointers for a kernel. At runtime,
|
| 14 |
+
// the fastest available kernel is chosen based on the features reported by
|
| 15 |
+
// cpuinfo.
|
| 16 |
+
//
|
| 17 |
+
// Example:
|
| 18 |
+
//
|
| 19 |
+
// In native/MyKernel.h:
|
| 20 |
+
// using fn_type = void(*)(const Tensor& x);
|
| 21 |
+
// DECLARE_DISPATCH(fn_type, stub);
|
| 22 |
+
//
|
| 23 |
+
// In native/MyKernel.cpp
|
| 24 |
+
// DEFINE_DISPATCH(stub);
|
| 25 |
+
//
|
| 26 |
+
// In native/cpu/MyKernel.cpp:
|
| 27 |
+
// namespace {
|
| 28 |
+
// // use anonymous namespace so that different cpu versions won't conflict
|
| 29 |
+
// void kernel(const Tensor& x) { ... }
|
| 30 |
+
// }
|
| 31 |
+
// REGISTER_DISPATCH(stub, &kernel);
|
| 32 |
+
//
|
| 33 |
+
// To call:
|
| 34 |
+
// stub(kCPU, tensor);
|
| 35 |
+
//
|
| 36 |
+
// TODO: CPU instruction set selection should be folded into whatever
|
| 37 |
+
// the main dispatch mechanism is.
|
| 38 |
+
|
| 39 |
+
// ignore warnings about DispatchStub::DEFAULT, AVX, AVX2 defined elsewhere
|
| 40 |
+
C10_CLANG_DIAGNOSTIC_PUSH()
|
| 41 |
+
C10_CLANG_DIAGNOSTIC_IGNORE("-Wundefined-var-template")
|
| 42 |
+
|
| 43 |
+
namespace at::native {
|
| 44 |
+
|
| 45 |
+
enum class CPUCapability {
|
| 46 |
+
DEFAULT = 0,
|
| 47 |
+
#if defined(HAVE_VSX_CPU_DEFINITION)
|
| 48 |
+
VSX = 1,
|
| 49 |
+
#elif defined(HAVE_ZVECTOR_CPU_DEFINITION)
|
| 50 |
+
ZVECTOR = 1,
|
| 51 |
+
#else
|
| 52 |
+
AVX2 = 1,
|
| 53 |
+
AVX512 = 2,
|
| 54 |
+
#endif
|
| 55 |
+
NUM_OPTIONS
|
| 56 |
+
};
|
| 57 |
+
|
| 58 |
+
CPUCapability get_cpu_capability();
|
| 59 |
+
|
| 60 |
+
template <typename FnPtr, typename T>
|
| 61 |
+
struct DispatchStub;
|
| 62 |
+
|
| 63 |
+
/**
|
| 64 |
+
* The sole purpose of this class is to outline methods that don't need to be
|
| 65 |
+
* specialized or otherwise inlined and duplicated (by the compiler due to
|
| 66 |
+
* template expansion), since it causes size bloat if there are a significant
|
| 67 |
+
* number of specialization of the DispatchStub<> class.
|
| 68 |
+
*/
|
| 69 |
+
struct TORCH_API DispatchStubImpl {
|
| 70 |
+
void* get_call_ptr(
|
| 71 |
+
c10::DeviceType device_type
|
| 72 |
+
, void *DEFAULT
|
| 73 |
+
#ifdef HAVE_AVX512_CPU_DEFINITION
|
| 74 |
+
, void *AVX512
|
| 75 |
+
#endif
|
| 76 |
+
#ifdef HAVE_AVX2_CPU_DEFINITION
|
| 77 |
+
, void *AVX2
|
| 78 |
+
#endif
|
| 79 |
+
#ifdef HAVE_VSX_CPU_DEFINITION
|
| 80 |
+
, void *VSX
|
| 81 |
+
#endif
|
| 82 |
+
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
| 83 |
+
, void *ZVECTOR
|
| 84 |
+
#endif
|
| 85 |
+
);
|
| 86 |
+
|
| 87 |
+
/**
|
| 88 |
+
* The CPU Dispatch actual method is chosen in decreasing order of preference by
|
| 89 |
+
* DispatchStubImpl::choose_cpu_impl() in case none is found by
|
| 90 |
+
* DispatchStubImpl::get_call_ptr() in cpu_dispatch_ptr.
|
| 91 |
+
*/
|
| 92 |
+
void* choose_cpu_impl(
|
| 93 |
+
void *DEFAULT
|
| 94 |
+
#ifdef HAVE_AVX512_CPU_DEFINITION
|
| 95 |
+
, void *AVX512
|
| 96 |
+
#endif
|
| 97 |
+
#ifdef HAVE_AVX2_CPU_DEFINITION
|
| 98 |
+
, void *AVX2
|
| 99 |
+
#endif
|
| 100 |
+
#ifdef HAVE_VSX_CPU_DEFINITION
|
| 101 |
+
, void *VSX
|
| 102 |
+
#endif
|
| 103 |
+
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
| 104 |
+
, void *ZVECTOR
|
| 105 |
+
#endif
|
| 106 |
+
);
|
| 107 |
+
|
| 108 |
+
// Fixing dispatch error in Windows debug builds.
|
| 109 |
+
// See https://github.com/pytorch/pytorch/issues/22681 for more details.
|
| 110 |
+
#if defined(_MSC_VER) && defined(_DEBUG)
|
| 111 |
+
std::atomic<void*> cpu_dispatch_ptr;
|
| 112 |
+
void* cuda_dispatch_ptr;
|
| 113 |
+
void* hip_dispatch_ptr;
|
| 114 |
+
void* mps_dispatch_ptr;
|
| 115 |
+
void* privateuse1_dispatch_ptr;
|
| 116 |
+
#else
|
| 117 |
+
std::atomic<void*> cpu_dispatch_ptr{nullptr};
|
| 118 |
+
void* cuda_dispatch_ptr = nullptr;
|
| 119 |
+
void* hip_dispatch_ptr = nullptr;
|
| 120 |
+
void* mps_dispatch_ptr = nullptr;
|
| 121 |
+
void* privateuse1_dispatch_ptr = nullptr;
|
| 122 |
+
#endif
|
| 123 |
+
};
|
| 124 |
+
|
| 125 |
+
template <typename rT, typename T, typename... Args>
|
| 126 |
+
struct DispatchStub<rT (*)(Args...), T> {
|
| 127 |
+
using FnPtr = rT (*) (Args...);
|
| 128 |
+
|
| 129 |
+
DispatchStub() = default;
|
| 130 |
+
DispatchStub(const DispatchStub&) = delete;
|
| 131 |
+
DispatchStub& operator=(const DispatchStub&) = delete;
|
| 132 |
+
|
| 133 |
+
private:
|
| 134 |
+
FnPtr get_call_ptr(c10::DeviceType device_type) {
|
| 135 |
+
return reinterpret_cast<FnPtr>(
|
| 136 |
+
impl.get_call_ptr(device_type
|
| 137 |
+
, reinterpret_cast<void*>(DEFAULT)
|
| 138 |
+
#ifdef HAVE_AVX512_CPU_DEFINITION
|
| 139 |
+
, reinterpret_cast<void*>(AVX512)
|
| 140 |
+
#endif
|
| 141 |
+
#ifdef HAVE_AVX2_CPU_DEFINITION
|
| 142 |
+
, reinterpret_cast<void*>(AVX2)
|
| 143 |
+
#endif
|
| 144 |
+
#ifdef HAVE_VSX_CPU_DEFINITION
|
| 145 |
+
, reinterpret_cast<void*>(VSX)
|
| 146 |
+
#endif
|
| 147 |
+
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
| 148 |
+
, reinterpret_cast<void*>(ZVECTOR)
|
| 149 |
+
#endif
|
| 150 |
+
)
|
| 151 |
+
);
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
public:
|
| 155 |
+
template <typename... ArgTypes>
|
| 156 |
+
rT operator()(c10::DeviceType device_type, ArgTypes&&... args) {
|
| 157 |
+
FnPtr call_ptr = get_call_ptr(device_type);
|
| 158 |
+
return (*call_ptr)(std::forward<ArgTypes>(args)...);
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
void set_cuda_dispatch_ptr(FnPtr fn_ptr) {
|
| 162 |
+
impl.cuda_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
void set_hip_dispatch_ptr(FnPtr fn_ptr) {
|
| 166 |
+
impl.hip_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
void set_mps_dispatch_ptr(FnPtr fn_ptr) {
|
| 170 |
+
impl.mps_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
void set_privateuse1_dispatch_ptr(FnPtr fn_ptr) {
|
| 174 |
+
impl.privateuse1_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
static TORCH_API FnPtr DEFAULT;
|
| 178 |
+
#ifdef HAVE_AVX512_CPU_DEFINITION
|
| 179 |
+
static TORCH_API FnPtr AVX512;
|
| 180 |
+
#endif
|
| 181 |
+
#ifdef HAVE_AVX2_CPU_DEFINITION
|
| 182 |
+
static TORCH_API FnPtr AVX2;
|
| 183 |
+
#endif
|
| 184 |
+
#ifdef HAVE_VSX_CPU_DEFINITION
|
| 185 |
+
static TORCH_API FnPtr VSX;
|
| 186 |
+
#endif
|
| 187 |
+
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
| 188 |
+
static TORCH_API FnPtr ZVECTOR;
|
| 189 |
+
#endif
|
| 190 |
+
private:
|
| 191 |
+
DispatchStubImpl impl;
|
| 192 |
+
};
|
| 193 |
+
|
| 194 |
+
namespace {
|
| 195 |
+
template <typename DispatchStub>
|
| 196 |
+
struct RegisterCUDADispatch {
|
| 197 |
+
RegisterCUDADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
|
| 198 |
+
stub.set_cuda_dispatch_ptr(value);
|
| 199 |
+
}
|
| 200 |
+
};
|
| 201 |
+
|
| 202 |
+
template <typename DispatchStub>
|
| 203 |
+
struct RegisterMPSDispatch {
|
| 204 |
+
RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
|
| 205 |
+
stub.set_mps_dispatch_ptr(value);
|
| 206 |
+
}
|
| 207 |
+
};
|
| 208 |
+
|
| 209 |
+
template <typename DispatchStub>
|
| 210 |
+
struct RegisterHIPDispatch {
|
| 211 |
+
RegisterHIPDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
|
| 212 |
+
// TODO: make this point at hip_dispatch_ptr
|
| 213 |
+
stub.set_cuda_dispatch_ptr(value);
|
| 214 |
+
}
|
| 215 |
+
};
|
| 216 |
+
|
| 217 |
+
template <typename DispatchStub>
|
| 218 |
+
struct RegisterPRIVATEUSE1Dispatch {
|
| 219 |
+
RegisterPRIVATEUSE1Dispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
|
| 220 |
+
stub.set_privateuse1_dispatch_ptr(value);
|
| 221 |
+
}
|
| 222 |
+
};
|
| 223 |
+
|
| 224 |
+
} // anonymous namespace
|
| 225 |
+
// Compiler will complain if you put things like std::tuple<Tensor, Tensor> in
|
| 226 |
+
// the `fn` argument of DECLARE_DISPATCH. Some possible workarounds, e.g.,
|
| 227 |
+
// adding parentheses and using helper struct to get rid of the parentheses, do
|
| 228 |
+
// not work with MSVC. So do a `using`-declaration if you need to pass in such
|
| 229 |
+
// `fn`, e.g., grid_sampler_2d_backward_cpu_kernel in GridSampleKernel.h.
|
| 230 |
+
#define DECLARE_DISPATCH(fn, name) \
|
| 231 |
+
struct name : DispatchStub<fn, name> { \
|
| 232 |
+
name() = default; \
|
| 233 |
+
name(const name&) = delete; \
|
| 234 |
+
name& operator=(const name&) = delete; \
|
| 235 |
+
}; \
|
| 236 |
+
extern TORCH_API struct name name
|
| 237 |
+
|
| 238 |
+
#define DEFINE_DISPATCH(name) struct name name
|
| 239 |
+
|
| 240 |
+
#define REGISTER_ARCH_DISPATCH(name, arch, fn) \
|
| 241 |
+
template <> name::FnPtr TORCH_API DispatchStub<name::FnPtr, struct name>::arch = fn;
|
| 242 |
+
|
| 243 |
+
#ifdef HAVE_AVX512_CPU_DEFINITION
|
| 244 |
+
#define REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX512, fn)
|
| 245 |
+
#else
|
| 246 |
+
#define REGISTER_AVX512_DISPATCH(name, fn)
|
| 247 |
+
#endif
|
| 248 |
+
|
| 249 |
+
#ifdef HAVE_AVX2_CPU_DEFINITION
|
| 250 |
+
#define REGISTER_AVX2_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX2, fn)
|
| 251 |
+
#else
|
| 252 |
+
#define REGISTER_AVX2_DISPATCH(name, fn)
|
| 253 |
+
#endif
|
| 254 |
+
|
| 255 |
+
#ifdef HAVE_VSX_CPU_DEFINITION
|
| 256 |
+
#define REGISTER_VSX_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, VSX, fn)
|
| 257 |
+
#else
|
| 258 |
+
#define REGISTER_VSX_DISPATCH(name, fn)
|
| 259 |
+
#endif
|
| 260 |
+
|
| 261 |
+
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
| 262 |
+
#define REGISTER_ZVECTOR_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, ZVECTOR, fn)
|
| 263 |
+
#else
|
| 264 |
+
#define REGISTER_ZVECTOR_DISPATCH(name, fn)
|
| 265 |
+
#endif
|
| 266 |
+
|
| 267 |
+
// Macro to register the same kernel for all CPU arch types. This is useful
|
| 268 |
+
// if a kernel does not benefit from being recompiled across different arch types.
|
| 269 |
+
#define REGISTER_ALL_CPU_DISPATCH(name, fn) \
|
| 270 |
+
REGISTER_ARCH_DISPATCH(name, DEFAULT, fn) \
|
| 271 |
+
REGISTER_AVX512_DISPATCH(name, fn) \
|
| 272 |
+
REGISTER_AVX2_DISPATCH(name, fn) \
|
| 273 |
+
REGISTER_VSX_DISPATCH(name, fn) \
|
| 274 |
+
REGISTER_ZVECTOR_DISPATCH(name, fn)
|
| 275 |
+
|
| 276 |
+
#define REGISTER_NO_CPU_DISPATCH(name) \
|
| 277 |
+
REGISTER_ALL_CPU_DISPATCH(name, nullptr)
|
| 278 |
+
|
| 279 |
+
#define REGISTER_CUDA_DISPATCH(name, fn) \
|
| 280 |
+
static RegisterCUDADispatch<struct name> name ## __register(name, fn);
|
| 281 |
+
|
| 282 |
+
#define REGISTER_HIP_DISPATCH(name, fn) \
|
| 283 |
+
static RegisterHIPDispatch<struct name> name ## __register(name, fn);
|
| 284 |
+
|
| 285 |
+
#define REGISTER_MPS_DISPATCH(name, fn) \
|
| 286 |
+
static RegisterMPSDispatch<struct name> name ## __register(name, fn);
|
| 287 |
+
|
| 288 |
+
#define REGISTER_PRIVATEUSE1_DISPATCH(name, fn) \
|
| 289 |
+
static RegisterPRIVATEUSE1Dispatch<struct name> name ## __register(name, fn);
|
| 290 |
+
|
| 291 |
+
// NB: This macro must be used in an actual 'cu' file; if you try using
|
| 292 |
+
// it from a 'cpp' file it will not work!
|
| 293 |
+
#if defined(__CUDACC__)
|
| 294 |
+
#define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
|
| 295 |
+
#elif defined(__HIPCC__)
|
| 296 |
+
// TODO: cut this over to HIP dispatch once we stop pretending that CUDA
|
| 297 |
+
// is HIP in the PyTorch HIPify build.
|
| 298 |
+
#define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
|
| 299 |
+
// #define REGISTER_DISPATCH(name, fn) REGISTER_HIP_DISPATCH(name, fn)
|
| 300 |
+
#elif defined(__OBJC__) && defined(USE_MPS)
|
| 301 |
+
// NB: this macro must be used from a 'mm' file in order to dispatch a MPS kernel
|
| 302 |
+
#define REGISTER_DISPATCH(name, fn) REGISTER_MPS_DISPATCH(name, fn)
|
| 303 |
+
#elif defined(CPU_CAPABILITY)
|
| 304 |
+
// REGISTER_DISPATCH now dispatches an AVX512 kernel to nullptr but registers other dispatches.
|
| 305 |
+
// ALSO_REGISTER_AVX512_DISPATCH should be used for ensuring AVX512 dispatch, among others.
|
| 306 |
+
#ifdef CPU_CAPABILITY_AVX512
|
| 307 |
+
#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, nullptr)
|
| 308 |
+
#else
|
| 309 |
+
#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
|
| 310 |
+
#endif
|
| 311 |
+
#define ALSO_REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
|
| 312 |
+
#endif
|
| 313 |
+
} // namespace at::native
|
| 314 |
+
|
| 315 |
+
C10_CLANG_DIAGNOSTIC_POP()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Distance.h
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
class Tensor;
|
| 7 |
+
|
| 8 |
+
namespace native {
|
| 9 |
+
|
| 10 |
+
using pdist_forward_fn = void(*)(Tensor&, const Tensor&, const double p);
|
| 11 |
+
using pdist_backward_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const double p, const Tensor&);
|
| 12 |
+
using cdist_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const double p);
|
| 13 |
+
using cdist_backward_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const Tensor&, const double p, const Tensor&);
|
| 14 |
+
|
| 15 |
+
DECLARE_DISPATCH(pdist_forward_fn, pdist_forward_stub);
|
| 16 |
+
DECLARE_DISPATCH(pdist_backward_fn, pdist_backward_stub);
|
| 17 |
+
DECLARE_DISPATCH(cdist_fn, cdist_stub);
|
| 18 |
+
DECLARE_DISPATCH(cdist_backward_fn, cdist_backward_stub);
|
| 19 |
+
|
| 20 |
+
}} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/FractionalMaxPooling.h
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Tensor.h>
|
| 3 |
+
#include <ATen/TensorUtils.h>
|
| 4 |
+
#include <c10/util/irange.h>
|
| 5 |
+
|
| 6 |
+
namespace at::native {
|
| 7 |
+
|
| 8 |
+
template<typename scalar_t>
|
| 9 |
+
static inline std::vector<int> generate_intervals(
|
| 10 |
+
scalar_t sample,
|
| 11 |
+
int64_t inputSize,
|
| 12 |
+
int64_t outputSize,
|
| 13 |
+
int64_t poolSize) {
|
| 14 |
+
std::vector<int> sequence(outputSize);
|
| 15 |
+
if (outputSize > 1) {
|
| 16 |
+
scalar_t alpha = static_cast<scalar_t>(inputSize - poolSize) /
|
| 17 |
+
static_cast<scalar_t>(outputSize - 1);
|
| 18 |
+
|
| 19 |
+
for (const auto i : c10::irange(outputSize - 1)) {
|
| 20 |
+
sequence[i] =
|
| 21 |
+
static_cast<int>((i + sample) * alpha) - static_cast<int>(sample * alpha);
|
| 22 |
+
}
|
| 23 |
+
}
|
| 24 |
+
if (outputSize > 0) {
|
| 25 |
+
sequence[outputSize - 1] = inputSize - poolSize;
|
| 26 |
+
}
|
| 27 |
+
return sequence;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
template <int64_t ndim>
|
| 31 |
+
static inline void fractional_max_pool_check_shape(
|
| 32 |
+
const Tensor& input,
|
| 33 |
+
const Tensor& randomSamples) {
|
| 34 |
+
|
| 35 |
+
TORCH_CHECK(
|
| 36 |
+
input.scalar_type() == randomSamples.scalar_type(),
|
| 37 |
+
"Expect _random_samples to have the same dtype as input");
|
| 38 |
+
|
| 39 |
+
int64_t ndimension = randomSamples.ndimension();
|
| 40 |
+
TORCH_CHECK(
|
| 41 |
+
ndimension == 3,
|
| 42 |
+
"Expect _random_samples to have 3 dimensions, got ", ndimension);
|
| 43 |
+
|
| 44 |
+
int64_t N = randomSamples.size(0);
|
| 45 |
+
int64_t C = randomSamples.size(1);
|
| 46 |
+
int64_t D = randomSamples.size(2);
|
| 47 |
+
|
| 48 |
+
int64_t input_batch, input_channel;
|
| 49 |
+
if (ndim == 2) {
|
| 50 |
+
// fractional_max_pool2d
|
| 51 |
+
if (input.ndimension() == 3) {
|
| 52 |
+
input_batch = 1;
|
| 53 |
+
input_channel = input.size(0);
|
| 54 |
+
} else {
|
| 55 |
+
input_batch = input.size(0);
|
| 56 |
+
input_channel = input.size(1);
|
| 57 |
+
}
|
| 58 |
+
} else {
|
| 59 |
+
// factional_max_pool3d
|
| 60 |
+
if (input.ndimension() == 4) {
|
| 61 |
+
input_batch = 1;
|
| 62 |
+
input_channel = input.size(0);
|
| 63 |
+
} else {
|
| 64 |
+
input_batch = input.size(0);
|
| 65 |
+
input_channel = input.size(1);
|
| 66 |
+
}
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
TORCH_CHECK(
|
| 70 |
+
N >= input_batch,
|
| 71 |
+
"Expect _random_samples.size(0) no less then input batch size.");
|
| 72 |
+
TORCH_CHECK(
|
| 73 |
+
C == input_channel,
|
| 74 |
+
"Expect _random_samples.size(1) equals to input channel size.");
|
| 75 |
+
TORCH_CHECK(
|
| 76 |
+
D == ndim,
|
| 77 |
+
"Expect _random_samples.size(2) equals to ", ndim, "; got ", D, ".");
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/FunctionOfAMatrixUtils.h
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
#include <cstdint>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
struct TensorIterator;
|
| 8 |
+
|
| 9 |
+
namespace native {
|
| 10 |
+
|
| 11 |
+
using _compute_linear_combination_fn = void(*)(
|
| 12 |
+
TensorIterator& iter,
|
| 13 |
+
int64_t in_stride,
|
| 14 |
+
int64_t coeff_stride,
|
| 15 |
+
int64_t num_summations
|
| 16 |
+
);
|
| 17 |
+
|
| 18 |
+
DECLARE_DISPATCH(_compute_linear_combination_fn, _compute_linear_combination_stub);
|
| 19 |
+
|
| 20 |
+
}} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/GridSampler.h
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <algorithm>
|
| 4 |
+
#include <cmath>
|
| 5 |
+
#include <cstdint>
|
| 6 |
+
#include <utility>
|
| 7 |
+
|
| 8 |
+
#include <ATen/native/GridSamplerUtils.h>
|
| 9 |
+
|
| 10 |
+
namespace at::native {
|
| 11 |
+
|
| 12 |
+
using detail::GridSamplerInterpolation;
|
| 13 |
+
using detail::GridSamplerPadding;
|
| 14 |
+
|
| 15 |
+
// Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value,
|
| 16 |
+
// where we view each pixel as an area between (idx - 0.5) and (idx + 0.5).
|
| 17 |
+
// if align_corners: -1 and +1 get sent to the centers of the corner pixels
|
| 18 |
+
// -1 --> 0
|
| 19 |
+
// +1 --> (size - 1)
|
| 20 |
+
// scale_factor = (size - 1) / 2
|
| 21 |
+
// if not align_corners: -1 and +1 get sent to the image edges
|
| 22 |
+
// -1 --> -0.5
|
| 23 |
+
// +1 --> (size - 1) + 0.5 == size - 0.5
|
| 24 |
+
// scale_factor = size / 2
|
| 25 |
+
template <typename scalar_t>
|
| 26 |
+
static inline scalar_t grid_sampler_unnormalize(scalar_t coord, int64_t size,
|
| 27 |
+
bool align_corners) {
|
| 28 |
+
if (align_corners) {
|
| 29 |
+
// unnormalize coord from [-1, 1] to [0, size - 1]
|
| 30 |
+
return ((coord + 1) / 2) * (size - 1);
|
| 31 |
+
} else {
|
| 32 |
+
// unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
|
| 33 |
+
return ((coord + 1) * size - 1) / 2;
|
| 34 |
+
}
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
// grid_sampler_unnormalize_set_grad works the same as grid_sampler_unnormalize
|
| 38 |
+
// except that it also returns the `d output / d input` via pointer argument
|
| 39 |
+
// `grad_in`.
|
| 40 |
+
// This is useful in the backward pass of grid_sampler.
|
| 41 |
+
template <typename scalar_t>
|
| 42 |
+
static inline scalar_t grid_sampler_unnormalize_set_grad(scalar_t coord, int64_t size,
|
| 43 |
+
bool align_corners, scalar_t *grad_in) {
|
| 44 |
+
if (align_corners) {
|
| 45 |
+
// unnormalize coord from [-1, 1] to [0, size - 1]
|
| 46 |
+
*grad_in = static_cast<scalar_t>(size - 1) / 2;
|
| 47 |
+
return ((coord + 1) / 2) * (size - 1);
|
| 48 |
+
} else {
|
| 49 |
+
// unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
|
| 50 |
+
*grad_in = static_cast<scalar_t>(size) / 2;
|
| 51 |
+
return ((coord + 1) * size - 1) / 2;
|
| 52 |
+
}
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
// Clips coordinates to between 0 and clip_limit - 1
|
| 56 |
+
template<typename scalar_t>
|
| 57 |
+
static inline scalar_t clip_coordinates(scalar_t in, int64_t clip_limit) {
|
| 58 |
+
return std::min(static_cast<scalar_t>(clip_limit - 1), std::max(in, static_cast<scalar_t>(0)));
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
// clip_coordinates_set_grad works similarly to clip_coordinates except that
|
| 62 |
+
// it also returns the `d output / d input` via pointer argument `grad_in`.
|
| 63 |
+
// This is useful in the backward pass of grid_sampler.
|
| 64 |
+
template<typename scalar_t>
|
| 65 |
+
static inline scalar_t clip_coordinates_set_grad(scalar_t in, int64_t clip_limit,
|
| 66 |
+
scalar_t *grad_in) {
|
| 67 |
+
// Note that it is important for the gradient calculation that borders
|
| 68 |
+
// are considered out of bounds.
|
| 69 |
+
if (in <= static_cast<scalar_t>(0)) {
|
| 70 |
+
*grad_in = static_cast<scalar_t>(0);
|
| 71 |
+
return static_cast<scalar_t>(0);
|
| 72 |
+
} else {
|
| 73 |
+
scalar_t max = static_cast<scalar_t>(clip_limit - 1);
|
| 74 |
+
if (in >= max) {
|
| 75 |
+
*grad_in = static_cast<scalar_t>(0);
|
| 76 |
+
return max;
|
| 77 |
+
} else {
|
| 78 |
+
*grad_in = static_cast<scalar_t>(1);
|
| 79 |
+
return in;
|
| 80 |
+
}
|
| 81 |
+
}
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
// Reflects coordinates until they fall between low and high (inclusive).
|
| 85 |
+
// The bounds are passed as twice their value so that half-integer values
|
| 86 |
+
// can be represented as ints.
|
| 87 |
+
template<typename scalar_t>
|
| 88 |
+
static inline scalar_t reflect_coordinates(scalar_t in, int64_t twice_low,
|
| 89 |
+
int64_t twice_high) {
|
| 90 |
+
if (twice_low == twice_high) {
|
| 91 |
+
return static_cast<scalar_t>(0);
|
| 92 |
+
}
|
| 93 |
+
scalar_t min = static_cast<scalar_t>(twice_low) / 2;
|
| 94 |
+
scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
|
| 95 |
+
in = std::fabs(in - min);
|
| 96 |
+
// `fmod` returns same sign as `in`, which is positive after the `fabs` above.
|
| 97 |
+
scalar_t extra = std::fmod(in, span);
|
| 98 |
+
int flips = static_cast<int>(std::floor(in / span));
|
| 99 |
+
if (flips % 2 == 0) {
|
| 100 |
+
return extra + min;
|
| 101 |
+
} else {
|
| 102 |
+
return span - extra + min;
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
// reflect_coordinates_set_grad works similarly to reflect_coordinates except
|
| 107 |
+
// that it also returns the `d output / d input` via pointer argument
|
| 108 |
+
// `grad_in`.
|
| 109 |
+
// This is useful in the backward pass of grid_sampler.
|
| 110 |
+
template<typename scalar_t>
|
| 111 |
+
static inline scalar_t reflect_coordinates_set_grad(scalar_t in, int64_t twice_low,
|
| 112 |
+
int64_t twice_high, scalar_t *grad_in) {
|
| 113 |
+
if (twice_low == twice_high) {
|
| 114 |
+
*grad_in = static_cast<scalar_t>(0);
|
| 115 |
+
return static_cast<scalar_t>(0);
|
| 116 |
+
}
|
| 117 |
+
int grad_in_mult_;
|
| 118 |
+
scalar_t min = static_cast<scalar_t>(twice_low) / 2;
|
| 119 |
+
scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
|
| 120 |
+
in = in - min;
|
| 121 |
+
if (in < static_cast<scalar_t>(0)) {
|
| 122 |
+
grad_in_mult_ = -1;
|
| 123 |
+
in = -in;
|
| 124 |
+
} else {
|
| 125 |
+
grad_in_mult_ = 1;
|
| 126 |
+
}
|
| 127 |
+
// `fmod` returns same sign as `in`, which is positive after the `if` above.
|
| 128 |
+
scalar_t extra = std::fmod(in, span);
|
| 129 |
+
int flips = static_cast<int>(std::floor(in / span));
|
| 130 |
+
if (flips % 2 == 0) {
|
| 131 |
+
*grad_in = static_cast<scalar_t>(grad_in_mult_);
|
| 132 |
+
return extra + min;
|
| 133 |
+
} else {
|
| 134 |
+
*grad_in = static_cast<scalar_t>(-grad_in_mult_);
|
| 135 |
+
return span - extra + min;
|
| 136 |
+
}
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
// Mapping the out-of-boundary points back into boundary
|
| 140 |
+
// This would only affect padding_mode=border or reflection
|
| 141 |
+
template<typename scalar_t>
|
| 142 |
+
static inline scalar_t compute_coordinates(scalar_t coord, int64_t size,
|
| 143 |
+
GridSamplerPadding padding_mode,
|
| 144 |
+
bool align_corners) {
|
| 145 |
+
if (padding_mode == GridSamplerPadding::Border) {
|
| 146 |
+
// clip coordinates to image borders
|
| 147 |
+
coord = clip_coordinates(coord, size);
|
| 148 |
+
} else if (padding_mode == GridSamplerPadding::Reflection) {
|
| 149 |
+
// reflect coordinates by image borders
|
| 150 |
+
if (align_corners) {
|
| 151 |
+
coord = reflect_coordinates(coord, 0, 2*(size - 1));
|
| 152 |
+
} else {
|
| 153 |
+
coord = reflect_coordinates(coord, -1, 2*size - 1);
|
| 154 |
+
}
|
| 155 |
+
// clip coordinates to image borders
|
| 156 |
+
coord = clip_coordinates(coord, size);
|
| 157 |
+
}
|
| 158 |
+
return coord;
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
// Computes the pixel source index value for a grid coordinate
|
| 162 |
+
template <typename scalar_t>
|
| 163 |
+
static inline scalar_t grid_sampler_compute_source_index(
|
| 164 |
+
scalar_t coord,
|
| 165 |
+
int64_t size,
|
| 166 |
+
GridSamplerPadding padding_mode,
|
| 167 |
+
bool align_corners) {
|
| 168 |
+
coord = grid_sampler_unnormalize(coord, size, align_corners);
|
| 169 |
+
coord = compute_coordinates(coord, size, padding_mode, align_corners);
|
| 170 |
+
return coord;
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
// grid_sampler_compute_source_index_set_grad works similarly to
|
| 174 |
+
// grid_sampler_compute_source_index except that it also returns the
|
| 175 |
+
// `d output / d input` via pointer argument `grad_in`.
|
| 176 |
+
// This is useful in the backward pass of grid_sampler.
|
| 177 |
+
template <typename scalar_t>
|
| 178 |
+
static inline scalar_t grid_sampler_compute_source_index_set_grad(
|
| 179 |
+
scalar_t coord,
|
| 180 |
+
int64_t size,
|
| 181 |
+
GridSamplerPadding padding_mode,
|
| 182 |
+
bool align_corners,
|
| 183 |
+
scalar_t *grad_in) {
|
| 184 |
+
scalar_t grad_clip, grad_refl;
|
| 185 |
+
coord = grid_sampler_unnormalize_set_grad(coord, size, align_corners, grad_in);
|
| 186 |
+
if (padding_mode == GridSamplerPadding::Border) {
|
| 187 |
+
// clip coordinates to image borders
|
| 188 |
+
coord = clip_coordinates_set_grad(coord, size, &grad_clip);
|
| 189 |
+
*grad_in = (*grad_in) * grad_clip;
|
| 190 |
+
} else if (padding_mode == GridSamplerPadding::Reflection) {
|
| 191 |
+
// reflect coordinates by image borders
|
| 192 |
+
if (align_corners) {
|
| 193 |
+
coord = reflect_coordinates_set_grad(coord, 0, 2*(size - 1), &grad_refl);
|
| 194 |
+
} else {
|
| 195 |
+
coord = reflect_coordinates_set_grad(coord, -1, 2*size - 1, &grad_refl);
|
| 196 |
+
}
|
| 197 |
+
// clip coordinates to image borders
|
| 198 |
+
coord = clip_coordinates_set_grad(coord, size, &grad_clip);
|
| 199 |
+
*grad_in = (*grad_in) * grad_refl * grad_clip;
|
| 200 |
+
}
|
| 201 |
+
return coord;
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
static inline bool within_bounds_2d(int64_t h, int64_t w, int64_t H, int64_t W) {
|
| 205 |
+
return h >= 0 && h < H && w >= 0 && w < W;
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
static inline bool within_bounds_3d(int64_t d, int64_t h, int64_t w, int64_t D, int64_t H, int64_t W) {
|
| 209 |
+
return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W;
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
template<typename scalar_t>
|
| 213 |
+
static inline scalar_t get_value_bounded(
|
| 214 |
+
scalar_t* data,
|
| 215 |
+
scalar_t x,
|
| 216 |
+
scalar_t y,
|
| 217 |
+
int64_t W,
|
| 218 |
+
int64_t H,
|
| 219 |
+
int64_t sW,
|
| 220 |
+
int64_t sH,
|
| 221 |
+
GridSamplerPadding padding_mode,
|
| 222 |
+
bool align_corners) {
|
| 223 |
+
|
| 224 |
+
x = compute_coordinates(x, W, padding_mode, align_corners);
|
| 225 |
+
y = compute_coordinates(y, H, padding_mode, align_corners);
|
| 226 |
+
|
| 227 |
+
int64_t ix = static_cast<int64_t>(x);
|
| 228 |
+
int64_t iy = static_cast<int64_t>(y);
|
| 229 |
+
|
| 230 |
+
if (within_bounds_2d(iy, ix, H, W)) {
|
| 231 |
+
return data[iy * sH + ix * sW];
|
| 232 |
+
}
|
| 233 |
+
return static_cast<scalar_t>(0);
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
template<typename scalar_t>
|
| 237 |
+
static inline void safe_add_2d(scalar_t *data, int64_t h, int64_t w,
|
| 238 |
+
int64_t sH, int64_t sW, int64_t H, int64_t W,
|
| 239 |
+
scalar_t delta) {
|
| 240 |
+
if (within_bounds_2d(h, w, H, W)) {
|
| 241 |
+
data[h * sH + w * sW] += delta;
|
| 242 |
+
}
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
template<typename scalar_t>
|
| 246 |
+
static inline void safe_add_3d(scalar_t *data, int64_t d, int64_t h, int64_t w,
|
| 247 |
+
int64_t sD, int64_t sH, int64_t sW,
|
| 248 |
+
int64_t D, int64_t H, int64_t W,
|
| 249 |
+
scalar_t delta) {
|
| 250 |
+
if (within_bounds_3d(d, h, w, D, H, W)) {
|
| 251 |
+
data[d * sD + h * sH + w * sW] += delta;
|
| 252 |
+
}
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
template<typename scalar_t>
|
| 256 |
+
static inline void add_value_bounded(
|
| 257 |
+
scalar_t* data,
|
| 258 |
+
scalar_t x,
|
| 259 |
+
scalar_t y,
|
| 260 |
+
int64_t W,
|
| 261 |
+
int64_t H,
|
| 262 |
+
int64_t sW,
|
| 263 |
+
int64_t sH,
|
| 264 |
+
scalar_t delta,
|
| 265 |
+
GridSamplerPadding padding_mode,
|
| 266 |
+
bool align_corners) {
|
| 267 |
+
|
| 268 |
+
x = compute_coordinates(x, W, padding_mode, align_corners);
|
| 269 |
+
y = compute_coordinates(y, H, padding_mode, align_corners);
|
| 270 |
+
|
| 271 |
+
int64_t ix = static_cast<int64_t>(x);
|
| 272 |
+
int64_t iy = static_cast<int64_t>(y);
|
| 273 |
+
|
| 274 |
+
safe_add_2d(data, iy, ix, sH, sW, H, W, delta);
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
// Calculate the differential of the cubic convolution, i.e. `d coeff / d x`
|
| 278 |
+
template<typename scalar_t>
|
| 279 |
+
static inline void get_cubic_coefficients_grad(
|
| 280 |
+
scalar_t coeffs[4],
|
| 281 |
+
scalar_t t) {
|
| 282 |
+
|
| 283 |
+
// Must be the same as forward calculation in
|
| 284 |
+
// aten/src/ATen/native/UpSample.h:get_cubic_upsample_coefficients
|
| 285 |
+
scalar_t A = -0.75;
|
| 286 |
+
|
| 287 |
+
scalar_t x;
|
| 288 |
+
x = -1 - t; // 1 < x = |-1 - tx| < 2
|
| 289 |
+
coeffs[0] = (-3 * A * x - 10 * A ) * x - 8 * A;
|
| 290 |
+
x = -t; // x = |0 - tx| <= 1
|
| 291 |
+
coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x;
|
| 292 |
+
x = 1 - t; // x = |1 - tx| <= 1
|
| 293 |
+
coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x;
|
| 294 |
+
x = 2 - t; // 1 < x = |2 - tx| < 2
|
| 295 |
+
coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A;
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/GridSamplerUtils.h
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// See NOTE: [Tensor vs. TensorBase]
|
| 4 |
+
// https://github.com/pytorch/pytorch/pull/66979
|
| 5 |
+
#include <ATen/core/TensorBase.h>
|
| 6 |
+
#include <ATen/native/TensorProperties.h>
|
| 7 |
+
#include <ATen/native/CanUse32BitIndexMath.h>
|
| 8 |
+
|
| 9 |
+
namespace at::native {
|
| 10 |
+
|
| 11 |
+
namespace detail {
|
| 12 |
+
|
| 13 |
+
enum class GridSamplerInterpolation {Bilinear, Nearest, Bicubic};
|
| 14 |
+
enum class GridSamplerPadding {Zeros, Border, Reflection};
|
| 15 |
+
|
| 16 |
+
} // namespace detail
|
| 17 |
+
|
| 18 |
+
using detail::GridSamplerInterpolation;
|
| 19 |
+
using detail::GridSamplerPadding;
|
| 20 |
+
|
| 21 |
+
namespace {
|
| 22 |
+
|
| 23 |
+
// See NOTE [ grid_sampler Native Functions ].
|
| 24 |
+
void check_grid_sampler_common(
|
| 25 |
+
const TensorBase& input,
|
| 26 |
+
const TensorBase& grid
|
| 27 |
+
) {
|
| 28 |
+
auto input_opt = input.options();
|
| 29 |
+
auto grid_opt = grid.options();
|
| 30 |
+
|
| 31 |
+
TORCH_CHECK(
|
| 32 |
+
input.defined(),
|
| 33 |
+
"grid_sampler(): expected input to not be undefined");
|
| 34 |
+
TORCH_CHECK(
|
| 35 |
+
grid.defined(),
|
| 36 |
+
"grid_sampler(): expected grid to not be undefined");
|
| 37 |
+
TORCH_CHECK(
|
| 38 |
+
input_opt.device() == grid_opt.device(),
|
| 39 |
+
"grid_sampler(): expected input and grid to be on same device, but input "
|
| 40 |
+
"is on ", input_opt.device(), " and grid is on ", grid_opt.device());
|
| 41 |
+
TORCH_CHECK(
|
| 42 |
+
input_opt.layout() == kStrided && grid_opt.layout() == kStrided,
|
| 43 |
+
"grid_sampler(): expected input and grid to have torch.strided layout, but "
|
| 44 |
+
"input has ", input_opt.layout(), " and grid has ", grid_opt.layout());
|
| 45 |
+
TORCH_CHECK(
|
| 46 |
+
input.size(0) == grid.size(0),
|
| 47 |
+
"grid_sampler(): expected grid and input to have same batch size, but got "
|
| 48 |
+
"input with sizes ", input.sizes(), " and grid with sizes ", grid.sizes());
|
| 49 |
+
TORCH_CHECK(
|
| 50 |
+
grid.size(-1) == input.dim() - 2,
|
| 51 |
+
"grid_sampler(): expected grid to have size ", input.dim() - 2, " in last "
|
| 52 |
+
"dimension, but got grid with sizes ", grid.sizes());
|
| 53 |
+
|
| 54 |
+
for (const auto i : c10::irange(2, input.dim())) {
|
| 55 |
+
TORCH_CHECK(input.size(i) > 0,
|
| 56 |
+
"grid_sampler(): expected input to have non-empty spatial dimensions, "
|
| 57 |
+
"but input has sizes ", input.sizes(), " with dimension ", i, " being "
|
| 58 |
+
"empty");
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
// See NOTE [ grid_sampler Native Functions ].
|
| 63 |
+
void check_grid_sampler_2d(
|
| 64 |
+
const TensorBase& input,
|
| 65 |
+
const TensorBase& grid
|
| 66 |
+
) {
|
| 67 |
+
TORCH_CHECK(
|
| 68 |
+
input.dim() == 4 && input.dim() == grid.dim(),
|
| 69 |
+
"grid_sampler(): expected 4D input and grid with same number of "
|
| 70 |
+
"dimensions, but got input with sizes ", input.sizes(),
|
| 71 |
+
" and grid with sizes ", grid.sizes());
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
// See NOTE [ grid_sampler Native Functions ].
|
| 75 |
+
void check_grid_sampler_3d(
|
| 76 |
+
const TensorBase& input,
|
| 77 |
+
const TensorBase& grid,
|
| 78 |
+
int64_t interpolation_mode
|
| 79 |
+
) {
|
| 80 |
+
TORCH_CHECK(
|
| 81 |
+
input.dim() == 5 && input.dim() == grid.dim(),
|
| 82 |
+
"grid_sampler(): expected 5D input and grid with same number of "
|
| 83 |
+
"dimensions, but got input with sizes ", input.sizes(),
|
| 84 |
+
" and grid with sizes ", grid.sizes());
|
| 85 |
+
TORCH_CHECK(
|
| 86 |
+
!(input.dim() == 5 &&
|
| 87 |
+
static_cast<GridSamplerInterpolation>(interpolation_mode) ==
|
| 88 |
+
GridSamplerInterpolation::Bicubic),
|
| 89 |
+
"grid_sampler(): bicubic interpolation only supports 4D input");
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
// See NOTE [ grid_sampler Native Functions ].
|
| 93 |
+
// cudnn does not support inputs larger than 1024.
|
| 94 |
+
bool cond_cudnn_grid_sampler(
|
| 95 |
+
const TensorBase& input,
|
| 96 |
+
const TensorBase& grid
|
| 97 |
+
) {
|
| 98 |
+
return (
|
| 99 |
+
at::native::cudnn_is_acceptable(input) &&
|
| 100 |
+
at::native::cudnn_is_acceptable(grid) &&
|
| 101 |
+
at::native::canUse32BitIndexMath(input) &&
|
| 102 |
+
at::native::canUse32BitIndexMath(grid) &&
|
| 103 |
+
input.dim() == 4 &&
|
| 104 |
+
input.sym_size(1) <= 1024);
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
} // anonymous namespace
|
| 108 |
+
|
| 109 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Lerp.h
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
#include <ATen/OpMathType.h>
|
| 5 |
+
#include <ATen/TensorIterator.h>
|
| 6 |
+
#include <c10/core/Scalar.h>
|
| 7 |
+
|
| 8 |
+
namespace at::native {
|
| 9 |
+
|
| 10 |
+
template <typename scalar_t>
|
| 11 |
+
C10_HOST_DEVICE C10_ALWAYS_INLINE bool is_lerp_weight_small(scalar_t weight) {
|
| 12 |
+
return std::abs(weight) < scalar_t(0.5);
|
| 13 |
+
}
|
| 14 |
+
template <typename scalar_t>
|
| 15 |
+
C10_HOST_DEVICE C10_ALWAYS_INLINE bool is_lerp_weight_small(c10::complex<scalar_t> weight) {
|
| 16 |
+
// Avoid the sqrt in abs(weight)
|
| 17 |
+
return (weight.real() * weight.real() + weight.imag() * weight.imag()) < scalar_t(0.25);
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
template <typename scalar_t, typename weight_t>
|
| 21 |
+
C10_HOST_DEVICE C10_ALWAYS_INLINE scalar_t lerp(scalar_t self_, scalar_t end_, weight_t weight_) {
|
| 22 |
+
using opmath_t = at::opmath_type<scalar_t>;
|
| 23 |
+
using opmath_weight_t = at::opmath_type<weight_t>;
|
| 24 |
+
|
| 25 |
+
opmath_t self = self_;
|
| 26 |
+
opmath_t end = end_;
|
| 27 |
+
opmath_weight_t weight = weight_;
|
| 28 |
+
|
| 29 |
+
// Conditional for better numeric. This has been discussed in
|
| 30 |
+
// https://github.com/pytorch/pytorch/pull/18871
|
| 31 |
+
return is_lerp_weight_small(weight)
|
| 32 |
+
? self + weight * (end - self)
|
| 33 |
+
: end - (end - self) * (opmath_t(1) - weight);
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
using lerp_fn_scalar = void (*)(
|
| 37 |
+
at::TensorIteratorBase& iter,
|
| 38 |
+
const Scalar& weight);
|
| 39 |
+
|
| 40 |
+
using lerp_fn_tensor = void (*)(
|
| 41 |
+
at::TensorIteratorBase& iter);
|
| 42 |
+
|
| 43 |
+
DECLARE_DISPATCH(lerp_fn_scalar, lerp_kernel_scalar_weight);
|
| 44 |
+
DECLARE_DISPATCH(lerp_fn_tensor, lerp_kernel_tensor_weight);
|
| 45 |
+
|
| 46 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/LinearAlgebraUtils.h
ADDED
|
@@ -0,0 +1,623 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/ScalarType.h>
|
| 4 |
+
#include <c10/util/irange.h>
|
| 5 |
+
#include <c10/util/Exception.h>
|
| 6 |
+
#include <c10/util/strides.h>
|
| 7 |
+
#include <ATen/core/Tensor.h>
|
| 8 |
+
#include <ATen/ExpandUtils.h>
|
| 9 |
+
#include <ATen/TensorUtils.h>
|
| 10 |
+
#include <ATen/native/TensorIterator.h>
|
| 11 |
+
#include <ATen/native/TransposeType.h>
|
| 12 |
+
#include <limits>
|
| 13 |
+
#include <type_traits>
|
| 14 |
+
#include <sstream>
|
| 15 |
+
#include <cstring>
|
| 16 |
+
#include <cctype>
|
| 17 |
+
|
| 18 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 19 |
+
#include <ATen/Functions.h>
|
| 20 |
+
#else
|
| 21 |
+
#include <ATen/ops/arange.h>
|
| 22 |
+
#include <ATen/ops/empty.h>
|
| 23 |
+
#include <ATen/ops/empty_like.h>
|
| 24 |
+
#include <ATen/ops/empty_strided.h>
|
| 25 |
+
#include <ATen/ops/zeros.h>
|
| 26 |
+
#endif
|
| 27 |
+
|
| 28 |
+
namespace at::native {
|
| 29 |
+
|
| 30 |
+
static inline c10::MaybeOwned<Tensor> expect_resolved_conj(const Tensor& tensor) {
|
| 31 |
+
if (tensor.is_conj()) {
|
| 32 |
+
return c10::MaybeOwned<Tensor>::owned(tensor.resolve_conj());
|
| 33 |
+
} else {
|
| 34 |
+
return c10::MaybeOwned<Tensor>::borrowed(tensor);
|
| 35 |
+
}
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
static inline DimVector batched_matrix_contiguous_strides(
|
| 39 |
+
const IntArrayRef sizes,
|
| 40 |
+
const bool f_contig = false) {
|
| 41 |
+
// f_contig chooses between the strides of a batch of Fortran (F-contiguous)
|
| 42 |
+
// and C-contiguous matrices
|
| 43 |
+
auto strides = c10::contiguous_strides(sizes);
|
| 44 |
+
auto dim = strides.size();
|
| 45 |
+
|
| 46 |
+
if (f_contig && dim >= 2) {
|
| 47 |
+
// Fix the strides of the last two dimensions, so that we return
|
| 48 |
+
// C-contiguous batches of F-contiguous matrices.
|
| 49 |
+
strides[dim - 1] = std::max(sizes[dim - 2], static_cast<int64_t>(1));
|
| 50 |
+
strides[dim - 2] = 1;
|
| 51 |
+
}
|
| 52 |
+
return strides;
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
/*
|
| 56 |
+
* Clones a Tensor so that the following conditions hold:
|
| 57 |
+
* If we think of a Tensor of having size (B, M, N), where B is any number
|
| 58 |
+
* of batch dimensions, then:
|
| 59 |
+
* - Each (M, N) matrix is in column major form
|
| 60 |
+
* - Let Tensor P have size (B, M, N) and Q have size (B, M', N').
|
| 61 |
+
* Then when laid out in memory, the M by N matrix starting at
|
| 62 |
+
* P.data_ptr()[B * M * N] is of the same corresponding batch as the M' by N'
|
| 63 |
+
* matrix starting at Q.data_ptr()[B * M' * N'].
|
| 64 |
+
*/
|
| 65 |
+
static inline Tensor cloneBatchedColumnMajor(const Tensor& src) {
|
| 66 |
+
// If src is already in batched column major format, then
|
| 67 |
+
// this will be efficient (no reordering of the data will occur)
|
| 68 |
+
// because the first transpose will make the tensor contiguous,
|
| 69 |
+
// and cloning a contiguous tensor is fast.
|
| 70 |
+
auto result = src.mT().clone(at::MemoryFormat::Contiguous);
|
| 71 |
+
result.transpose_(-2, -1);
|
| 72 |
+
return result;
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
/*
|
| 76 |
+
* contig chooses between C-contig (true) and F-contig (false)
|
| 77 |
+
*/
|
| 78 |
+
static inline c10::MaybeOwned<Tensor> borrow_else_clone(const bool cond, const Tensor& borrow, const Tensor& clone, const bool contig) {
|
| 79 |
+
return cond ? c10::MaybeOwned<Tensor>::borrowed(borrow)
|
| 80 |
+
: c10::MaybeOwned<Tensor>::owned(contig ? clone.clone(MemoryFormat::Contiguous)
|
| 81 |
+
: cloneBatchedColumnMajor(clone));
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
/*
|
| 85 |
+
* This method is designed to be a faster alternative to
|
| 86 |
+
* `cloneBatchedColumnMajor` with some additional features,
|
| 87 |
+
* namely:
|
| 88 |
+
* 1. It uses `copy` instead of `clone` which could be much faster.
|
| 89 |
+
* 2. `nrows` parameter used to create inputs with the number of rows larger
|
| 90 |
+
* than the original input, which is required for some LAPACK/MAGMA methods.
|
| 91 |
+
* 3. `desired_batch_size` is used to create copies with the batch size
|
| 92 |
+
* which is either the original batch size of the input, or its larger
|
| 93 |
+
* broadcasted shape.
|
| 94 |
+
*/
|
| 95 |
+
static inline Tensor copyBatchedColumnMajor(const Tensor& src, int64_t nrows = -1,
|
| 96 |
+
at::OptionalIntArrayRef desired_batch_sizes = c10::nullopt) {
|
| 97 |
+
nrows = (nrows == -1) ? src.size(-2) : nrows;
|
| 98 |
+
auto copy_sizes = desired_batch_sizes.has_value()
|
| 99 |
+
? desired_batch_sizes.value().vec()
|
| 100 |
+
: IntArrayRef(src.sizes().data(), src.dim() - 2).vec();
|
| 101 |
+
copy_sizes.insert(copy_sizes.end(), {nrows, src.size(-1)});
|
| 102 |
+
const auto copy_strides = batched_matrix_contiguous_strides(copy_sizes, /*f-contig*/true);
|
| 103 |
+
auto copy = at::empty_strided(copy_sizes, copy_strides, src.options());
|
| 104 |
+
copy.narrow(-2, 0, src.size(-2)).copy_(src);
|
| 105 |
+
return copy;
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
/*
|
| 109 |
+
* Given batches of matrices with arbitrary batch dim,
|
| 110 |
+
* computes the number of batches.
|
| 111 |
+
*/
|
| 112 |
+
static inline int64_t batchCount(const Tensor& batched_matrices) {
|
| 113 |
+
int64_t result = 1;
|
| 114 |
+
for (int64_t i = 0; i < batched_matrices.ndimension() - 2; i++) {
|
| 115 |
+
result *= batched_matrices.size(i);
|
| 116 |
+
}
|
| 117 |
+
return result;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
// Computes the number of elements of a matrix in a batched matrix tensor
|
| 121 |
+
static inline int64_t matrixStride(const Tensor& batched_matrices) {
|
| 122 |
+
return batched_matrices.size(-1) * batched_matrices.size(-2);
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
// Validates input shapes for operations on batches of square matrices (inverse, cholesky, symeig, eig)
|
| 126 |
+
static inline void checkIsMatrix(const Tensor& A, const char* const f_name, const char* const arg_name = "A") {
|
| 127 |
+
TORCH_CHECK(A.dim() >= 2, f_name, ": The input tensor ", arg_name, " must have at least 2 dimensions.");
|
| 128 |
+
}
|
| 129 |
+
static inline void squareCheckInputs(const Tensor& self, const char* const f_name, const char* const arg_name = "A") {
|
| 130 |
+
checkIsMatrix(self, f_name, arg_name);
|
| 131 |
+
TORCH_CHECK(self.sym_size(-1) == self.sym_size(-2),
|
| 132 |
+
f_name,
|
| 133 |
+
": ", arg_name, " must be batches of square matrices, "
|
| 134 |
+
"but they are ", self.sym_size(-2), " by ", self.sym_size(-1), " matrices");
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
static inline void checkInputsSolver(const Tensor& A,
|
| 138 |
+
const Tensor& B,
|
| 139 |
+
const bool left,
|
| 140 |
+
const char* const f_name) {
|
| 141 |
+
squareCheckInputs(A, f_name, "A");
|
| 142 |
+
checkIsMatrix(B, f_name, "B");
|
| 143 |
+
TORCH_CHECK(left ? A.size(-2) == B.size(-2) : A.size(-1) == B.size(-1),
|
| 144 |
+
f_name, ": Incompatible shapes of A and B for the equation ",
|
| 145 |
+
left ? "AX = B" : "XA = B",
|
| 146 |
+
" (", A.size(-2), "x", A.size(-1), " and ", B.size(-2), "x", B.size(-1), ")");
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
static inline bool is_row_or_column_contiguous(const Tensor& t) {
|
| 150 |
+
// This could be made more general, similar to how it's checked in matmul, which would allow to
|
| 151 |
+
// ellide the copy with strides such as (6, 12, 1, 3) or (3, 1, 9), but this is quite tricky.
|
| 152 |
+
// We choose to be conservative for simplicity
|
| 153 |
+
return t.is_contiguous() || t.transpose(-2, -1).is_contiguous();
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
static inline TransposeType to_transpose_type(const bool contig, const bool conj) {
|
| 157 |
+
if (conj) {
|
| 158 |
+
if (contig) { TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); }
|
| 159 |
+
else { return TransposeType::ConjTranspose; }
|
| 160 |
+
} else {
|
| 161 |
+
if (contig) { return TransposeType::NoTranspose; }
|
| 162 |
+
else { return TransposeType::Transpose; }
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
// This function is designed to be used with linear algebra methods that minimize
|
| 168 |
+
// L(ax - b) = 0, where L is generally the identity map (`solve`, for example)
|
| 169 |
+
// or the L2 norm (`lstsq`).
|
| 170 |
+
// It is expected that `a` and `b` are contiguous tensors of column-major matrices
|
| 171 |
+
// (so that a.view({-1, a.size(-2), a.size(-1)}) succeeds, same for `b`),
|
| 172 |
+
// with the following additional properties:
|
| 173 |
+
//
|
| 174 |
+
// 1. a.dim() == b.dim()
|
| 175 |
+
// 2. a.shape[:-2] broadcasts over b.shape[:-2]
|
| 176 |
+
// 3. a.size(i) <= b.size(i) for i=0,..., a.dim() - 3 (only for batch dimensions)
|
| 177 |
+
//
|
| 178 |
+
// MAGMA/LAPACK modify tensor `a` in-place, and the main goal of this method
|
| 179 |
+
// is to be memory efficient, which means that if there exists an index i such that
|
| 180 |
+
// a.shape[i] < b.shape[i], 0 <= i <= a.dim() - 3,
|
| 181 |
+
// then instead of materializing copies of `a` in the broadcasted shape, we keep
|
| 182 |
+
// a buffer copy of `a` along with flags that check whether specific batch dimension
|
| 183 |
+
// indices for `a` were already accessed. If they were, we copy the data from the buffer
|
| 184 |
+
// into `a`. The number of copies does not exceed
|
| 185 |
+
// prod(max(a.shape[:-2], b.shape[:-2]) - a.shape[:-2] + 1)
|
| 186 |
+
// and this value is attained by tensors with non-empty batch dimensions.
|
| 187 |
+
//
|
| 188 |
+
// func_t `f` is a callable that is being supplied with
|
| 189 |
+
// scalar_t* a_working_ptr, scalar_t* b_working_ptr, int64_t a_linear_batch_idx.
|
| 190 |
+
// a_working_ptr and b_working_ptr can directly be passed to LAPACK/MAGMA routines,
|
| 191 |
+
// and a_linear_batch_idx is an index in the 3d representation which corresponds to
|
| 192 |
+
// the memory a_working_ptr points to, in other words:
|
| 193 |
+
// a_working_ptr == a.view({-1, a.size(-2), a.size(-1)}.select(0, a_linear_batch_idx).data_ptr<scalar_t>();
|
| 194 |
+
// a_linear_batch_idx is useful to store metadata related to `a`, such as, for example,
|
| 195 |
+
// its rank or singular values (see linalg_lstsq).
|
| 196 |
+
template<typename scalar_t, typename func_t>
|
| 197 |
+
void batch_iterator_with_broadcasting(const Tensor& a, const Tensor& b, const func_t& f) {
|
| 198 |
+
IntArrayRef a_batch_sizes(a.sizes().data(), a.dim() - 2);
|
| 199 |
+
IntArrayRef b_batch_sizes(b.sizes().data(), b.dim() - 2);
|
| 200 |
+
|
| 201 |
+
auto a_linear_batch_idx = at::arange(batchCount(a)).view(a_batch_sizes);
|
| 202 |
+
auto b_linear_batch_idx = at::arange(batchCount(b)).view(b_batch_sizes);
|
| 203 |
+
|
| 204 |
+
TensorIterator iter = TensorIteratorConfig()
|
| 205 |
+
.set_check_mem_overlap(false)
|
| 206 |
+
.check_all_same_dtype(false)
|
| 207 |
+
.resize_outputs(false)
|
| 208 |
+
.add_output(b_linear_batch_idx)
|
| 209 |
+
.add_input(a_linear_batch_idx)
|
| 210 |
+
.build();
|
| 211 |
+
|
| 212 |
+
auto m = a.size(-2);
|
| 213 |
+
auto n = a.size(-1);
|
| 214 |
+
auto a_3d = a.view({batchCount(a), m, n});
|
| 215 |
+
auto b_3d = b.view({batchCount(b), b.size(-2), b.size(-1)});
|
| 216 |
+
|
| 217 |
+
auto a_broadcasts_over_b = (a_batch_sizes != b_batch_sizes);
|
| 218 |
+
Tensor a_buffer, a_was_accessed, a_buffer_3d;
|
| 219 |
+
std::function<void(int64_t)> check_if_copy_needed_for_a
|
| 220 |
+
= [](int64_t /*a_curr_linear_batch_idx*/){};
|
| 221 |
+
if (a_broadcasts_over_b) {
|
| 222 |
+
a_buffer = at::empty_strided(a.sizes(), a.strides(), a.options())
|
| 223 |
+
.copy_(a);
|
| 224 |
+
a_was_accessed = at::zeros(batchCount(a), at::kBool);
|
| 225 |
+
a_buffer_3d = a_buffer.view({batchCount(a), m, n});
|
| 226 |
+
check_if_copy_needed_for_a = [&](int64_t a_curr_linear_batch_idx) {
|
| 227 |
+
auto* a_was_accessed_flag = a_was_accessed
|
| 228 |
+
.select(0, a_curr_linear_batch_idx)
|
| 229 |
+
.data_ptr<bool>();
|
| 230 |
+
if (!(*a_was_accessed_flag)) {
|
| 231 |
+
*a_was_accessed_flag = true;
|
| 232 |
+
}
|
| 233 |
+
else {
|
| 234 |
+
a_3d.select(0, a_curr_linear_batch_idx)
|
| 235 |
+
.copy_(a_buffer_3d.select(0, a_curr_linear_batch_idx));
|
| 236 |
+
}
|
| 237 |
+
};
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
auto loop = [&](char** data, const int64_t* strides, int64_t nelems) {
|
| 241 |
+
auto* b_batch_idx_ptr = data[0];
|
| 242 |
+
auto* a_batch_idx_ptr = data[1];
|
| 243 |
+
|
| 244 |
+
for (const auto elem C10_UNUSED : c10::irange(nelems)) {
|
| 245 |
+
auto b_curr_linear_batch_idx = *reinterpret_cast<int64_t*>(b_batch_idx_ptr);
|
| 246 |
+
auto a_curr_linear_batch_idx = *reinterpret_cast<int64_t*>(a_batch_idx_ptr);
|
| 247 |
+
|
| 248 |
+
check_if_copy_needed_for_a(a_curr_linear_batch_idx);
|
| 249 |
+
|
| 250 |
+
auto* a_working_ptr = a_3d.select(0, a_curr_linear_batch_idx)
|
| 251 |
+
.data_ptr<scalar_t>();
|
| 252 |
+
auto* b_working_ptr = b_3d.select(0, b_curr_linear_batch_idx)
|
| 253 |
+
.data_ptr<scalar_t>();
|
| 254 |
+
f(a_working_ptr, b_working_ptr, a_curr_linear_batch_idx);
|
| 255 |
+
|
| 256 |
+
b_batch_idx_ptr += strides[0];
|
| 257 |
+
a_batch_idx_ptr += strides[1];
|
| 258 |
+
}
|
| 259 |
+
};
|
| 260 |
+
iter.serial_for_each(loop, {0, batchCount(b)});
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
// Returns the epsilon value for floating types except half
|
| 264 |
+
static inline double _get_epsilon(const ScalarType& sc_type) {
|
| 265 |
+
switch (sc_type) {
|
| 266 |
+
case at::ScalarType::Float:
|
| 267 |
+
return static_cast<double>(std::numeric_limits<float>::epsilon());
|
| 268 |
+
case at::ScalarType::Double:
|
| 269 |
+
return std::numeric_limits<double>::epsilon();
|
| 270 |
+
default:
|
| 271 |
+
AT_ERROR("This function doesn't handle types other than float and double");
|
| 272 |
+
}
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
// Validates input shapes and devices
|
| 276 |
+
// for linear solve methods (solve, cholesky_solve, lu_solve, triangular_solve)
|
| 277 |
+
static inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A, const char* name) {
|
| 278 |
+
TORCH_CHECK(self.device() == A.device(),
|
| 279 |
+
"Expected b and A to be on the same device, but found b on ",
|
| 280 |
+
self.device(), " and A on ", A.device(), " instead.");
|
| 281 |
+
|
| 282 |
+
TORCH_CHECK(self.scalar_type() == A.scalar_type(),
|
| 283 |
+
"Expected b and A to have the same dtype, but found b of type ",
|
| 284 |
+
self.scalar_type(), " and A of type ", A.scalar_type(), " instead.");
|
| 285 |
+
|
| 286 |
+
TORCH_CHECK(A.size(-1) == A.size(-2),
|
| 287 |
+
"A must be batches of square matrices, "
|
| 288 |
+
"but they are ", A.size(-2), " by ", A.size(-1), " matrices");
|
| 289 |
+
|
| 290 |
+
TORCH_CHECK(A.size(-1) == self.size(-2),
|
| 291 |
+
"Incompatible matrix sizes for ", name, ": each A "
|
| 292 |
+
"matrix is ", A.size(-1), " by ", A.size(-1),
|
| 293 |
+
" but each b matrix is ", self.size(-2), " by ", self.size(-1));
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
static inline void checkFloatingOrComplex(const Tensor& t, const char* const f_name, const bool allow_low_precision_dtypes=true) {
|
| 297 |
+
auto dtype = t.scalar_type();
|
| 298 |
+
TORCH_CHECK((at::isFloatingType(dtype) || at::isComplexType(dtype)),
|
| 299 |
+
f_name, ": Expected a floating point or complex tensor as input. Got ", dtype);
|
| 300 |
+
if (!allow_low_precision_dtypes) {
|
| 301 |
+
TORCH_CHECK(dtype == kFloat || dtype == kDouble || dtype == kComplexFloat || dtype == kComplexDouble,
|
| 302 |
+
f_name, ": Low precision dtypes not supported. Got ", dtype);
|
| 303 |
+
}
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
// Checks if all the Tensors in a TensorList are of the same dimensions
|
| 308 |
+
static inline void checkAllSameDim(TensorList tensors, int64_t dim) {
|
| 309 |
+
for (auto &t : tensors) {
|
| 310 |
+
TORCH_CHECK(t.dim() == dim, "Tensor dimension is ", t.dim(), ", expected ", dim, " instead.");
|
| 311 |
+
}
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
static inline std::tuple<std::vector<int64_t>, std::vector<int64_t>> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2) {
|
| 315 |
+
// broadcast the batch dimensions of arg1 and arg2.
|
| 316 |
+
IntArrayRef arg1_batch_sizes(arg1.sizes().data(), arg1.ndimension() - 2);
|
| 317 |
+
IntArrayRef arg2_batch_sizes(arg2.sizes().data(), arg2.ndimension() - 2);
|
| 318 |
+
std::vector<int64_t> expand_batch_portion = infer_size(arg1_batch_sizes, arg2_batch_sizes);
|
| 319 |
+
|
| 320 |
+
std::vector<int64_t> arg1_expand_size({expand_batch_portion});
|
| 321 |
+
arg1_expand_size.insert(arg1_expand_size.end(), { arg1.size(-2), arg1.size(-1) });
|
| 322 |
+
|
| 323 |
+
std::vector<int64_t> arg2_expand_size({expand_batch_portion});
|
| 324 |
+
arg2_expand_size.insert(arg2_expand_size.end(), { arg2.size(-2), arg2.size(-1) });
|
| 325 |
+
return std::make_tuple(std::move(arg1_expand_size), std::move(arg2_expand_size));
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
static inline std::tuple<Tensor,Tensor> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2, const char* name) {
|
| 329 |
+
// If there's no name we assume we don't want to check the errors
|
| 330 |
+
if (name != nullptr) {
|
| 331 |
+
linearSolveCheckInputs(arg1, arg2, name);
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
auto [arg1_expand_size, arg2_expand_size] = at::native::_linalg_broadcast_batch_dims(arg1, arg2);
|
| 335 |
+
|
| 336 |
+
auto arg1_broadcasted = arg1_expand_size == arg1.sizes() ? arg1 : arg1.expand(arg1_expand_size);
|
| 337 |
+
auto arg2_broadcasted = arg2_expand_size == arg2.sizes() ? arg2 : arg2.expand(arg2_expand_size);
|
| 338 |
+
return std::make_tuple(arg1_broadcasted, arg2_broadcasted);
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
static inline std::vector<int64_t> broadcast_batch_size(const Tensor& t1, const Tensor& t2, int64_t n_batch_dims) {
|
| 342 |
+
IntArrayRef t1_batch_sizes(t1.sizes().data(), n_batch_dims);
|
| 343 |
+
IntArrayRef t2_batch_sizes(t2.sizes().data(), n_batch_dims);
|
| 344 |
+
auto broadcasted_batch_sizes = infer_size(t1_batch_sizes, t2_batch_sizes);
|
| 345 |
+
return broadcasted_batch_sizes;
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
// Return a permutation with the given axes moved to the end.
|
| 349 |
+
static inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) {
|
| 350 |
+
const std::vector<int64_t> a = axes.vec();
|
| 351 |
+
const int64_t ndim = self.ndimension();
|
| 352 |
+
std::vector<int64_t> perm;
|
| 353 |
+
|
| 354 |
+
for (const auto i : c10::irange(ndim)) {
|
| 355 |
+
auto it = std::find(a.begin(), a.end(), i);
|
| 356 |
+
if (it == a.end()) {
|
| 357 |
+
perm.push_back(i);
|
| 358 |
+
}
|
| 359 |
+
}
|
| 360 |
+
for (auto i : a) {
|
| 361 |
+
perm.push_back(i);
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
TORCH_CHECK((int64_t)perm.size() == ndim,
|
| 365 |
+
"duplicate or invalid axis in 'dim' argument for tensor with ndim==", ndim);
|
| 366 |
+
|
| 367 |
+
return self.permute(perm);
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
// parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced)
|
| 371 |
+
static inline std::tuple<bool, bool> _parse_qr_mode(c10::string_view mode) {
|
| 372 |
+
bool compute_q;
|
| 373 |
+
bool reduced;
|
| 374 |
+
if (mode == "reduced") {
|
| 375 |
+
compute_q = true;
|
| 376 |
+
reduced = true;
|
| 377 |
+
} else if (mode == "complete") {
|
| 378 |
+
compute_q = true;
|
| 379 |
+
reduced = false;
|
| 380 |
+
} else if (mode == "r") {
|
| 381 |
+
compute_q = false;
|
| 382 |
+
reduced = true; // this is actually irrelevant in this mode
|
| 383 |
+
} else {
|
| 384 |
+
TORCH_CHECK(false, "qr received unrecognized mode '", mode,
|
| 385 |
+
"' but expected one of 'reduced' (default), 'r', or 'complete'");
|
| 386 |
+
}
|
| 387 |
+
return std::make_tuple(compute_q, reduced);
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
// Function to compute sizes, strides and the extra columns for the Q matrix in the QR Decomposition
|
| 391 |
+
static inline std::tuple<DimVector, DimVector, int64_t> _compute_geometry_for_Q(
|
| 392 |
+
const Tensor& input,
|
| 393 |
+
bool reduced) {
|
| 394 |
+
int64_t m = input.size(-2), n = input.size(-1);
|
| 395 |
+
int64_t n_columns_q;
|
| 396 |
+
|
| 397 |
+
// We need to compute the required size of Q based on the `reduced` option
|
| 398 |
+
DimVector q_sizes(input.sizes());
|
| 399 |
+
if (!reduced && m > n) {
|
| 400 |
+
q_sizes[input.dim() - 1] = m;
|
| 401 |
+
n_columns_q = m;
|
| 402 |
+
} else {
|
| 403 |
+
q_sizes[input.dim() - 1] = n;
|
| 404 |
+
n_columns_q = std::min(m, n);
|
| 405 |
+
}
|
| 406 |
+
auto q_strides = batched_matrix_contiguous_strides(q_sizes, /*f-contig*/true);
|
| 407 |
+
return std::make_tuple(q_sizes, q_strides, n_columns_q);
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
static inline bool svd_uses_cusolver(const Tensor& A) {
|
| 411 |
+
// if cusolver is available, it is used unconditionally
|
| 412 |
+
return A.is_cuda()
|
| 413 |
+
&& at::globalContext().hasCuSOLVER()
|
| 414 |
+
&& at::globalContext().linalgPreferredBackend() != at::LinalgBackend::Magma;
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
// Function used instead of .to so that the original strides are retained
|
| 419 |
+
// .to doesn't retain strides and make the output tensor contiguous
|
| 420 |
+
static inline Tensor same_stride_to(const Tensor& original_tensor, const at::TensorOptions& options) {
|
| 421 |
+
auto strided_to = at::empty_strided(original_tensor.sizes(),
|
| 422 |
+
original_tensor.strides(),
|
| 423 |
+
options);
|
| 424 |
+
strided_to.copy_(original_tensor);
|
| 425 |
+
return strided_to;
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
// Creates a dimension permutation array that can be given to `at::permute()`, which will shift
|
| 429 |
+
// the two specified dimensions to the end of a tensor, without changing the order of
|
| 430 |
+
// the other dimensions. `dim1` will be placed at the very end, and `dim0` will be
|
| 431 |
+
// placed just to the left of it.
|
| 432 |
+
//
|
| 433 |
+
// For instance, given a 4-D tensor, dimensions 1 and 3 can be shifted to the end by
|
| 434 |
+
// calling `create_dim_backshift_permutation(1, 3, 4)`. The resulting vector will
|
| 435 |
+
// be `vec(0, 2, 1, 3)`.
|
| 436 |
+
static inline std::vector<int64_t> create_dim_backshift_permutation(int64_t dim0, int64_t dim1, int64_t ndim) {
|
| 437 |
+
TORCH_CHECK(
|
| 438 |
+
(dim0 != dim1) && (dim0 < ndim) && (dim0 >= 0) && (dim1 < ndim) && (dim1 >= 0),
|
| 439 |
+
"duplicate or invalid dimensions");
|
| 440 |
+
std::vector<int64_t> permutation(ndim);
|
| 441 |
+
int64_t cur_permuted_dim = 0;
|
| 442 |
+
for (const auto dim_ind : c10::irange(ndim)) {
|
| 443 |
+
if ((dim_ind != dim0) && (dim_ind != dim1)) {
|
| 444 |
+
permutation[cur_permuted_dim++] = dim_ind;
|
| 445 |
+
}
|
| 446 |
+
}
|
| 447 |
+
permutation[cur_permuted_dim++] = dim0;
|
| 448 |
+
permutation[cur_permuted_dim] = dim1;
|
| 449 |
+
return permutation;
|
| 450 |
+
}
|
| 451 |
+
|
| 452 |
+
// Creates a dimension permutation array that can be given to `at::permute()`, which
|
| 453 |
+
// will reverse a given permutation.
|
| 454 |
+
// The reverse permutation array is created by swapping the indices and their
|
| 455 |
+
// associated values from the given permutation array.
|
| 456 |
+
static inline std::vector<int64_t> create_reverse_permutation(std::vector<int64_t> permutation) {
|
| 457 |
+
int64_t ndim = permutation.size();
|
| 458 |
+
std::vector<int64_t> reverse_permutation(ndim);
|
| 459 |
+
for (const auto dim_ind : c10::irange(ndim)) {
|
| 460 |
+
reverse_permutation[permutation[dim_ind]] = dim_ind;
|
| 461 |
+
}
|
| 462 |
+
return reverse_permutation;
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
// Compute R-work array size for MAGMA/LAPACK cgesdd/zgesdd
|
| 466 |
+
// See https://github.com/Reference-LAPACK/lapack/blob/122506cd8b6ce050a200920c3d4c0b153b150fd8/SRC/cgesdd.f#L186
|
| 467 |
+
static inline int64_t computeLRWorkDim(const char jobz, int64_t m, int64_t n) {
|
| 468 |
+
auto mn = std::min(m, n);
|
| 469 |
+
auto mx = std::max(m, n);
|
| 470 |
+
if (jobz == 'N') {
|
| 471 |
+
#ifdef __APPLE__
|
| 472 |
+
// According to `vecLib.framework/Headers/clapack.h` Accelerate.framework is based on LAPACK 3.2.1
|
| 473 |
+
return 7 * mn;
|
| 474 |
+
#else
|
| 475 |
+
// These setting is valid for on LAPACK 3.6+
|
| 476 |
+
return 5 * mn;
|
| 477 |
+
#endif
|
| 478 |
+
}
|
| 479 |
+
if (mx > 10 * mn) {
|
| 480 |
+
return 5 * mn * mn + 5 * mn;
|
| 481 |
+
}
|
| 482 |
+
return std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn);
|
| 483 |
+
}
|
| 484 |
+
|
| 485 |
+
// This function checks whether the uplo argument input is valid
|
| 486 |
+
// Allowed strings are "u", "U", "l", "L"
|
| 487 |
+
static inline void checkUplo(const c10::string_view uplo) {
|
| 488 |
+
// To use std::toupper safely with plain chars (or signed chars), the argument should first be converted to unsigned char
|
| 489 |
+
char uplo_uppercase = static_cast<char>(std::toupper(static_cast<unsigned char>(uplo[0])));
|
| 490 |
+
TORCH_CHECK(uplo.size() == 1 && (uplo_uppercase == 'U' || uplo_uppercase == 'L'),
|
| 491 |
+
"Expected UPLO argument to be 'L' or 'U', but got ", uplo);
|
| 492 |
+
}
|
| 493 |
+
|
| 494 |
+
static inline void checkSameDevice(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") {
|
| 495 |
+
TORCH_CHECK(
|
| 496 |
+
result.device() == input.device(),
|
| 497 |
+
fn_name,
|
| 498 |
+
": Expected ", result_name, " and input tensors to be on the same device, but got ",
|
| 499 |
+
result_name, " on ", result.device(), " and input on ", input.device());
|
| 500 |
+
}
|
| 501 |
+
|
| 502 |
+
// Check the dtype of result and input tensors (for _out variants).
|
| 503 |
+
// Most linear algebra functions have the same dtype for input and output
|
| 504 |
+
// (either floating or complex type input), so we can check whether input's dtype can be casted to result's dtype.
|
| 505 |
+
// According to https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch
|
| 506 |
+
// c10::canCast is used for checking the "safe copy" dtype requirements.
|
| 507 |
+
static inline void checkLinalgCompatibleDtype(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") {
|
| 508 |
+
bool can_cast = c10::canCast(input.scalar_type(), result.scalar_type());
|
| 509 |
+
TORCH_CHECK(
|
| 510 |
+
can_cast,
|
| 511 |
+
fn_name,
|
| 512 |
+
": Expected ", result_name, " to be safely castable from ", input.scalar_type(), " dtype, but got ",
|
| 513 |
+
result_name, " with dtype ", result.scalar_type());
|
| 514 |
+
}
|
| 515 |
+
|
| 516 |
+
// Alternatively, we can check whether the specific expected output type (result_type) can be safely casted to out tensor dtype (out_type)
|
| 517 |
+
static inline void checkLinalgCompatibleDtype(const std::string& fn_name, ScalarType out_type, ScalarType result_type, const std::string& out_name = "result") {
|
| 518 |
+
bool can_cast = c10::canCast(result_type, out_type);
|
| 519 |
+
TORCH_CHECK(
|
| 520 |
+
can_cast,
|
| 521 |
+
fn_name,
|
| 522 |
+
": Expected ", out_name, " to be safely castable from ", result_type, " dtype, but got ",
|
| 523 |
+
out_name, " with dtype ", out_type);
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
static inline void checkNotComplexTolerance(const Tensor& tol, const c10::string_view f_name, const c10::string_view tol_name) {
|
| 527 |
+
TORCH_CHECK(!at::isComplexType(tol.scalar_type()),
|
| 528 |
+
f_name, ": ", tol_name, " tensor of complex type is not supported. Got ", tol.scalar_type());
|
| 529 |
+
}
|
| 530 |
+
|
| 531 |
+
/*
|
| 532 |
+
Two types of 'other' tensors are supported when solving
|
| 533 |
+
a system of linear equations matmul(input, x) = other:
|
| 534 |
+
* 1-dimensional (1D) tensor or batch of 1D tensors (vector case)
|
| 535 |
+
* 2-dimensional (2D) tensor or batch of 2D tensors (matrix case).
|
| 536 |
+
The original torch.solve supported only the matrix case, while NumPy works for both cases.
|
| 537 |
+
For the batched input we need to be able to distinguish them.
|
| 538 |
+
Let input.shape = (batch_dimensions, m, n), then 'other' is of vector type if other.shape == (batch_dimensions, m).
|
| 539 |
+
This rule is compatible with NumPy, see https://github.com/numpy/numpy/blob/v1.20.0/numpy/linalg/linalg.py#L384-L389
|
| 540 |
+
*/
|
| 541 |
+
static inline bool linalg_solve_is_vector_rhs(const Tensor& input, const Tensor& other) {
|
| 542 |
+
auto expected_batched_rhs_shape = SymIntArrayRef(input.sym_sizes().data(), input.dim() - 1); // input.shape[:-1]
|
| 543 |
+
bool vector_case = other.dim() == 1 || (input.dim() - 1 == other.dim() && other.sym_sizes().equals(expected_batched_rhs_shape));
|
| 544 |
+
return vector_case;
|
| 545 |
+
}
|
| 546 |
+
|
| 547 |
+
/*
|
| 548 |
+
Computes linear indices for a tensor with original_shape to access its elements like it was a materialized broadcast tensor.
|
| 549 |
+
*/
|
| 550 |
+
static inline Tensor get_linear_indices(int64_t numel, IntArrayRef original_shape, IntArrayRef broadcast_shape) {
|
| 551 |
+
TensorOptions options = at::TensorOptions().dtype(at::kLong).device(at::kCPU);
|
| 552 |
+
return at::arange(numel, options).view(original_shape).broadcast_to(broadcast_shape).contiguous();
|
| 553 |
+
}
|
| 554 |
+
|
| 555 |
+
class BroadcastLinearIndices {
|
| 556 |
+
private:
|
| 557 |
+
Tensor linear_indices_;
|
| 558 |
+
bool is_broadcasting_;
|
| 559 |
+
|
| 560 |
+
public:
|
| 561 |
+
BroadcastLinearIndices(
|
| 562 |
+
int64_t numel,
|
| 563 |
+
IntArrayRef original_shape,
|
| 564 |
+
IntArrayRef broadcast_shape) : is_broadcasting_(!original_shape.equals(broadcast_shape)) {
|
| 565 |
+
// The assumption is that the broadcast_shape is a materialized broadcast
|
| 566 |
+
// shape of the original_shape. We need to compute the linear indices
|
| 567 |
+
// compatible with the original_shape to access the elements in the original
|
| 568 |
+
// tensor corresponding to the broadcast tensor.
|
| 569 |
+
if (is_broadcasting_) {
|
| 570 |
+
linear_indices_ =
|
| 571 |
+
get_linear_indices(numel, original_shape, broadcast_shape);
|
| 572 |
+
}
|
| 573 |
+
}
|
| 574 |
+
int64_t operator()(int64_t broadcast_linear_index) {
|
| 575 |
+
return is_broadcasting_
|
| 576 |
+
? linear_indices_.data_ptr<int64_t>()[broadcast_linear_index]
|
| 577 |
+
: broadcast_linear_index;
|
| 578 |
+
}
|
| 579 |
+
};
|
| 580 |
+
|
| 581 |
+
static inline bool is_blas_compatible_column_major_order(const Tensor& input) {
|
| 582 |
+
IntArrayRef input_strides = input.strides();
|
| 583 |
+
IntArrayRef input_sizes = input.sizes();
|
| 584 |
+
auto ndim = input.dim();
|
| 585 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ndim >= 2);
|
| 586 |
+
if (ndim > 3) {
|
| 587 |
+
return input.transpose(-2, -1).is_contiguous();
|
| 588 |
+
}
|
| 589 |
+
auto leading_dimension = input_strides[ndim - 1];
|
| 590 |
+
auto rows = input_sizes[ndim - 2];
|
| 591 |
+
bool batch_stride_compatible = true;
|
| 592 |
+
if (ndim == 3) {
|
| 593 |
+
auto cols = input_sizes[ndim - 1];
|
| 594 |
+
batch_stride_compatible =
|
| 595 |
+
input_strides[ndim - 3] >= leading_dimension * cols;
|
| 596 |
+
}
|
| 597 |
+
return (input_strides[ndim - 2] == 1) &&
|
| 598 |
+
(leading_dimension >= std::max<int64_t>(1, rows)) &&
|
| 599 |
+
batch_stride_compatible;
|
| 600 |
+
}
|
| 601 |
+
|
| 602 |
+
static inline bool is_blas_compatible_row_major_order(const Tensor& input) {
|
| 603 |
+
IntArrayRef input_strides = input.strides();
|
| 604 |
+
IntArrayRef input_sizes = input.sizes();
|
| 605 |
+
auto ndim = input.dim();
|
| 606 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ndim >= 2);
|
| 607 |
+
if (ndim > 3) {
|
| 608 |
+
return input.is_contiguous();
|
| 609 |
+
}
|
| 610 |
+
auto leading_dimension = input_strides[ndim - 2];
|
| 611 |
+
auto cols = input_sizes[ndim - 1];
|
| 612 |
+
bool batch_stride_compatible = true;
|
| 613 |
+
if (ndim == 3) {
|
| 614 |
+
auto rows = input_sizes[ndim - 2];
|
| 615 |
+
batch_stride_compatible =
|
| 616 |
+
input_strides[ndim - 3] >= leading_dimension * rows;
|
| 617 |
+
}
|
| 618 |
+
return (input_strides[ndim - 1] == 1) &&
|
| 619 |
+
(leading_dimension >= std::max<int64_t>(1, cols)) &&
|
| 620 |
+
batch_stride_compatible;
|
| 621 |
+
}
|
| 622 |
+
|
| 623 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/MathBitFallThroughLists.h
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
namespace at {
|
| 4 |
+
// views and their in-place version ops
|
| 5 |
+
#define TORCH_VIEW_FNS(m) \
|
| 6 |
+
m.impl("as_strided_", torch::CppFunction::makeFallthrough()); \
|
| 7 |
+
m.impl("detach", torch::CppFunction::makeFallthrough()); \
|
| 8 |
+
m.impl("detach_", torch::CppFunction::makeFallthrough()); \
|
| 9 |
+
m.impl("diagonal", torch::CppFunction::makeFallthrough()); \
|
| 10 |
+
m.impl("expand", torch::CppFunction::makeFallthrough()); \
|
| 11 |
+
m.impl("expand_as", torch::CppFunction::makeFallthrough()); \
|
| 12 |
+
m.impl("movedim.int", torch::CppFunction::makeFallthrough()); \
|
| 13 |
+
m.impl("movedim.intlist", torch::CppFunction::makeFallthrough()); \
|
| 14 |
+
m.impl("narrow", torch::CppFunction::makeFallthrough()); \
|
| 15 |
+
m.impl("permute", torch::CppFunction::makeFallthrough()); \
|
| 16 |
+
m.impl("select.Dimname", torch::CppFunction::makeFallthrough()); \
|
| 17 |
+
m.impl("select.int", torch::CppFunction::makeFallthrough()); \
|
| 18 |
+
m.impl("squeeze", torch::CppFunction::makeFallthrough()); \
|
| 19 |
+
m.impl("squeeze_", torch::CppFunction::makeFallthrough()); \
|
| 20 |
+
m.impl("transpose.int", torch::CppFunction::makeFallthrough()); \
|
| 21 |
+
m.impl("transpose.Dimname", torch::CppFunction::makeFallthrough()); \
|
| 22 |
+
m.impl("transpose_", torch::CppFunction::makeFallthrough()); \
|
| 23 |
+
m.impl("t", torch::CppFunction::makeFallthrough()); \
|
| 24 |
+
m.impl("t_", torch::CppFunction::makeFallthrough()); \
|
| 25 |
+
m.impl("real", torch::CppFunction::makeFallthrough()); \
|
| 26 |
+
m.impl("imag", torch::CppFunction::makeFallthrough()); \
|
| 27 |
+
m.impl("view_as_real", torch::CppFunction::makeFallthrough()); \
|
| 28 |
+
m.impl("unflatten.int", torch::CppFunction::makeFallthrough()); \
|
| 29 |
+
m.impl("unflatten.Dimname", torch::CppFunction::makeFallthrough()); \
|
| 30 |
+
m.impl("unfold", torch::CppFunction::makeFallthrough()); \
|
| 31 |
+
m.impl("unsqueeze", torch::CppFunction::makeFallthrough()); \
|
| 32 |
+
m.impl("unsqueeze_", torch::CppFunction::makeFallthrough()); \
|
| 33 |
+
m.impl("view_as", torch::CppFunction::makeFallthrough()); \
|
| 34 |
+
m.impl("unbind.int", torch::CppFunction::makeFallthrough()); \
|
| 35 |
+
m.impl("unbind.Dimname", torch::CppFunction::makeFallthrough()); \
|
| 36 |
+
m.impl("split.Tensor", torch::CppFunction::makeFallthrough()); \
|
| 37 |
+
m.impl("split_with_sizes", torch::CppFunction::makeFallthrough()); \
|
| 38 |
+
m.impl("swapaxes", torch::CppFunction::makeFallthrough()); \
|
| 39 |
+
m.impl("swapdims", torch::CppFunction::makeFallthrough()); \
|
| 40 |
+
m.impl("chunk", torch::CppFunction::makeFallthrough()); \
|
| 41 |
+
m.impl("reshape", torch::CppFunction::makeFallthrough()); \
|
| 42 |
+
m.impl("alias", torch::CppFunction::makeFallthrough()); \
|
| 43 |
+
m.impl("hsplit.int", torch::CppFunction::makeFallthrough()); \
|
| 44 |
+
m.impl("hsplit.array", torch::CppFunction::makeFallthrough()); \
|
| 45 |
+
m.impl("dsplit.int", torch::CppFunction::makeFallthrough()); \
|
| 46 |
+
m.impl("dsplit.array", torch::CppFunction::makeFallthrough()); \
|
| 47 |
+
m.impl("vsplit.int", torch::CppFunction::makeFallthrough()); \
|
| 48 |
+
m.impl("vsplit.array", torch::CppFunction::makeFallthrough()); \
|
| 49 |
+
m.impl("conj", torch::CppFunction::makeFallthrough()); \
|
| 50 |
+
m.impl("_conj", torch::CppFunction::makeFallthrough()); \
|
| 51 |
+
m.impl("_unsafe_view", torch::CppFunction::makeFallthrough()); \
|
| 52 |
+
m.impl("resize_", torch::CppFunction::makeFallthrough());
|
| 53 |
+
|
| 54 |
+
#define TENSOR_UTILITIES_AND_CONSTRUCTORS(m) \
|
| 55 |
+
m.impl("empty_like", torch::CppFunction::makeFallthrough()); \
|
| 56 |
+
m.impl("empty.memory_format", torch::CppFunction::makeFallthrough()); \
|
| 57 |
+
m.impl("empty.out", torch::CppFunction::makeFallthrough()); \
|
| 58 |
+
m.impl("empty_strided", torch::CppFunction::makeFallthrough()); \
|
| 59 |
+
m.impl("full_like", torch::CppFunction::makeFallthrough()); \
|
| 60 |
+
m.impl("stride.int", torch::CppFunction::makeFallthrough()); \
|
| 61 |
+
m.impl("stride.Dimname", torch::CppFunction::makeFallthrough()); \
|
| 62 |
+
m.impl("size.int", torch::CppFunction::makeFallthrough()); \
|
| 63 |
+
m.impl("size.Dimname", torch::CppFunction::makeFallthrough()); \
|
| 64 |
+
m.impl("is_complex", torch::CppFunction::makeFallthrough()); \
|
| 65 |
+
m.impl("is_floating_point", torch::CppFunction::makeFallthrough()); \
|
| 66 |
+
m.impl("requires_grad_", torch::CppFunction::makeFallthrough());
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
#define TORCH_VIEW_FNS_NATIVE_FN_REGISTRATION(m) \
|
| 70 |
+
m.impl("as_strided", torch::CppFunction::makeFallthrough()); \
|
| 71 |
+
m.impl("view", torch::CppFunction::makeFallthrough());
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Sorting.h
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
#include <cstdint>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
class TensorBase;
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
namespace at::native {
|
| 11 |
+
|
| 12 |
+
enum class QUANTILE_INTERPOLATION_MODE : uint8_t {
|
| 13 |
+
LINEAR,
|
| 14 |
+
LOWER,
|
| 15 |
+
HIGHER,
|
| 16 |
+
MIDPOINT,
|
| 17 |
+
NEAREST
|
| 18 |
+
};
|
| 19 |
+
|
| 20 |
+
using sort_fn = void(*)(const TensorBase&, const TensorBase&, const TensorBase&, int64_t, bool, bool);
|
| 21 |
+
using topk_fn = void(*)(const TensorBase&, const TensorBase&, const TensorBase&, int64_t, int64_t, bool, bool);
|
| 22 |
+
|
| 23 |
+
DECLARE_DISPATCH(sort_fn, sort_stub);
|
| 24 |
+
DECLARE_DISPATCH(topk_fn, topk_stub);
|
| 25 |
+
|
| 26 |
+
void _fill_indices(const TensorBase &indices, int64_t dim);
|
| 27 |
+
|
| 28 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SpectralOpsUtils.h
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <string>
|
| 4 |
+
#include <stdexcept>
|
| 5 |
+
#include <sstream>
|
| 6 |
+
#include <c10/core/ScalarType.h>
|
| 7 |
+
#include <c10/util/ArrayRef.h>
|
| 8 |
+
#include <c10/util/Exception.h>
|
| 9 |
+
#include <ATen/native/DispatchStub.h>
|
| 10 |
+
#include <ATen/core/TensorBase.h>
|
| 11 |
+
|
| 12 |
+
namespace at::native {
|
| 13 |
+
|
| 14 |
+
// Normalization types used in _fft_with_size
|
| 15 |
+
enum class fft_norm_mode {
|
| 16 |
+
none, // No normalization
|
| 17 |
+
by_root_n, // Divide by sqrt(signal_size)
|
| 18 |
+
by_n, // Divide by signal_size
|
| 19 |
+
};
|
| 20 |
+
|
| 21 |
+
// NOTE [ Fourier Transform Conjugate Symmetry ]
|
| 22 |
+
//
|
| 23 |
+
// Real-to-complex Fourier transform satisfies the conjugate symmetry. That is,
|
| 24 |
+
// assuming X is the transformed K-dimensionsal signal, we have
|
| 25 |
+
//
|
| 26 |
+
// X[i_1, ..., i_K] = X[j_i, ..., j_K]*,
|
| 27 |
+
//
|
| 28 |
+
// where j_k = (N_k - i_k) mod N_k, N_k being the signal size at dim k,
|
| 29 |
+
// * is the conjugate operator.
|
| 30 |
+
//
|
| 31 |
+
// Therefore, in such cases, FFT libraries return only roughly half of the
|
| 32 |
+
// values to avoid redundancy:
|
| 33 |
+
//
|
| 34 |
+
// X[:, :, ..., :floor(N / 2) + 1]
|
| 35 |
+
//
|
| 36 |
+
// This is also the assumption in cuFFT and MKL. In ATen SpectralOps, such
|
| 37 |
+
// halved signal will also be returned by default (flag onesided=True).
|
| 38 |
+
// The following infer_ft_real_to_complex_onesided_size function calculates the
|
| 39 |
+
// onesided size from the twosided size.
|
| 40 |
+
//
|
| 41 |
+
// Note that this loses some information about the size of signal at last
|
| 42 |
+
// dimension. E.g., both 11 and 10 maps to 6. Hence, the following
|
| 43 |
+
// infer_ft_complex_to_real_onesided_size function takes in optional parameter
|
| 44 |
+
// to infer the twosided size from given onesided size.
|
| 45 |
+
//
|
| 46 |
+
// cuFFT doc: http://docs.nvidia.com/cuda/cufft/index.html#multi-dimensional
|
| 47 |
+
// MKL doc: https://software.intel.com/en-us/mkl-developer-reference-c-dfti-complex-storage-dfti-real-storage-dfti-conjugate-even-storage#CONJUGATE_EVEN_STORAGE
|
| 48 |
+
|
| 49 |
+
inline int64_t infer_ft_real_to_complex_onesided_size(int64_t real_size) {
|
| 50 |
+
return (real_size / 2) + 1;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
inline int64_t infer_ft_complex_to_real_onesided_size(int64_t complex_size,
|
| 54 |
+
int64_t expected_size=-1) {
|
| 55 |
+
int64_t base = (complex_size - 1) * 2;
|
| 56 |
+
if (expected_size < 0) {
|
| 57 |
+
return base + 1;
|
| 58 |
+
} else if (base == expected_size) {
|
| 59 |
+
return base;
|
| 60 |
+
} else if (base + 1 == expected_size) {
|
| 61 |
+
return base + 1;
|
| 62 |
+
} else {
|
| 63 |
+
std::ostringstream ss;
|
| 64 |
+
ss << "expected real signal size " << expected_size << " is incompatible "
|
| 65 |
+
<< "with onesided complex frequency size " << complex_size;
|
| 66 |
+
AT_ERROR(ss.str());
|
| 67 |
+
}
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
using fft_fill_with_conjugate_symmetry_fn =
|
| 71 |
+
void (*)(ScalarType dtype, IntArrayRef mirror_dims, IntArrayRef half_sizes,
|
| 72 |
+
IntArrayRef in_strides, const void* in_data,
|
| 73 |
+
IntArrayRef out_strides, void* out_data);
|
| 74 |
+
DECLARE_DISPATCH(fft_fill_with_conjugate_symmetry_fn, fft_fill_with_conjugate_symmetry_stub);
|
| 75 |
+
|
| 76 |
+
// In real-to-complex transform, cuFFT and MKL only fill half of the values
|
| 77 |
+
// due to conjugate symmetry. This function fills in the other half of the full
|
| 78 |
+
// fft by using the Hermitian symmetry in the signal.
|
| 79 |
+
// self should be the shape of the full signal and dims.back() should be the
|
| 80 |
+
// one-sided dimension.
|
| 81 |
+
// See NOTE [ Fourier Transform Conjugate Symmetry ]
|
| 82 |
+
TORCH_API void _fft_fill_with_conjugate_symmetry_(const Tensor& self, IntArrayRef dims);
|
| 83 |
+
|
| 84 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorTransformations.h
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/Tensor.h>
|
| 2 |
+
|
| 3 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 4 |
+
#include <ATen/Functions.h>
|
| 5 |
+
#else
|
| 6 |
+
#include <ATen/ops/roll.h>
|
| 7 |
+
#endif
|
| 8 |
+
|
| 9 |
+
#include <c10/util/Exception.h>
|
| 10 |
+
|
| 11 |
+
namespace at::native {
|
| 12 |
+
|
| 13 |
+
static inline Tensor roll_common(const Tensor& self, IntArrayRef shifts, IntArrayRef dims) {
|
| 14 |
+
TORCH_CHECK(!shifts.empty(), "`shifts` required");
|
| 15 |
+
if (dims.empty() && shifts.size() == 1) {
|
| 16 |
+
auto flattened = self.contiguous().view(self.numel());
|
| 17 |
+
return roll(flattened, shifts[0], 0).view(self.sizes());
|
| 18 |
+
}
|
| 19 |
+
TORCH_CHECK(
|
| 20 |
+
shifts.size() == dims.size(),
|
| 21 |
+
"shifts and dimensions must align. shifts: ", shifts.size(), ", dims:", dims.size()
|
| 22 |
+
);
|
| 23 |
+
AT_ASSERT(dims.size() > 1);
|
| 24 |
+
auto tail_shifts = shifts.slice(1);
|
| 25 |
+
auto tail_dims = dims.slice(1);
|
| 26 |
+
auto first_dim_rolled = roll(self, shifts[0], dims[0]);
|
| 27 |
+
return at::roll(first_dim_rolled, tail_shifts, tail_dims);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/GridSamplerKernel.h
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
|
| 5 |
+
#include <array>
|
| 6 |
+
#include <cstdint>
|
| 7 |
+
|
| 8 |
+
namespace at {
|
| 9 |
+
class TensorBase;
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
namespace at { namespace native {
|
| 13 |
+
|
| 14 |
+
using forward_2d_fn = void (*) (
|
| 15 |
+
const TensorBase &output,
|
| 16 |
+
const TensorBase &input,
|
| 17 |
+
const TensorBase &grid,
|
| 18 |
+
int64_t interpolation_mode,
|
| 19 |
+
int64_t padding_mode,
|
| 20 |
+
bool align_corners);
|
| 21 |
+
using backward_2d_fn = void (*) (
|
| 22 |
+
const TensorBase &grad_input,
|
| 23 |
+
const TensorBase &grad_grid,
|
| 24 |
+
const TensorBase &grad_output,
|
| 25 |
+
const TensorBase &input,
|
| 26 |
+
const TensorBase &grid,
|
| 27 |
+
int64_t interpolation_mode,
|
| 28 |
+
int64_t padding_mode,
|
| 29 |
+
bool align_corners,
|
| 30 |
+
std::array<bool, 2> output_mask);
|
| 31 |
+
DECLARE_DISPATCH(forward_2d_fn, grid_sampler_2d_cpu_kernel);
|
| 32 |
+
DECLARE_DISPATCH(backward_2d_fn, grid_sampler_2d_backward_cpu_kernel);
|
| 33 |
+
|
| 34 |
+
}} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/IndexKernelUtils.h
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/native/TensorIterator.h>
|
| 3 |
+
#include <c10/util/irange.h>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
namespace native {
|
| 7 |
+
|
| 8 |
+
namespace {
|
| 9 |
+
static bool is_constant_index(int ntensor, const int64_t* strides) {
|
| 10 |
+
AT_ASSERT(ntensor >= 3);
|
| 11 |
+
for (const auto arg : c10::irange(2, ntensor)) {
|
| 12 |
+
if (strides[arg] != 0) {
|
| 13 |
+
return false;
|
| 14 |
+
}
|
| 15 |
+
}
|
| 16 |
+
return true;
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
struct Indexer {
|
| 21 |
+
Indexer(int64_t num_indexers, char** indexers, const int64_t* indexer_strides,
|
| 22 |
+
IntArrayRef original_sizes, IntArrayRef original_strides)
|
| 23 |
+
: num_indexers(num_indexers)
|
| 24 |
+
, indexers(indexers)
|
| 25 |
+
, indexer_strides(indexer_strides)
|
| 26 |
+
, original_strides(original_strides.data())
|
| 27 |
+
, original_sizes(original_sizes.data()) {
|
| 28 |
+
AT_ASSERT(static_cast<int64_t>(original_strides.size()) == num_indexers);
|
| 29 |
+
AT_ASSERT(static_cast<int64_t>(original_sizes.size()) == num_indexers);
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
int64_t num_indexers;
|
| 33 |
+
char** indexers;
|
| 34 |
+
const int64_t* indexer_strides;
|
| 35 |
+
const int64_t* original_strides;
|
| 36 |
+
const int64_t* original_sizes;
|
| 37 |
+
|
| 38 |
+
int64_t get(int64_t idx) {
|
| 39 |
+
int64_t offset = 0;
|
| 40 |
+
for (const auto j : c10::irange(num_indexers)) {
|
| 41 |
+
int64_t value = *(int64_t*)&indexers[j][idx * indexer_strides[j]];
|
| 42 |
+
int64_t size = original_sizes[j];
|
| 43 |
+
TORCH_CHECK_INDEX(value >= -size && value < size,
|
| 44 |
+
"index ", value, " is out of bounds for dimension ", j, " with size ", size);
|
| 45 |
+
if (value < 0) {
|
| 46 |
+
value += size;
|
| 47 |
+
}
|
| 48 |
+
offset += value * original_strides[j];
|
| 49 |
+
}
|
| 50 |
+
return offset;
|
| 51 |
+
}
|
| 52 |
+
};
|
| 53 |
+
} // anonymous namespace
|
| 54 |
+
|
| 55 |
+
template <typename scalar_t, typename func_t>
|
| 56 |
+
void cpu_index_kernel(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride,
|
| 57 |
+
const func_t& f, bool serial_execution=false)
|
| 58 |
+
{
|
| 59 |
+
int ntensor = iter.ntensors();
|
| 60 |
+
// When launch the index parallel version, set a relative small grain size less than the INTERNAL::GRAIN_SIZE
|
| 61 |
+
// to make the whole available thread numbers get more balanced work load and a better cache location.
|
| 62 |
+
// The grain size here is chosen by the op benchmark to overcome the thread launch overhead
|
| 63 |
+
const int index_parallel_grain_size = 3000;
|
| 64 |
+
auto loop = [&](char** data, const int64_t* strides, int64_t n) {
|
| 65 |
+
auto indexer = Indexer(ntensor - 2, &data[2], &strides[2], index_size, index_stride);
|
| 66 |
+
char* dst = data[0];
|
| 67 |
+
char* src = data[1];
|
| 68 |
+
if (is_constant_index(ntensor, strides)) {
|
| 69 |
+
// specialization for when every element uses the same index
|
| 70 |
+
int64_t offset = indexer.get(0);
|
| 71 |
+
for (const auto i : c10::irange(n)) {
|
| 72 |
+
f(dst + strides[0] * i, src + strides[1] * i, offset);
|
| 73 |
+
}
|
| 74 |
+
} else {
|
| 75 |
+
for (const auto i : c10::irange(n)) {
|
| 76 |
+
int64_t offset = indexer.get(i);
|
| 77 |
+
f(dst + strides[0] * i, src + strides[1] * i, offset);
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
};
|
| 81 |
+
if (serial_execution) {
|
| 82 |
+
iter.serial_for_each(loop, {0, iter.numel()});
|
| 83 |
+
} else {
|
| 84 |
+
iter.for_each(loop, index_parallel_grain_size);
|
| 85 |
+
}
|
| 86 |
+
}
|
| 87 |
+
} // at
|
| 88 |
+
} // native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/LogAddExp.h
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/util/complex.h>
|
| 4 |
+
#include <ATen/NumericUtils.h>
|
| 5 |
+
|
| 6 |
+
namespace at { namespace native {
|
| 7 |
+
inline namespace CPU_CAPABILITY {
|
| 8 |
+
|
| 9 |
+
// custom min and max to be used in logcumsumexp for complex arguments
|
| 10 |
+
template <typename scalar_t>
|
| 11 |
+
std::pair<c10::complex<scalar_t>, c10::complex<scalar_t>> _logcumsumexp_minmax(c10::complex<scalar_t> x, c10::complex<scalar_t> y) {
|
| 12 |
+
if (at::_isnan(y)) { // either real is nan or imag is nan
|
| 13 |
+
return std::make_pair(y, y);
|
| 14 |
+
} else if (at::_isnan(x)) { // either real is nan or imag is nan
|
| 15 |
+
return std::make_pair(x, x);
|
| 16 |
+
} else {
|
| 17 |
+
return (x.real() < y.real()) ? std::make_pair(x, y) : std::make_pair(y, x);
|
| 18 |
+
}
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
template <typename scalar_t>
|
| 22 |
+
scalar_t _log_add_exp_helper(scalar_t x, scalar_t y) {
|
| 23 |
+
// Reference : https://www.tensorflow.org/api_docs/python/tf/math/cumulative_logsumexp
|
| 24 |
+
scalar_t min = at::_isnan(y) ? y : std::min(x, y); // std::min returns first arg if one of the args is nan
|
| 25 |
+
scalar_t max = at::_isnan(y) ? y : std::max(x, y); // std::max returns first arg if one of the args is nan
|
| 26 |
+
if (min != max || std::isfinite(min)) {
|
| 27 |
+
// nan will be propagated here
|
| 28 |
+
return std::log1p(std::exp(min - max)) + max;
|
| 29 |
+
} else {
|
| 30 |
+
// special case to correctly handle infinite cases
|
| 31 |
+
return x;
|
| 32 |
+
}
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
template <typename scalar_t>
|
| 36 |
+
c10::complex<scalar_t> _log_add_exp_helper(const c10::complex<scalar_t>& x, const c10::complex<scalar_t>& y) {
|
| 37 |
+
auto [min, max] = _logcumsumexp_minmax<scalar_t>(x, y);
|
| 38 |
+
auto min_real = std::real(min);
|
| 39 |
+
auto max_real = std::real(max);
|
| 40 |
+
|
| 41 |
+
if (at::_isnan(min)) { // either real is nan or imag is nan
|
| 42 |
+
// handling the "infectious" NaNs
|
| 43 |
+
return {std::numeric_limits<scalar_t>::quiet_NaN(), std::numeric_limits<scalar_t>::quiet_NaN()};
|
| 44 |
+
} else if (!std::isfinite(min_real) && (min_real == max_real)) {
|
| 45 |
+
if (min_real < 0) {
|
| 46 |
+
// handle the -inf case, the imaginary part here does not really matter as the exp(value)
|
| 47 |
+
// will be around 0.0 and the angle (i.e. the imaginary part) cannot be determined.
|
| 48 |
+
// It does not matter if we're taking the exp of this value
|
| 49 |
+
return min;
|
| 50 |
+
} else {
|
| 51 |
+
// handle the +inf case, we don't need the special precision for log1p for small values
|
| 52 |
+
// and to avoid producing nan in case of real(max) == real(min) == +inf
|
| 53 |
+
return std::log(std::exp(min) + std::exp(max));
|
| 54 |
+
}
|
| 55 |
+
} else {
|
| 56 |
+
return std::log1p(std::exp(min - max)) + max;
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
} // end namespace
|
| 61 |
+
}} //end at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/PixelShuffleKernel.h
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/native/DispatchStub.h>
|
| 3 |
+
|
| 4 |
+
namespace at {
|
| 5 |
+
class TensorBase;
|
| 6 |
+
}
|
| 7 |
+
|
| 8 |
+
namespace at { namespace native {
|
| 9 |
+
|
| 10 |
+
using pixel_shuffle_fn = void(*)(TensorBase&, const TensorBase&, int64_t);
|
| 11 |
+
DECLARE_DISPATCH(pixel_shuffle_fn, pixel_shuffle_kernel);
|
| 12 |
+
DECLARE_DISPATCH(pixel_shuffle_fn, pixel_unshuffle_kernel);
|
| 13 |
+
|
| 14 |
+
}} // at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SampledAddmmKernel.h
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/native/DispatchStub.h>
|
| 5 |
+
|
| 6 |
+
namespace at { namespace native {
|
| 7 |
+
|
| 8 |
+
using sampled_addmm_sparse_csr_fn = void(*)(const Tensor&, const Tensor&, const Scalar&, const Scalar&, const Tensor&);
|
| 9 |
+
|
| 10 |
+
DECLARE_DISPATCH(sampled_addmm_sparse_csr_fn, sampled_addmm_sparse_csr_stub);
|
| 11 |
+
|
| 12 |
+
}} // at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/UpSampleKernelAVXAntialias.h
ADDED
|
@@ -0,0 +1,1376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
The Python Imaging Library (PIL) is
|
| 3 |
+
|
| 4 |
+
Copyright © 1997-2011 by Secret Labs AB
|
| 5 |
+
Copyright © 1995-2011 by Fredrik Lundh
|
| 6 |
+
|
| 7 |
+
Pillow is the friendly PIL fork. It is
|
| 8 |
+
|
| 9 |
+
Copyright © 2010-2022 by Alex Clark and contributors
|
| 10 |
+
|
| 11 |
+
Like PIL, Pillow is licensed under the open source HPND License
|
| 12 |
+
*/
|
| 13 |
+
|
| 14 |
+
// This code is heavily inspired from PILLOW-SIMD's implementation:
|
| 15 |
+
// https://github.com/uploadcare/pillow-simd/blob/simd/master/src/libImaging/Resample.c
|
| 16 |
+
|
| 17 |
+
#pragma once
|
| 18 |
+
#ifdef CPU_CAPABILITY_AVX2
|
| 19 |
+
// TODO: This file only supports AVX2. We could split the AVX kernels into
|
| 20 |
+
// smaller logical blocks in order to port them into the Vec.h logic. This would
|
| 21 |
+
// allow to support other vectorization architectures and perhaps also support
|
| 22 |
+
// the non-vectorized fallback (we'd need to make sure it's not slower than the
|
| 23 |
+
// current fallback).
|
| 24 |
+
|
| 25 |
+
#include <ATen/core/Tensor.h>
|
| 26 |
+
#include <ATen/cpu/vec/intrinsics.h>
|
| 27 |
+
#include <c10/util/irange.h>
|
| 28 |
+
|
| 29 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 30 |
+
#include <ATen/Functions.h>
|
| 31 |
+
#else
|
| 32 |
+
#include <ATen/ops/empty.h>
|
| 33 |
+
#endif
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
namespace {
|
| 37 |
+
|
| 38 |
+
static inline __m128i mm_cvtsi32_si128(const uint8_t* C10_RESTRICT ptr, bool i32_aligned) {
|
| 39 |
+
int32_t v;
|
| 40 |
+
if (i32_aligned) {
|
| 41 |
+
v = *(const int32_t*)ptr;
|
| 42 |
+
} else {
|
| 43 |
+
std::memcpy(&v, ptr, 4);
|
| 44 |
+
}
|
| 45 |
+
return _mm_cvtsi32_si128(v);
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
static inline __m128i mm_cvtepu8_epi32(const uint8_t* C10_RESTRICT ptr, bool i32_aligned) {
|
| 49 |
+
return _mm_cvtepu8_epi32(mm_cvtsi32_si128(ptr, i32_aligned));
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
static inline void _write_endline_rgb_as_uint32(
|
| 53 |
+
uint8_t* C10_RESTRICT output,
|
| 54 |
+
uint32_t data
|
| 55 |
+
) {
|
| 56 |
+
// data is (R G B X), output is (X1 X2 X3 | R1 B1 G1 R2 ...)
|
| 57 |
+
// Here we explicitly set X as R1
|
| 58 |
+
uint8_t* data_ptr = reinterpret_cast<uint8_t*>(&data);
|
| 59 |
+
data_ptr[3] = output[3];
|
| 60 |
+
std::memcpy(output, data_ptr, 4);
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
at::Tensor unpack_rgb(const at::Tensor& packed_tensor) {
|
| 64 |
+
// Convert a "packed" tensor (typically RGBRGBRGB if channels_last) into
|
| 65 |
+
// RGBARGBARGBA format where A is hard-coded to 0. Each pixel is encoded
|
| 66 |
+
// into as 32 bits. This generalizes to num_channels <= 4 and also works for
|
| 67 |
+
// non-channels_last tensors.
|
| 68 |
+
|
| 69 |
+
const uint8_t* packed = (const uint8_t*)packed_tensor.data_ptr<uint8_t>();
|
| 70 |
+
auto num_pixels = packed_tensor.size(1) * packed_tensor.size(2);
|
| 71 |
+
auto num_channels = packed_tensor.size(0);
|
| 72 |
+
|
| 73 |
+
constexpr int rgba_size = 4;
|
| 74 |
+
auto unpacked_tensor = at::empty({rgba_size, packed_tensor.size(1), packed_tensor.size(2)}, at::CPU(at::kByte));
|
| 75 |
+
uint8_t* unpacked = (uint8_t*) unpacked_tensor.data_ptr<uint8_t>();
|
| 76 |
+
|
| 77 |
+
auto stride_i = packed_tensor.stride(2);
|
| 78 |
+
auto stride_j = packed_tensor.stride(0);
|
| 79 |
+
|
| 80 |
+
for (const auto i : c10::irange(num_pixels)) {
|
| 81 |
+
for (const auto j : c10::irange(rgba_size)) {
|
| 82 |
+
unpacked[rgba_size * i + j] = (j < num_channels) ? packed[stride_i * i + stride_j * j] : 0;
|
| 83 |
+
}
|
| 84 |
+
}
|
| 85 |
+
return unpacked_tensor;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
void pack_rgb(
|
| 89 |
+
const at::Tensor& unpacked_tensor, // IN
|
| 90 |
+
const at::Tensor& packed_tensor // OUT
|
| 91 |
+
) {
|
| 92 |
+
// Convert from unpacked channels last 3-channels or 4-channels tensor into original data layout.
|
| 93 |
+
|
| 94 |
+
uint8_t* unpacked = (uint8_t*)unpacked_tensor.data_ptr<uint8_t>();
|
| 95 |
+
uint8_t* packed = (uint8_t*)packed_tensor.data_ptr<uint8_t>();
|
| 96 |
+
auto num_pixels = packed_tensor.size(1) * packed_tensor.size(2);
|
| 97 |
+
auto num_channels = packed_tensor.size(0);
|
| 98 |
+
|
| 99 |
+
auto unpacked_increment = unpacked_tensor.size(0);
|
| 100 |
+
auto packed_increment = packed_tensor.stride(2);
|
| 101 |
+
auto packed_stride = packed_tensor.stride(0);
|
| 102 |
+
|
| 103 |
+
TORCH_INTERNAL_ASSERT(unpacked_increment == 3 || unpacked_increment == 4);
|
| 104 |
+
|
| 105 |
+
for (const auto i C10_UNUSED : c10::irange(num_pixels)) {
|
| 106 |
+
for (const auto j : c10::irange(num_channels)) {
|
| 107 |
+
packed[j * packed_stride] = unpacked[j];
|
| 108 |
+
}
|
| 109 |
+
unpacked += unpacked_increment;
|
| 110 |
+
packed += packed_increment;
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
void ImagingResampleHorizontalConvolution8u4x(
|
| 115 |
+
uint8_t* C10_RESTRICT lineOut0,
|
| 116 |
+
uint8_t* C10_RESTRICT lineOut1,
|
| 117 |
+
uint8_t* C10_RESTRICT lineOut2,
|
| 118 |
+
uint8_t* C10_RESTRICT lineOut3,
|
| 119 |
+
int64_t out_xsize,
|
| 120 |
+
const uint8_t* C10_RESTRICT lineIn0,
|
| 121 |
+
const uint8_t* C10_RESTRICT lineIn1,
|
| 122 |
+
const uint8_t* C10_RESTRICT lineIn2,
|
| 123 |
+
const uint8_t* C10_RESTRICT lineIn3,
|
| 124 |
+
int64_t in_xsize,
|
| 125 |
+
const int64_t* idx_ptr_xmin,
|
| 126 |
+
const int64_t* idx_ptr_size,
|
| 127 |
+
const int16_t* kk,
|
| 128 |
+
int kmax,
|
| 129 |
+
unsigned int coefs_precision,
|
| 130 |
+
int64_t num_channels,
|
| 131 |
+
bool is_last_line);
|
| 132 |
+
|
| 133 |
+
void ImagingResampleHorizontalConvolution8u(
|
| 134 |
+
uint8_t* C10_RESTRICT lineOut,
|
| 135 |
+
int64_t out_xsize,
|
| 136 |
+
const uint8_t* C10_RESTRICT lineIn,
|
| 137 |
+
int64_t in_xsize,
|
| 138 |
+
const int64_t* idx_ptr_xmin,
|
| 139 |
+
const int64_t* idx_ptr_size,
|
| 140 |
+
const int16_t* kk,
|
| 141 |
+
int kmax,
|
| 142 |
+
unsigned int coefs_precision,
|
| 143 |
+
int64_t num_channels,
|
| 144 |
+
bool is_last_line);
|
| 145 |
+
|
| 146 |
+
void ImagingResampleVerticalConvolution8u(
|
| 147 |
+
uint8_t* C10_RESTRICT lineOut,
|
| 148 |
+
const uint8_t* C10_RESTRICT lineIn,
|
| 149 |
+
int64_t xsize,
|
| 150 |
+
int64_t ids_min,
|
| 151 |
+
int64_t ids_size,
|
| 152 |
+
const int16_t* k,
|
| 153 |
+
unsigned int coefs_precision,
|
| 154 |
+
int64_t num_channels);
|
| 155 |
+
|
| 156 |
+
template<int num_channels>
|
| 157 |
+
void ImagingResampleHorizontal(
|
| 158 |
+
const at::Tensor & unpacked_output,
|
| 159 |
+
const at::Tensor & unpacked_input,
|
| 160 |
+
int ksize,
|
| 161 |
+
const std::vector<at::Tensor>& horiz_indices_weights,
|
| 162 |
+
unsigned int horiz_weights_precision) {
|
| 163 |
+
|
| 164 |
+
// Interpolation horizontal pass: we compute x-axis (image width) interpolation outputs.
|
| 165 |
+
|
| 166 |
+
// Input data is stored as
|
| 167 |
+
// input = [r[0], g[0], b[0], a[0], r[1], g[1], b[1], a[1], r[2], g[2], b[2], a[2], ...]
|
| 168 |
+
// Weights are float values computed for each output pixel and rescaled to uint16:
|
| 169 |
+
// weights[i] = [w[i, 0], w[i, 1], ..., w[i, K-1]]
|
| 170 |
+
// We want to compute the output as following:
|
| 171 |
+
// output = [oR[0], oG[0], oB[0], oA[0], oR[1], oG[1], oB[1], oA[1], ...]
|
| 172 |
+
// where
|
| 173 |
+
// oR[yoffset + i] = r[yoffset + xmin[i]] * w[i, 0] + ... + r[yoffset + xmin[i] + K-1] * w[i, K-1]
|
| 174 |
+
// oG[yoffset + i] = g[yoffset + xmin[i]] * w[i, 0] + ... + g[yoffset + xmin[i] + K-1] * w[i, K-1]
|
| 175 |
+
// oB[yoffset + i] = b[yoffset + xmin[i]] * w[i, 0] + ... + b[yoffset + xmin[i] + K-1] * w[i, K-1]
|
| 176 |
+
//
|
| 177 |
+
|
| 178 |
+
// TODO: we may want to merge that into the fallback code (currently called
|
| 179 |
+
// basic_loop_aa_horizontal<uint8_t>)
|
| 180 |
+
// Although this may not be needed if / when we port all this code to use
|
| 181 |
+
// Vec.h since this would potentially give us another fall-back implem
|
| 182 |
+
|
| 183 |
+
const int16_t* kk = (int16_t*)(horiz_indices_weights[3].data_ptr<double>());
|
| 184 |
+
|
| 185 |
+
auto xout = unpacked_output.size(2);
|
| 186 |
+
auto yout = unpacked_output.size(1);
|
| 187 |
+
auto xin = unpacked_input.size(2);
|
| 188 |
+
TORCH_INTERNAL_ASSERT(num_channels == unpacked_input.size(0));
|
| 189 |
+
|
| 190 |
+
const int64_t* idx_ptr_xmin = horiz_indices_weights[0].data_ptr<int64_t>();
|
| 191 |
+
const int64_t* idx_ptr_size = horiz_indices_weights[1].data_ptr<int64_t>();
|
| 192 |
+
|
| 193 |
+
uint8_t* unpacked_output_p = unpacked_output.data_ptr<uint8_t>();
|
| 194 |
+
const uint8_t* unpacked_input_p = unpacked_input.data_ptr<uint8_t>();
|
| 195 |
+
|
| 196 |
+
int64_t yy = 0;
|
| 197 |
+
auto xout_stride = xout * num_channels;
|
| 198 |
+
auto xin_stride = xin * num_channels;
|
| 199 |
+
for (; yy < yout - 3; yy += 4) {
|
| 200 |
+
ImagingResampleHorizontalConvolution8u4x(
|
| 201 |
+
unpacked_output_p + yy * xout_stride,
|
| 202 |
+
unpacked_output_p + (yy + 1) * xout_stride,
|
| 203 |
+
unpacked_output_p + (yy + 2) * xout_stride,
|
| 204 |
+
unpacked_output_p + (yy + 3) * xout_stride,
|
| 205 |
+
xout,
|
| 206 |
+
unpacked_input_p + yy * xin_stride,
|
| 207 |
+
unpacked_input_p + (yy + 1) * xin_stride,
|
| 208 |
+
unpacked_input_p + (yy + 2) * xin_stride,
|
| 209 |
+
unpacked_input_p + (yy + 3) * xin_stride,
|
| 210 |
+
xin,
|
| 211 |
+
idx_ptr_xmin,
|
| 212 |
+
idx_ptr_size,
|
| 213 |
+
kk,
|
| 214 |
+
ksize,
|
| 215 |
+
horiz_weights_precision,
|
| 216 |
+
num_channels,
|
| 217 |
+
yy + 3 == yout - 1);
|
| 218 |
+
}
|
| 219 |
+
for (; yy < yout; yy++) {
|
| 220 |
+
ImagingResampleHorizontalConvolution8u(
|
| 221 |
+
unpacked_output_p + yy * xout_stride,
|
| 222 |
+
xout,
|
| 223 |
+
unpacked_input_p + yy * xin_stride,
|
| 224 |
+
xin,
|
| 225 |
+
idx_ptr_xmin,
|
| 226 |
+
idx_ptr_size,
|
| 227 |
+
kk,
|
| 228 |
+
ksize,
|
| 229 |
+
horiz_weights_precision,
|
| 230 |
+
num_channels,
|
| 231 |
+
yy == yout - 1);
|
| 232 |
+
}
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
void ImagingResampleVertical(
|
| 236 |
+
const at::Tensor & unpacked_output,
|
| 237 |
+
const at::Tensor & unpacked_input,
|
| 238 |
+
int ksize,
|
| 239 |
+
const std::vector<at::Tensor>& vert_indices_weights,
|
| 240 |
+
unsigned int vert_weights_precision) {
|
| 241 |
+
|
| 242 |
+
// Interpolation vertical pass: we compute y-axis interpolation outputs.
|
| 243 |
+
// Input data is stored as
|
| 244 |
+
// input = [r[0], g[0], b[0], a[0], r[1], g[1], b[1], a[1], r[2], g[2], b[2], a[2], ...]
|
| 245 |
+
// Weights are float values computed for each output pixel and rescaled to uint16:
|
| 246 |
+
// weights[i] = [w[i, 0], w[i, 1], ..., w[i, K-1]]
|
| 247 |
+
// We want to compute the output as following:
|
| 248 |
+
// output = [oR[0], oG[0], oB[0], oA[0], oR[1], oG[1], oB[1], oA[1], ...]
|
| 249 |
+
// where
|
| 250 |
+
// oR[xoffset + i] = r[xoffset + ymin[i]] * w[i, 0] + ... + r[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1]
|
| 251 |
+
// oG[xoffset + i] = g[xoffset + ymin[i]] * w[i, 0] + ... + g[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1]
|
| 252 |
+
// oB[xoffset + i] = b[xoffset + ymin[i]] * w[i, 0] + ... + b[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1]
|
| 253 |
+
|
| 254 |
+
// TODO: we may want to merge that into the fallback code (currently called
|
| 255 |
+
// basic_loop_aa_vertical<uint8_t>)
|
| 256 |
+
// Although this may not be needed if / when we port all this code to use
|
| 257 |
+
// Vec.h since this would potentially give us another fall-back implem
|
| 258 |
+
const int16_t* kk = (int16_t*)(vert_indices_weights[3].data_ptr<double>());
|
| 259 |
+
|
| 260 |
+
const int64_t* idx_ptr_xmin = vert_indices_weights[0].data_ptr<int64_t>();
|
| 261 |
+
const int64_t* idx_ptr_size = vert_indices_weights[1].data_ptr<int64_t>();
|
| 262 |
+
|
| 263 |
+
uint8_t* unpacked_output_p = unpacked_output.data_ptr<uint8_t>();
|
| 264 |
+
const uint8_t* unpacked_input_p = unpacked_input.data_ptr<uint8_t>();
|
| 265 |
+
|
| 266 |
+
auto xout = unpacked_output.size(2);
|
| 267 |
+
auto yout = unpacked_output.size(1);
|
| 268 |
+
const auto num_channels = unpacked_input.size(0);
|
| 269 |
+
TORCH_INTERNAL_ASSERT(num_channels == unpacked_output.size(0));
|
| 270 |
+
|
| 271 |
+
auto xout_stride = xout * num_channels;
|
| 272 |
+
for (const auto yy : c10::irange(yout)) {
|
| 273 |
+
const auto* k = &kk[yy * ksize];
|
| 274 |
+
auto ids_min = idx_ptr_xmin[yy];
|
| 275 |
+
auto ids_size = idx_ptr_size[yy];
|
| 276 |
+
ImagingResampleVerticalConvolution8u(
|
| 277 |
+
unpacked_output_p + yy * xout_stride,
|
| 278 |
+
unpacked_input_p,
|
| 279 |
+
xout,
|
| 280 |
+
ids_min,
|
| 281 |
+
ids_size,
|
| 282 |
+
k,
|
| 283 |
+
vert_weights_precision,
|
| 284 |
+
num_channels);
|
| 285 |
+
}
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
// This is the only public entry point in this file. It supports bilinear or bicubic
|
| 289 |
+
// mode for uint8 dtype when C <= 4, with or without antialias. The
|
| 290 |
+
// implem is based on PIL-SIMD.
|
| 291 |
+
// Its equivalent implementation (fallback) for when AVX isn't supported or when
|
| 292 |
+
// C > 4 is separable_upsample_generic_Nd_kernel_impl() There are a bunch of
|
| 293 |
+
// future improvement that can be done: look for the TODOs in this file.
|
| 294 |
+
// For details on how the weights are computed and how the multiplications are
|
| 295 |
+
// run on int (instead of float weights), see
|
| 296 |
+
// [ Weights computation for uint8_t and multiplication trick ]
|
| 297 |
+
// For details on how the AVX kernels are implemented, see
|
| 298 |
+
// https://gist.github.com/NicolasHug/47c97d731f05eaad5694c173849b86f5
|
| 299 |
+
// See also [ Support for antialias=False as a subcase of antialias=True ] to
|
| 300 |
+
// learn more about how the antialias=False case is computed. The same holds
|
| 301 |
+
// here: all these kernels are general enough to handle an arbitrary number of
|
| 302 |
+
// weights, but when aa=False they could be optimized further.
|
| 303 |
+
template <typename scale_type, class F>
|
| 304 |
+
void upsample_avx_bilinear_bicubic_uint8(
|
| 305 |
+
const at::Tensor& input_,
|
| 306 |
+
const at::Tensor& output,
|
| 307 |
+
bool align_corners,
|
| 308 |
+
const scale_type& scales,
|
| 309 |
+
bool antialias) {
|
| 310 |
+
auto batch_size = input_.size(0);
|
| 311 |
+
auto num_channels = input_.size(1);
|
| 312 |
+
auto xin = input_.size(3);
|
| 313 |
+
auto yin = input_.size(2);
|
| 314 |
+
auto xout = output.size(3);
|
| 315 |
+
auto yout = output.size(2);
|
| 316 |
+
|
| 317 |
+
if (xin == xout && yin == yout) {
|
| 318 |
+
output.copy_(input_);
|
| 319 |
+
return;
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
at::Tensor input = input_;
|
| 323 |
+
if (!(input.is_contiguous() || input.is_contiguous(at::MemoryFormat::ChannelsLast))) {
|
| 324 |
+
// If input is not contiguous with memory format channels first or channels last,
|
| 325 |
+
// we explicitly convert the input to contiguous channels last memory format.
|
| 326 |
+
// This simplifies the rest of the code and let us assume that the format is only contiguous channels first or channels last,
|
| 327 |
+
// Most tensors going through this `if` block won't need to go through unpacking, but those having C < 3 may
|
| 328 |
+
// have to (this means 2 copies are made). We could avoid the extra copy by handling non-contiguous input
|
| 329 |
+
// directly within unpack_rgb() and pack_rgb(), but initial attempts showed that this is fairly complex.
|
| 330 |
+
input = input.contiguous(at::MemoryFormat::ChannelsLast);
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
auto need_horizontal = xout != xin;
|
| 334 |
+
auto need_vertical = yout != yin;
|
| 335 |
+
|
| 336 |
+
int ksize_horiz, ksize_vert;
|
| 337 |
+
std::vector<at::Tensor> horiz_indices_weights, vert_indices_weights;
|
| 338 |
+
unsigned int horiz_weights_precision, vert_weights_precision;
|
| 339 |
+
|
| 340 |
+
bool skip_unpacking = (num_channels == 3 || num_channels == 4) && input.is_contiguous(at::MemoryFormat::ChannelsLast);
|
| 341 |
+
bool skip_packing = (num_channels == 3 || num_channels == 4) && output.is_contiguous(at::MemoryFormat::ChannelsLast);
|
| 342 |
+
|
| 343 |
+
if (need_horizontal) {
|
| 344 |
+
int interp_dim = 3;
|
| 345 |
+
auto stride = (skip_unpacking) ? num_channels : 4;
|
| 346 |
+
std::tie(horiz_indices_weights, ksize_horiz, horiz_weights_precision) =
|
| 347 |
+
F::compute_index_ranges_int16_weights(
|
| 348 |
+
/*input_size=*/xin,
|
| 349 |
+
/*output_size=*/xout,
|
| 350 |
+
/*stride=*/stride,
|
| 351 |
+
/*ndims=*/4,
|
| 352 |
+
/*reshape_dim=*/interp_dim,
|
| 353 |
+
/*align_corners=*/align_corners,
|
| 354 |
+
/*opt_scale=*/scales[interp_dim - 2],
|
| 355 |
+
/*antialias=*/antialias,
|
| 356 |
+
/*align_i32=*/true);
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
if (need_vertical) {
|
| 360 |
+
int interp_dim = 2;
|
| 361 |
+
auto stride = (skip_unpacking) ? num_channels * xout : 4 * xout;
|
| 362 |
+
std::tie(vert_indices_weights, ksize_vert, vert_weights_precision) =
|
| 363 |
+
F::compute_index_ranges_int16_weights(
|
| 364 |
+
/*input_size=*/yin,
|
| 365 |
+
/*output_size=*/yout,
|
| 366 |
+
/*stride=*/stride,
|
| 367 |
+
/*ndims=*/4,
|
| 368 |
+
/*reshape_dim=*/interp_dim,
|
| 369 |
+
/*align_corners=*/align_corners,
|
| 370 |
+
/*opt_scale=*/scales[interp_dim - 2],
|
| 371 |
+
/*antialias=*/antialias,
|
| 372 |
+
/*align_i32=*/true);
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
at::Tensor buffer_horiz, buffer_vert;
|
| 376 |
+
// Minor optimization: we can avoid allocating an extra buffer if we're performing
|
| 377 |
+
// horizontal-only or vertical-only interpolation, and if the tensor doesn't
|
| 378 |
+
// need repacking
|
| 379 |
+
if (need_horizontal && (need_vertical || !skip_packing)) {
|
| 380 |
+
auto c = (skip_unpacking) ? num_channels : 4;
|
| 381 |
+
buffer_horiz = at::empty({c, yin, xout}, input.options());
|
| 382 |
+
}
|
| 383 |
+
if (need_vertical && !skip_packing) {
|
| 384 |
+
auto c = (skip_unpacking) ? num_channels : 4;
|
| 385 |
+
buffer_vert = at::empty({c, yout, xout}, input.options());
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
for (const auto i : c10::irange(batch_size)) {
|
| 389 |
+
|
| 390 |
+
at::Tensor unpacked_input = (skip_unpacking) ? input[i] : unpack_rgb(input[i]);
|
| 391 |
+
at::Tensor unpacked_output;
|
| 392 |
+
|
| 393 |
+
if (need_horizontal) {
|
| 394 |
+
at::Tensor unpacked_output_temp = (need_vertical || !skip_packing) ? buffer_horiz : output[i];
|
| 395 |
+
|
| 396 |
+
if (skip_unpacking && num_channels == 3) {
|
| 397 |
+
ImagingResampleHorizontal<3>(
|
| 398 |
+
unpacked_output_temp,
|
| 399 |
+
unpacked_input,
|
| 400 |
+
ksize_horiz,
|
| 401 |
+
horiz_indices_weights,
|
| 402 |
+
horiz_weights_precision);
|
| 403 |
+
} else {
|
| 404 |
+
ImagingResampleHorizontal<4>(
|
| 405 |
+
unpacked_output_temp,
|
| 406 |
+
unpacked_input,
|
| 407 |
+
ksize_horiz,
|
| 408 |
+
horiz_indices_weights,
|
| 409 |
+
horiz_weights_precision);
|
| 410 |
+
}
|
| 411 |
+
unpacked_output = unpacked_input = unpacked_output_temp;
|
| 412 |
+
}
|
| 413 |
+
if (need_vertical) {
|
| 414 |
+
unpacked_output = (skip_packing) ? output[i] : buffer_vert;
|
| 415 |
+
|
| 416 |
+
ImagingResampleVertical(
|
| 417 |
+
unpacked_output,
|
| 418 |
+
unpacked_input,
|
| 419 |
+
ksize_vert,
|
| 420 |
+
vert_indices_weights,
|
| 421 |
+
vert_weights_precision
|
| 422 |
+
);
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
TORCH_INTERNAL_ASSERT(unpacked_output.defined());
|
| 426 |
+
|
| 427 |
+
if (!skip_packing) {
|
| 428 |
+
pack_rgb(unpacked_output, output[i]);
|
| 429 |
+
}
|
| 430 |
+
}
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
void ImagingResampleHorizontalConvolution8u4x(
|
| 434 |
+
uint8_t* C10_RESTRICT lineOut0,
|
| 435 |
+
uint8_t* C10_RESTRICT lineOut1,
|
| 436 |
+
uint8_t* C10_RESTRICT lineOut2,
|
| 437 |
+
uint8_t* C10_RESTRICT lineOut3,
|
| 438 |
+
int64_t out_xsize,
|
| 439 |
+
const uint8_t* C10_RESTRICT lineIn0,
|
| 440 |
+
const uint8_t* C10_RESTRICT lineIn1,
|
| 441 |
+
const uint8_t* C10_RESTRICT lineIn2,
|
| 442 |
+
const uint8_t* C10_RESTRICT lineIn3,
|
| 443 |
+
int64_t in_xsize,
|
| 444 |
+
const int64_t* idx_ptr_xmin,
|
| 445 |
+
const int64_t* idx_ptr_size,
|
| 446 |
+
const int16_t* kk,
|
| 447 |
+
int kmax,
|
| 448 |
+
unsigned int coefs_precision,
|
| 449 |
+
int64_t num_channels,
|
| 450 |
+
bool is_last_line) {
|
| 451 |
+
|
| 452 |
+
// Interpolation horizontal pass processing together 4 vertical lines.
|
| 453 |
+
// - Input data format is RGBA or RGB with R,G,B,A being uint8. In case of RGBA
|
| 454 |
+
// we can encode 4 values as a single uint32 value.
|
| 455 |
+
// - We split the size of weight vector for a given output index as a sum:
|
| 456 |
+
// ids_size = num_blocks_4 * 4 + num_blocks_2 * 2 + num_blocks_1.
|
| 457 |
+
// - We load and process 4 weights values in a loop ("block 4") then we process 2 weights values
|
| 458 |
+
// in another loop ("block 2") and finally we process 1 weights value in the final loop ("block 1").
|
| 459 |
+
|
| 460 |
+
// Define shuffling masks (low/high) for num_channels 4 and 3
|
| 461 |
+
// Mask low casts lower half of each lane to epi16 and reorder RGBARGBA -> RRGGBBAA:
|
| 462 |
+
// [r1 g1 b1 a1 r2 g2 b2 a2 ... | R1 G1 B1 A1 R2 G2 B2 A2 ... ] ->
|
| 463 |
+
// [r1 0 r2 0 g1 0 g2 0 b1 0 b2 0 a1 0 a2 0 | R1 0 R2 0 G1 0 G2 0 B1 0 B2 0 A1 0 A2 0]
|
| 464 |
+
// Mask high casts upper half of each lane to epi16 and reorder RGBARGBA -> RRGGBBAA::
|
| 465 |
+
// [ ... r3 g3 b3 a3 r4 g4 b4 a4 | ... R3 G3 B3 A3 R4 G4 B4 A4 ] ->
|
| 466 |
+
// [r3 0 r4 0 g3 0 g4 0 b3 0 b4 0 a3 0 a4 0 | R3 0 R4 0 G3 0 G4 0 B3 0 B4 0 A3 0 A4 0]
|
| 467 |
+
|
| 468 |
+
const auto mask_low_c4 = _mm256_set_epi8(
|
| 469 |
+
-1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0,
|
| 470 |
+
-1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);
|
| 471 |
+
const auto mask_high_c4 = _mm256_set_epi8(
|
| 472 |
+
-1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8,
|
| 473 |
+
-1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8);
|
| 474 |
+
const auto mask_low_c3 = _mm256_set_epi8(
|
| 475 |
+
-1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0,
|
| 476 |
+
-1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0);
|
| 477 |
+
const auto mask_high_c3 = _mm256_set_epi8(
|
| 478 |
+
-1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6,
|
| 479 |
+
-1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6);
|
| 480 |
+
|
| 481 |
+
const auto mask_low = (num_channels == 3) ? mask_low_c3 : mask_low_c4;
|
| 482 |
+
const auto mask_high = (num_channels == 3) ? mask_high_c3 : mask_high_c4;
|
| 483 |
+
|
| 484 |
+
const auto stride = num_channels * sizeof(uint8_t);
|
| 485 |
+
|
| 486 |
+
TORCH_INTERNAL_ASSERT(stride == 3 || stride == 4);
|
| 487 |
+
|
| 488 |
+
// out_xsize = output width, out_x = output x index
|
| 489 |
+
// ids_min is the input offset index corresponding to out_x
|
| 490 |
+
// ids_size is the interpolation size for out_x
|
| 491 |
+
|
| 492 |
+
// Let's precompute ids_size limits for block 4 and block 2.
|
| 493 |
+
//
|
| 494 |
+
// In block 4 (4 means we process 4 weight values together), we read input data
|
| 495 |
+
// with _mm_loadu_si128, i.e. 16 bytes, per one line:
|
| 496 |
+
// lineIn0 + stride * (i + ids_min) + 16 <= lineIn0 + stride * (ids_size + ids_min)
|
| 497 |
+
// --> i <= ids_size - 16.0 / stride
|
| 498 |
+
// Strict boundary:
|
| 499 |
+
// --> i < ids_size + 1 - int(ceil(16.0 / stride)) = ids_size - b4_delta
|
| 500 |
+
// Soft boundary for reading inside the buffer except its boundaries:
|
| 501 |
+
// --> i < ids_size + 1 - int(16.0 / stride) = ids_size - b4_delta_soft
|
| 502 |
+
// RGBA: b4_delta = b4_delta_soft = 3
|
| 503 |
+
// RGB : b4_delta = 5
|
| 504 |
+
// RGB : b4_delta_soft = 4
|
| 505 |
+
const auto b4_delta = (stride == 4) ? 3 : ((is_last_line) ? 5 : 4);
|
| 506 |
+
|
| 507 |
+
// In block 2 (2 means we process 2 weights values together), we read input data
|
| 508 |
+
// with _mm_loadl_epi64, i.e. 8 bytes, per one line:
|
| 509 |
+
// lineIn0 + stride * (i + ids_min) + 8 <= lineIn0 + stride * (ids_size + ids_min)
|
| 510 |
+
// --> i <= ids_size - 8.0 / stride
|
| 511 |
+
// Strict boundary:
|
| 512 |
+
// --> i < ids_size + 1 - int(ceil(8.0 / stride)) = ids_size - b2_delta
|
| 513 |
+
// Soft boundary for reading inside the buffer except its boundaries:
|
| 514 |
+
// --> i < ids_size + 1 - int(8.0 / stride) = ids_size - b2_delta_soft
|
| 515 |
+
// RGBA: b2_delta = b2_delta_soft = 1
|
| 516 |
+
// RGB : b2_delta = 2
|
| 517 |
+
// RGB : b2_delta_soft = 1
|
| 518 |
+
const auto b2_delta = (stride == 4) ? 1 : ((is_last_line) ? 2 : 1);
|
| 519 |
+
|
| 520 |
+
const auto max_out_x_strided = out_xsize * stride;
|
| 521 |
+
const auto max_in_x_strided = in_xsize * stride;
|
| 522 |
+
|
| 523 |
+
const auto zero = _mm256_setzero_si256();
|
| 524 |
+
const auto initial = _mm256_set1_epi32(1 << (coefs_precision - 1));
|
| 525 |
+
|
| 526 |
+
for (const auto out_x : c10::irange(out_xsize)) {
|
| 527 |
+
const auto ids_min = idx_ptr_xmin[out_x];
|
| 528 |
+
const auto ids_size = idx_ptr_size[out_x];
|
| 529 |
+
const auto * k = &kk[out_x * kmax];
|
| 530 |
+
int64_t i = 0;
|
| 531 |
+
|
| 532 |
+
auto sss0 = initial;
|
| 533 |
+
auto sss1 = initial;
|
| 534 |
+
|
| 535 |
+
const auto * lineIn0_min = lineIn0 + ids_min;
|
| 536 |
+
const auto * lineIn1_min = lineIn1 + ids_min;
|
| 537 |
+
const auto * lineIn2_min = lineIn2 + ids_min;
|
| 538 |
+
const auto * lineIn3_min = lineIn3 + ids_min;
|
| 539 |
+
|
| 540 |
+
// block 4
|
| 541 |
+
for (; i < ids_size - b4_delta; i += 4) {
|
| 542 |
+
// Load 4 values from weight vector
|
| 543 |
+
// mmk0 = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ...]
|
| 544 |
+
// mmk1 = [wl_2 wh_2 wl_3 wh_3 wl_2 wh_2 wl_3 wh_3 ...]
|
| 545 |
+
const auto mmk0 = _mm256_set1_epi32(*(int32_t*)&k[i]);
|
| 546 |
+
const auto mmk1 = _mm256_set1_epi32(*(int32_t*)&k[i + 2]);
|
| 547 |
+
|
| 548 |
+
// RGBA: Load 8 pixels (4 per line) from input lines 0 and 1:
|
| 549 |
+
// source = [
|
| 550 |
+
// r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3
|
| 551 |
+
// R0 G0 B0 A0 R1 G1 B1 A1 R2 G2 B2 A2 R3 G3 B3 A3
|
| 552 |
+
// ]
|
| 553 |
+
// RGB: Load 10 pixels (5 per line)
|
| 554 |
+
// source = [
|
| 555 |
+
// r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5
|
| 556 |
+
// R0 G0 B0 R1 G1 B1 R2 G2 B2 R3 G3 B3 R4 G4 B4 R5
|
| 557 |
+
// ]
|
| 558 |
+
auto source = _mm256_inserti128_si256(_mm256_castsi128_si256(
|
| 559 |
+
_mm_loadu_si128((__m128i *) (lineIn0_min + stride * i))),
|
| 560 |
+
_mm_loadu_si128((__m128i *) (lineIn1_min + stride * i)), 1);
|
| 561 |
+
|
| 562 |
+
// Apply mask_low:
|
| 563 |
+
// RGBA:
|
| 564 |
+
// [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 A0 0 A1 0]
|
| 565 |
+
// RGB:
|
| 566 |
+
// [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 0 0 0 0]
|
| 567 |
+
auto pix1 = _mm256_shuffle_epi8(source, mask_low);
|
| 568 |
+
// Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision
|
| 569 |
+
sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk0));
|
| 570 |
+
|
| 571 |
+
// Apply mask_high:
|
| 572 |
+
// RGBA:
|
| 573 |
+
// [r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 a2 0 a3 0 | R2 0 R3 0 G2 0 G3 0 B2 0 B3 0 A2 0 A3 0]
|
| 574 |
+
// RGB:
|
| 575 |
+
// [r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 0 0 0 0 | R2 0 R3 0 G2 0 G3 0 B2 0 B3 0 0 0 0 0]
|
| 576 |
+
auto pix2 = _mm256_shuffle_epi8(source, mask_high);
|
| 577 |
+
// Compute output value as C += w2 * C2 + w3 * C3 for each channel in 32-bit precision
|
| 578 |
+
sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix2, mmk1));
|
| 579 |
+
|
| 580 |
+
// Same as above to next lines 2 and 3:
|
| 581 |
+
auto source2 = _mm256_inserti128_si256(_mm256_castsi128_si256(
|
| 582 |
+
_mm_loadu_si128((__m128i *) (lineIn2_min + stride * i))),
|
| 583 |
+
_mm_loadu_si128((__m128i *) (lineIn3_min + stride * i)), 1);
|
| 584 |
+
auto pix3 = _mm256_shuffle_epi8(source2, mask_low);
|
| 585 |
+
sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix3, mmk0));
|
| 586 |
+
auto pix4 = _mm256_shuffle_epi8(source2, mask_high);
|
| 587 |
+
sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix4, mmk1));
|
| 588 |
+
}
|
| 589 |
+
|
| 590 |
+
// block 2
|
| 591 |
+
for (; i < ids_size - b2_delta; i += 2) {
|
| 592 |
+
// Load 2 values from weight vector
|
| 593 |
+
// mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ...]
|
| 594 |
+
const auto mmk = _mm256_set1_epi32(*(int32_t*)&k[i]);
|
| 595 |
+
|
| 596 |
+
// Load 4 pixels (2 per line) from input lines 0 and 1:
|
| 597 |
+
// RGBA: source1 = [
|
| 598 |
+
// r0 g0 b0 a0 r1 g1 b1 a1 0 0 0 0 0 0 0 0
|
| 599 |
+
// R0 G0 B0 A0 R1 G1 B1 A1 0 0 0 0 0 0 0 0
|
| 600 |
+
// ]
|
| 601 |
+
// RGB: source1 = [
|
| 602 |
+
// r0 g0 b0 r1 g1 b1 r2 0 0 0 0 0 0 0 0
|
| 603 |
+
// R0 G0 B0 R1 G1 B1 R2 0 0 0 0 0 0 0 0
|
| 604 |
+
// ]
|
| 605 |
+
auto source1 = _mm256_inserti128_si256(_mm256_castsi128_si256(
|
| 606 |
+
_mm_loadl_epi64((__m128i *) (lineIn0_min + stride * i))),
|
| 607 |
+
_mm_loadl_epi64((__m128i *) (lineIn1_min + stride * i)), 1);
|
| 608 |
+
// Apply mask_low:
|
| 609 |
+
// RGBA:
|
| 610 |
+
// [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 A0 0 A1 0]
|
| 611 |
+
// RGB:
|
| 612 |
+
// [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 0 0 0 0]
|
| 613 |
+
auto pix1 = _mm256_shuffle_epi8(source1, mask_low);
|
| 614 |
+
// Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision
|
| 615 |
+
sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk));
|
| 616 |
+
|
| 617 |
+
// Same as above for lines 2 and 3:
|
| 618 |
+
auto source2 = _mm256_inserti128_si256(_mm256_castsi128_si256(
|
| 619 |
+
_mm_loadl_epi64((__m128i *) (lineIn2_min + stride * i))),
|
| 620 |
+
_mm_loadl_epi64((__m128i *) (lineIn3_min + stride * i)), 1);
|
| 621 |
+
auto pix2 = _mm256_shuffle_epi8(source2, mask_low);
|
| 622 |
+
sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
|
| 623 |
+
}
|
| 624 |
+
|
| 625 |
+
// block 1
|
| 626 |
+
const auto i32_aligned = num_channels == 4;
|
| 627 |
+
for (; i < ids_size - 1; i++) {
|
| 628 |
+
// Load 1 value from weight vector
|
| 629 |
+
// mmk = [wl_0 wh_0 0 0 wl_0 wh_0 0 0 ...]
|
| 630 |
+
const auto mmk = _mm256_set1_epi32(k[i]);
|
| 631 |
+
|
| 632 |
+
// Load 2 pixels (one per line) from input lines 0 and 1:
|
| 633 |
+
// RGBA: pix1 = [
|
| 634 |
+
// r0 0 0 0 g0 0 0 0 b0 0 0 0 a0 0 0 0
|
| 635 |
+
// R0 0 0 0 G0 0 0 0 B0 0 0 0 A0 0 0 0
|
| 636 |
+
// ]
|
| 637 |
+
// RGB: pix1 = [
|
| 638 |
+
// r0 0 0 0 g0 0 0 0 b0 0 0 0 r1 0 0 0
|
| 639 |
+
// R0 0 0 0 G0 0 0 0 B0 0 0 0 R1 0 0 0
|
| 640 |
+
// ]
|
| 641 |
+
auto pix1 = _mm256_inserti128_si256(_mm256_castsi128_si256(
|
| 642 |
+
mm_cvtepu8_epi32(lineIn0_min + stride * i, i32_aligned)),
|
| 643 |
+
mm_cvtepu8_epi32(lineIn1_min + stride * i, i32_aligned), 1);
|
| 644 |
+
// Compute output value as C += w0 * C0 for each channel in 32-bit precision
|
| 645 |
+
sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk));
|
| 646 |
+
|
| 647 |
+
// Same as above for lines 2 and 3
|
| 648 |
+
auto pix2 = _mm256_inserti128_si256(_mm256_castsi128_si256(
|
| 649 |
+
mm_cvtepu8_epi32(lineIn2_min + stride * i, i32_aligned)),
|
| 650 |
+
mm_cvtepu8_epi32(lineIn3_min + stride * i, i32_aligned), 1);
|
| 651 |
+
sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
|
| 652 |
+
}
|
| 653 |
+
|
| 654 |
+
if (i == ids_size - 1) {
|
| 655 |
+
// last element
|
| 656 |
+
auto mmk = _mm256_set1_epi32(k[i]);
|
| 657 |
+
// For num_channels == 3 (3 bytes = one pixel) we tolerate to read 4 bytes
|
| 658 |
+
// lines 0, 1 and 2 wont go out of allocated memory bounds
|
| 659 |
+
auto pix = _mm256_inserti128_si256(_mm256_castsi128_si256(
|
| 660 |
+
mm_cvtepu8_epi32(lineIn0_min + stride * i, i32_aligned)),
|
| 661 |
+
mm_cvtepu8_epi32(lineIn1_min + stride * i, i32_aligned), 1);
|
| 662 |
+
sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix, mmk));
|
| 663 |
+
|
| 664 |
+
auto p0 = mm_cvtepu8_epi32(lineIn2_min + stride * i, i32_aligned);
|
| 665 |
+
__m128i p1;
|
| 666 |
+
if (num_channels == 3 && C10_UNLIKELY(is_last_line && ids_min + stride * i + 4 >= max_in_x_strided)) {
|
| 667 |
+
uint8_t input[4];
|
| 668 |
+
std::memcpy(input, lineIn3_min + stride * i, 3);
|
| 669 |
+
p1 = mm_cvtepu8_epi32(input, true);
|
| 670 |
+
} else {
|
| 671 |
+
p1 = mm_cvtepu8_epi32(lineIn3_min + stride * i, i32_aligned);
|
| 672 |
+
}
|
| 673 |
+
auto pix2 = _mm256_inserti128_si256(_mm256_castsi128_si256(p0), p1, 1);
|
| 674 |
+
sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
|
| 675 |
+
}
|
| 676 |
+
|
| 677 |
+
// Convert fixed point values back to integers (truncating)
|
| 678 |
+
sss0 = _mm256_srai_epi32(sss0, coefs_precision);
|
| 679 |
+
sss1 = _mm256_srai_epi32(sss1, coefs_precision);
|
| 680 |
+
// Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
|
| 681 |
+
// (a a a a b b b b c c c c d d d d) -> (a a b b c c d d 0 0 0 0 0 0 0 0)
|
| 682 |
+
sss0 = _mm256_packs_epi32(sss0, zero);
|
| 683 |
+
sss1 = _mm256_packs_epi32(sss1, zero);
|
| 684 |
+
// Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
|
| 685 |
+
// (a a b b c c d d) -> (a b c d 0 0 0 0)
|
| 686 |
+
sss0 = _mm256_packus_epi16(sss0, zero);
|
| 687 |
+
sss1 = _mm256_packus_epi16(sss1, zero);
|
| 688 |
+
|
| 689 |
+
// Write the output into single uint32
|
| 690 |
+
// (a b c d) -> x_uint32
|
| 691 |
+
auto o0 = _mm_cvtsi128_si32(_mm256_castsi256_si128(sss0));
|
| 692 |
+
auto o1 = _mm_cvtsi128_si32(_mm256_extracti128_si256(sss0, 1));
|
| 693 |
+
auto o2 = _mm_cvtsi128_si32(_mm256_castsi256_si128(sss1));
|
| 694 |
+
auto o3 = _mm_cvtsi128_si32(_mm256_extracti128_si256(sss1, 1));
|
| 695 |
+
|
| 696 |
+
const auto out_x_strided = stride * out_x;
|
| 697 |
+
|
| 698 |
+
if (num_channels == 3 && C10_UNLIKELY(out_x_strided + 4 >= max_out_x_strided)) {
|
| 699 |
+
// Memcpy 4-bytes is faster than 3-bytes and this is a boundary case when we want to write
|
| 700 |
+
// 4 bytes (R G B | X) to the output buffer (X1 X2 X3 | R1).
|
| 701 |
+
// The 4th byte in the register (X) has a garbage value and 4th byte in the output buffer (R1) has a correct
|
| 702 |
+
// value which was previously computed by another line. In other words, it means that we can not overwrite
|
| 703 |
+
// it by simply writing 4 bytes from the register to the output. We'll do the following:
|
| 704 |
+
// v----------|
|
| 705 |
+
// Output = [... X1 X2 X3 | R1 G1 B1 R2 ...]
|
| 706 |
+
// First, we write R1 value to the 4th byte of (R G B | X) -> (R G B | R1)
|
| 707 |
+
// Second, we write 4 bytes from the register to the output: (X1 X2 X3 | R1) -> (R G B | R1)
|
| 708 |
+
// Output = [... R G B | R1 G1 B1 R2 ...]
|
| 709 |
+
|
| 710 |
+
_write_endline_rgb_as_uint32(lineOut0 + out_x_strided, o0);
|
| 711 |
+
_write_endline_rgb_as_uint32(lineOut1 + out_x_strided, o1);
|
| 712 |
+
_write_endline_rgb_as_uint32(lineOut2 + out_x_strided, o2);
|
| 713 |
+
|
| 714 |
+
if (C10_UNLIKELY(is_last_line)) {
|
| 715 |
+
// When we handle the last line, we can not access the next 4 bytes
|
| 716 |
+
// as they are out of memory bounds.
|
| 717 |
+
std::memcpy(lineOut3 + out_x_strided, (uint8_t *) &o3, num_channels);
|
| 718 |
+
} else {
|
| 719 |
+
_write_endline_rgb_as_uint32(lineOut3 + out_x_strided, o3);
|
| 720 |
+
}
|
| 721 |
+
} else if (num_channels == 3) {
|
| 722 |
+
// Memcpy 4-bytes is faster than 3-bytes and here
|
| 723 |
+
// we simply write 4 bytes (... R G B X 0 0 0 0 0 ...) where X is a garbage value
|
| 724 |
+
// that we will overwrite on the next iteration: (... R G B R G B X 0 0 ...)
|
| 725 |
+
std::memcpy(lineOut0 + out_x_strided, (uint8_t *) &o0, 4);
|
| 726 |
+
std::memcpy(lineOut1 + out_x_strided, (uint8_t *) &o1, 4);
|
| 727 |
+
std::memcpy(lineOut2 + out_x_strided, (uint8_t *) &o2, 4);
|
| 728 |
+
std::memcpy(lineOut3 + out_x_strided, (uint8_t *) &o3, 4);
|
| 729 |
+
} else {
|
| 730 |
+
// num_channels = 4 -> lineOutX + out_x_strided should be uint32 aligned
|
| 731 |
+
*(uint32_t *)(lineOut0 + out_x_strided) = o0;
|
| 732 |
+
*(uint32_t *)(lineOut1 + out_x_strided) = o1;
|
| 733 |
+
*(uint32_t *)(lineOut2 + out_x_strided) = o2;
|
| 734 |
+
*(uint32_t *)(lineOut3 + out_x_strided) = o3;
|
| 735 |
+
}
|
| 736 |
+
}
|
| 737 |
+
}
|
| 738 |
+
|
| 739 |
+
void ImagingResampleHorizontalConvolution8u(
|
| 740 |
+
uint8_t* C10_RESTRICT lineOut,
|
| 741 |
+
int64_t out_xsize,
|
| 742 |
+
const uint8_t* C10_RESTRICT lineIn,
|
| 743 |
+
int64_t in_xsize,
|
| 744 |
+
const int64_t* idx_ptr_xmin,
|
| 745 |
+
const int64_t* idx_ptr_size,
|
| 746 |
+
const int16_t* kk,
|
| 747 |
+
int kmax,
|
| 748 |
+
unsigned int coefs_precision,
|
| 749 |
+
int64_t num_channels,
|
| 750 |
+
bool is_last_line) {
|
| 751 |
+
|
| 752 |
+
// Interpolation horizontal pass processing only one vertical line.
|
| 753 |
+
// - Input data format is RGBA or RGB with R,G,B,A being uint8. In case of RGBA
|
| 754 |
+
// we can encode 4 values as a single uint32 value.
|
| 755 |
+
// - We split the size of weight vector for a given output index as a sum:
|
| 756 |
+
// ids_size = num_blocks_8 * 8 + num_blocks_4 * 4 + num_blocks_2 * 2 + num_blocks_1
|
| 757 |
+
// - We load and process 8 weights values in a loop ("block 8") then 4 weights and 2 weights values in
|
| 758 |
+
// in another loops ("block 4" and "block 2") and finally we process 1 weight value in the final loop ("block 1").
|
| 759 |
+
|
| 760 |
+
// Define various shuffling masks
|
| 761 |
+
const auto kmask_low = _mm256_set_epi8(
|
| 762 |
+
11, 10, 9, 8, 11, 10, 9, 8, 11, 10, 9, 8, 11, 10, 9, 8,
|
| 763 |
+
3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0);
|
| 764 |
+
const auto kmask_high = _mm256_set_epi8(
|
| 765 |
+
15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12,
|
| 766 |
+
7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4);
|
| 767 |
+
const auto kmask_hl = _mm256_set_epi8(
|
| 768 |
+
7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4,
|
| 769 |
+
3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0);
|
| 770 |
+
|
| 771 |
+
const auto mask_low_c4 = _mm256_set_epi8(
|
| 772 |
+
-1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0,
|
| 773 |
+
-1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);
|
| 774 |
+
const auto mask_high_c4 = _mm256_set_epi8(
|
| 775 |
+
-1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8,
|
| 776 |
+
-1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8);
|
| 777 |
+
const auto mask_low_c3 = _mm256_set_epi8(
|
| 778 |
+
-1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0,
|
| 779 |
+
-1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0);
|
| 780 |
+
const auto mask_high_c3 = _mm256_set_epi8(
|
| 781 |
+
-1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6,
|
| 782 |
+
-1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6);
|
| 783 |
+
const auto mask_hl_c3 = _mm256_set_epi8(
|
| 784 |
+
-1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6,
|
| 785 |
+
-1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0);
|
| 786 |
+
const auto mask_hl_c4 = _mm256_set_epi8(
|
| 787 |
+
-1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8,
|
| 788 |
+
-1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);
|
| 789 |
+
|
| 790 |
+
const auto mask_low128_c3 = _mm_set_epi8(
|
| 791 |
+
-1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0);
|
| 792 |
+
const auto mask_low128_c4 = _mm_set_epi8(
|
| 793 |
+
-1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);
|
| 794 |
+
|
| 795 |
+
const auto mask_low = (num_channels == 3) ? mask_low_c3 : mask_low_c4;
|
| 796 |
+
const auto mask_high = (num_channels == 3) ? mask_high_c3 : mask_high_c4;
|
| 797 |
+
const auto mask_hl = (num_channels == 3) ? mask_hl_c3 : mask_hl_c4;
|
| 798 |
+
const auto mask_low128 = (num_channels == 3) ? mask_low128_c3 : mask_low128_c4;
|
| 799 |
+
|
| 800 |
+
// out_xsize = output width, out_x = output x index
|
| 801 |
+
// ids_min is the input offset index corresponding to out_x
|
| 802 |
+
// ids_size is the interpolation size for out_x
|
| 803 |
+
|
| 804 |
+
const auto stride = num_channels * sizeof(uint8_t);
|
| 805 |
+
const auto zero = _mm_setzero_si128();
|
| 806 |
+
|
| 807 |
+
TORCH_INTERNAL_ASSERT(stride == 3 || stride == 4);
|
| 808 |
+
|
| 809 |
+
// Let's precompute ids_size limits for block 8, block 4 and block 2
|
| 810 |
+
//
|
| 811 |
+
// In block 8 (8 means we process 8 weight values together), we read at
|
| 812 |
+
// most 32 bytes input data (16 + 16 bytes for RGBA and 12 + 16 bytes for RGB)
|
| 813 |
+
// lineIn + stride * (i + ids_min) + 32 <= lineIn + stride * (ids_size + ids_min)
|
| 814 |
+
// --> i <= ids_size - 32.0 / stride
|
| 815 |
+
// Strict boundary:
|
| 816 |
+
// --> i < ids_size + 1 - int(ceil(32.0 / stride)) = ids_size - b8_delta
|
| 817 |
+
// Soft boundary for reading inside the buffer except its boundaries:
|
| 818 |
+
// --> i < ids_size + 1 - int(32.0 / stride) = ids_size - b8_delta_soft
|
| 819 |
+
// RGBA: b8_delta = b8_delta_soft = 7
|
| 820 |
+
// RGB : b8_delta = 10
|
| 821 |
+
// RGB : b8_delta_soft = 9
|
| 822 |
+
const auto b8_delta = (stride == 4) ? 7 : ((is_last_line) ? 10 : 9);
|
| 823 |
+
|
| 824 |
+
// In block 4 (4 means we process 4 weight values together), we read
|
| 825 |
+
// 16 bytes of input data.
|
| 826 |
+
// lineIn + stride * (i + ids_min) + 16 <= lineIn0 + stride * (ids_size + ids_min)
|
| 827 |
+
// --> i <= ids_size - 16.0 / stride
|
| 828 |
+
// Strict boundary:
|
| 829 |
+
// --> i < ids_size + 1 - int(ceil(16.0 / stride)) = ids_size - b4_delta
|
| 830 |
+
// Soft boundary for reading inside the buffer except its boundaries:
|
| 831 |
+
// --> i < ids_size + 1 - int(16.0 / stride) = ids_size - b4_delta_soft
|
| 832 |
+
// RGBA: b4_delta = b4_delta_soft = 3
|
| 833 |
+
// RGB : b4_delta = 5
|
| 834 |
+
// RGB : b4_delta_soft = 4
|
| 835 |
+
const auto b4_delta = (stride == 4) ? 3 : ((is_last_line) ? 5 : 4);
|
| 836 |
+
|
| 837 |
+
// In block 2 (2 means we process 2 weight values together), we read
|
| 838 |
+
// 8 bytes of input data.
|
| 839 |
+
// lineIn0 + stride * (i + ids_min) + 8 <= lineIn0 + stride * (ids_size + ids_min)
|
| 840 |
+
// --> i <= ids_size - 8.0 / stride
|
| 841 |
+
// Strict boundary:
|
| 842 |
+
// --> i < ids_size + 1 - int(ceil(8.0 / stride)) = ids_size - b2_delta
|
| 843 |
+
// Soft boundary for reading inside the buffer except its boundaries:
|
| 844 |
+
// --> i < ids_size + 1 - int(8.0 / stride) = ids_size - b2_delta_soft
|
| 845 |
+
// RGBA: b2_delta = b2_delta_soft = 1
|
| 846 |
+
// RGB : b2_delta = 2
|
| 847 |
+
// RGB : b2_delta_soft = 1
|
| 848 |
+
const auto b2_delta = (stride == 4) ? 1 : ((is_last_line) ? 2 : 1);
|
| 849 |
+
|
| 850 |
+
const auto max_out_x_strided = out_xsize * stride;
|
| 851 |
+
const auto max_in_x_strided = in_xsize * stride;
|
| 852 |
+
|
| 853 |
+
for (const auto out_x : c10::irange(out_xsize)) {
|
| 854 |
+
__m128i sss;
|
| 855 |
+
const auto ids_min = idx_ptr_xmin[out_x];
|
| 856 |
+
const auto ids_size = idx_ptr_size[out_x];
|
| 857 |
+
const auto * k = &kk[out_x * kmax];
|
| 858 |
+
int64_t i = 0;
|
| 859 |
+
|
| 860 |
+
const auto * lineIn_min = lineIn + ids_min;
|
| 861 |
+
|
| 862 |
+
if (ids_size < 8) {
|
| 863 |
+
sss = _mm_set1_epi32(1 << (coefs_precision - 1));
|
| 864 |
+
} else {
|
| 865 |
+
// Lower part will be added to higher, use only half of the error
|
| 866 |
+
auto sss256 = _mm256_set1_epi32(1 << (coefs_precision - 2));
|
| 867 |
+
|
| 868 |
+
// block 8
|
| 869 |
+
for (; i < ids_size - b8_delta; i += 8) {
|
| 870 |
+
// Load 8 values from weight vector
|
| 871 |
+
auto tmp = _mm_loadu_si128((__m128i*)&k[i]);
|
| 872 |
+
// ksource = [
|
| 873 |
+
// wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 wl_4 wh_4 wl_5 wh_5 wl_6 wh_6 wl_7 wh_7
|
| 874 |
+
// wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 wl_4 wh_4 wl_5 wh_5 wl_6 wh_6 wl_7 wh_7
|
| 875 |
+
// ]
|
| 876 |
+
auto ksource = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1);
|
| 877 |
+
|
| 878 |
+
// RGBA: Load 8 pixels from input:
|
| 879 |
+
// source = [
|
| 880 |
+
// r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3
|
| 881 |
+
// r4 g4 b4 a4 r5 g5 b5 a5 r6 g6 b6 a6 r7 g7 b7 a7
|
| 882 |
+
// ]
|
| 883 |
+
// RGB: Load 10 pixels from input (however we can process only 8 pixels):
|
| 884 |
+
// source = [
|
| 885 |
+
// r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5
|
| 886 |
+
// r4 g4 b4 r5 g5 b5 r6 g6 b6 r7 g7 b7 r8 g8 b8 r9
|
| 887 |
+
// ]
|
| 888 |
+
auto source = _mm256_inserti128_si256(_mm256_castsi128_si256(
|
| 889 |
+
_mm_loadu_si128((__m128i *) (lineIn_min + stride * i))),
|
| 890 |
+
_mm_loadu_si128((__m128i *) (lineIn_min + stride * (i + 4))), 1);
|
| 891 |
+
|
| 892 |
+
// Extract lower part of each lane, cast to epi16 and reoder RGBARGBA -> RRGGBBAA
|
| 893 |
+
// RGBA: pix1 = [
|
| 894 |
+
// r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0
|
| 895 |
+
// r4 0 r5 0 g4 0 g5 0 b4 0 b5 0 a4 0 a5 0
|
| 896 |
+
// ]
|
| 897 |
+
// RGB: pix1 = [
|
| 898 |
+
// r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0
|
| 899 |
+
// r4 0 r5 0 g4 0 g5 0 b4 0 b5 0 0 0 0 0
|
| 900 |
+
// ]
|
| 901 |
+
auto pix1 = _mm256_shuffle_epi8(source, mask_low);
|
| 902 |
+
// mmk1 = [
|
| 903 |
+
// wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ...
|
| 904 |
+
// wl_4 wh_4 wl_5 wh_5 wl_4 wh_4 wl_5 wh_5 ... ...
|
| 905 |
+
// ]
|
| 906 |
+
auto mmk1 = _mm256_shuffle_epi8(ksource, kmask_low);
|
| 907 |
+
// Compute output value as
|
| 908 |
+
// C += w0 * C0 + w1 * C1
|
| 909 |
+
// C += w4 * C4 + w5 * C5 for each channel in 32-bit precision
|
| 910 |
+
sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix1, mmk1));
|
| 911 |
+
|
| 912 |
+
// Same as above for higher part of each lane
|
| 913 |
+
auto pix2 = _mm256_shuffle_epi8(source, mask_high);
|
| 914 |
+
auto mmk2 = _mm256_shuffle_epi8(ksource, kmask_high);
|
| 915 |
+
// Compute output value as
|
| 916 |
+
// C += w2 * C2 + w3 * C3
|
| 917 |
+
// C += w6 * C6 + w7 * C7 for each channel in 32-bit precision
|
| 918 |
+
sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix2, mmk2));
|
| 919 |
+
}
|
| 920 |
+
|
| 921 |
+
// block 4
|
| 922 |
+
for (; i < ids_size - b4_delta; i += 4) {
|
| 923 |
+
// Load 4 values from weight vector
|
| 924 |
+
auto tmp = _mm_loadl_epi64((__m128i *) &k[i]);
|
| 925 |
+
// ksource = [
|
| 926 |
+
// wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 0 0 0 0 0 0 0 0
|
| 927 |
+
// wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 0 0 0 0 0 0 0 0
|
| 928 |
+
// ]
|
| 929 |
+
auto ksource = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1);
|
| 930 |
+
|
| 931 |
+
// Load pixels from input line
|
| 932 |
+
tmp = _mm_loadu_si128((__m128i *) (lineIn_min + stride * i));
|
| 933 |
+
// RGBA: source = [
|
| 934 |
+
// r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3
|
| 935 |
+
// r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3
|
| 936 |
+
// ]
|
| 937 |
+
// RGB: source = [
|
| 938 |
+
// r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5
|
| 939 |
+
// r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5
|
| 940 |
+
// ]
|
| 941 |
+
auto source = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1);
|
| 942 |
+
|
| 943 |
+
// Cast source to epi16 and reorder RGBARGBA -> RRGGBBAA
|
| 944 |
+
// RGBA: pix = [
|
| 945 |
+
// r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0
|
| 946 |
+
// r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 a2 0 a3 0
|
| 947 |
+
// ]
|
| 948 |
+
// RGB: pix = [
|
| 949 |
+
// r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0
|
| 950 |
+
// r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 0 0 0 0
|
| 951 |
+
// ]
|
| 952 |
+
auto pix = _mm256_shuffle_epi8(source, mask_hl);
|
| 953 |
+
// mmk = [
|
| 954 |
+
// wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ...
|
| 955 |
+
// wl_2 wh_2 wl_3 wh_3 wl_2 wh_2 wl_3 wh_3 ... ...
|
| 956 |
+
// ]
|
| 957 |
+
auto mmk = _mm256_shuffle_epi8(ksource, kmask_hl);
|
| 958 |
+
// Compute output value as
|
| 959 |
+
// C += w0 * C0 + w1 * C1
|
| 960 |
+
// C += w2 * C2 + w3 * C3 for each channel in 32-bit precision
|
| 961 |
+
sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix, mmk));
|
| 962 |
+
}
|
| 963 |
+
|
| 964 |
+
// Sum results between the lanes
|
| 965 |
+
sss = _mm_add_epi32(
|
| 966 |
+
_mm256_extracti128_si256(sss256, 0),
|
| 967 |
+
_mm256_extracti128_si256(sss256, 1));
|
| 968 |
+
}
|
| 969 |
+
|
| 970 |
+
// block 2
|
| 971 |
+
for (; i < ids_size - b2_delta; i += 2) {
|
| 972 |
+
// Load 2 values from weight vector
|
| 973 |
+
// mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ...]
|
| 974 |
+
auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]);
|
| 975 |
+
// Load pixels from input line
|
| 976 |
+
// RGBA: source = [
|
| 977 |
+
// r0 g0 b0 a0 r1 g1 b1 a1 0 0 0 0 0 0 0 0
|
| 978 |
+
// ]
|
| 979 |
+
// RGB: source = [
|
| 980 |
+
// r0 g0 b0 r1 g1 b1 r2 g2 0 0 0 0 0 0 0 0
|
| 981 |
+
// ]
|
| 982 |
+
auto source = _mm_loadl_epi64((__m128i *) (lineIn_min + stride * i));
|
| 983 |
+
// Cast source to epi16 and reorder RGBARGBA -> RRGGBBAA
|
| 984 |
+
auto pix = _mm_shuffle_epi8(source, mask_low128);
|
| 985 |
+
// Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision
|
| 986 |
+
sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
|
| 987 |
+
}
|
| 988 |
+
|
| 989 |
+
// block 1
|
| 990 |
+
const auto i32_aligned = num_channels == 4;
|
| 991 |
+
for (; i < ids_size - 1; i++) {
|
| 992 |
+
// Load 1 value from weight vector
|
| 993 |
+
// mmk = [wl_0 wh_0 0 0 wl_0 wh_0 0 0 ...]
|
| 994 |
+
auto mmk = _mm_set1_epi32(k[i]);
|
| 995 |
+
// Load one pixel from input line
|
| 996 |
+
// RGBA: pix = [
|
| 997 |
+
// r0 0 0 0 g0 0 0 0 b0 0 0 0 a0 0 0 0
|
| 998 |
+
// ]
|
| 999 |
+
// RGB: pix = [
|
| 1000 |
+
// r0 0 0 0 g0 0 0 0 b0 0 0 0 r1 0 0 0
|
| 1001 |
+
// ]
|
| 1002 |
+
auto pix = mm_cvtepu8_epi32(lineIn_min + stride * i, i32_aligned);
|
| 1003 |
+
// Compute output value as C += w0 * C0 for each channel in 32-bit precision
|
| 1004 |
+
sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
|
| 1005 |
+
}
|
| 1006 |
+
|
| 1007 |
+
if (i == ids_size - 1) {
|
| 1008 |
+
// last element
|
| 1009 |
+
auto mmk = _mm_set1_epi32(k[i]);
|
| 1010 |
+
__m128i pix;
|
| 1011 |
+
auto p = lineIn_min + stride * i;
|
| 1012 |
+
if (num_channels == 3 && C10_UNLIKELY(is_last_line && ids_min + stride * i + 4 >= max_in_x_strided)) {
|
| 1013 |
+
uint8_t input[4];
|
| 1014 |
+
std::memcpy(input, p, 3);
|
| 1015 |
+
pix = mm_cvtepu8_epi32(input, true);
|
| 1016 |
+
} else {
|
| 1017 |
+
pix = mm_cvtepu8_epi32(p, i32_aligned);
|
| 1018 |
+
}
|
| 1019 |
+
sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
|
| 1020 |
+
}
|
| 1021 |
+
|
| 1022 |
+
// Convert fixed point values back to integers (truncating)
|
| 1023 |
+
sss = _mm_srai_epi32(sss, coefs_precision);
|
| 1024 |
+
// Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
|
| 1025 |
+
// (a a a a b b b b c c c c d d d d) -> (a a b b c c d d 0 0 0 0 0 0 0 0)
|
| 1026 |
+
sss = _mm_packs_epi32(sss, zero);
|
| 1027 |
+
// Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
|
| 1028 |
+
// (a a b b c c d d) -> (a b c d 0 0 0 0)
|
| 1029 |
+
sss = _mm_packus_epi16(sss, zero);
|
| 1030 |
+
// Write the output into single uint32
|
| 1031 |
+
// (a b c d) -> x_uint32
|
| 1032 |
+
auto o = _mm_cvtsi128_si32(sss);
|
| 1033 |
+
const auto out_x_strided = stride * out_x;
|
| 1034 |
+
if (num_channels == 3 && C10_UNLIKELY(out_x_strided + 4 >= max_out_x_strided)) {
|
| 1035 |
+
if (C10_UNLIKELY(is_last_line)) {
|
| 1036 |
+
// When we handle the last line, we can not access the next 4 bytes
|
| 1037 |
+
// as they are out of memory bounds.
|
| 1038 |
+
std::memcpy(lineOut + out_x_strided, (uint8_t *) &o, 3);
|
| 1039 |
+
} else {
|
| 1040 |
+
// Memcpy 4-bytes is faster than 3-bytes and this is a boundary case when we want to write
|
| 1041 |
+
// 4 bytes (R G B | X) to the output buffer (X1 X2 X3 | R1).
|
| 1042 |
+
// The 4th byte in the register (X) has a garbage value and 4th byte in the output buffer (R1) has a correct
|
| 1043 |
+
// value which was previously computed by another line. In other words, it means that we can not overwrite
|
| 1044 |
+
// it by simply writing 4 bytes from the register to the output. We'll do the following:
|
| 1045 |
+
// v----------|
|
| 1046 |
+
// Output = [... X1 X2 X3 | R1 G1 B1 R2 ...]
|
| 1047 |
+
// First, we write R1 value to the 4th byte of (R G B | X) -> (R G B | R1)
|
| 1048 |
+
// Second, we write 4 bytes from the register to the output: (X1 X2 X3 | R1) -> (R G B | R1)
|
| 1049 |
+
// Output = [... R G B | R1 G1 B1 R2 ...]
|
| 1050 |
+
_write_endline_rgb_as_uint32(lineOut + out_x_strided, o);
|
| 1051 |
+
}
|
| 1052 |
+
} else if (num_channels == 3) {
|
| 1053 |
+
// Memcpy 4-bytes is faster than 3-bytes and here
|
| 1054 |
+
// we simply write 4 bytes (... R G B X 0 0 0 0 0 ...) where X is a garbage value
|
| 1055 |
+
// that we will overwrite on the next iteration: (... R G B R G B X 0 0 ...)
|
| 1056 |
+
std::memcpy(lineOut + out_x_strided, (uint8_t *) &o, 4);
|
| 1057 |
+
} else {
|
| 1058 |
+
// num_channels = 4 -> lineOut + out_x_strided should be uint32 aligned
|
| 1059 |
+
*(uint32_t *)(lineOut + out_x_strided) = o;
|
| 1060 |
+
}
|
| 1061 |
+
}
|
| 1062 |
+
}
|
| 1063 |
+
|
| 1064 |
+
void ImagingResampleVerticalConvolution8u(
|
| 1065 |
+
uint8_t* C10_RESTRICT lineOut,
|
| 1066 |
+
const uint8_t* C10_RESTRICT lineIn,
|
| 1067 |
+
int64_t xsize,
|
| 1068 |
+
int64_t ids_min,
|
| 1069 |
+
int64_t ids_size,
|
| 1070 |
+
const int16_t* k,
|
| 1071 |
+
unsigned int coefs_precision,
|
| 1072 |
+
int64_t num_channels) {
|
| 1073 |
+
|
| 1074 |
+
// Interpolation vertical pass processing one line.
|
| 1075 |
+
// - We process x-axis data with blocks of 8, 2 and 1
|
| 1076 |
+
// - We split the size of weight vector for a given output index as a sum: K = n * 2 + m.
|
| 1077 |
+
|
| 1078 |
+
// xsize = output width, also equals to input width
|
| 1079 |
+
// ids_size = interpolation size
|
| 1080 |
+
// ids_min = input y start index
|
| 1081 |
+
const auto stride = num_channels * sizeof(uint8_t);
|
| 1082 |
+
|
| 1083 |
+
TORCH_INTERNAL_ASSERT(stride == 3 || stride == 4);
|
| 1084 |
+
|
| 1085 |
+
const int64_t data_size = xsize * stride;
|
| 1086 |
+
const int64_t data_stride = stride;
|
| 1087 |
+
constexpr auto vec_size = 256 / 8;
|
| 1088 |
+
|
| 1089 |
+
const auto initial = _mm_set1_epi32(1 << (coefs_precision - 1));
|
| 1090 |
+
const auto initial_256 = _mm256_set1_epi32(1 << (coefs_precision - 1));
|
| 1091 |
+
const auto zero = _mm_setzero_si128();
|
| 1092 |
+
const auto zero_256 = _mm256_setzero_si256();
|
| 1093 |
+
|
| 1094 |
+
int64_t j = 0;
|
| 1095 |
+
// block 8
|
| 1096 |
+
const auto b8_usable_vec_stride = (vec_size / data_stride) * data_stride;
|
| 1097 |
+
for (; j < data_size - vec_size; j += b8_usable_vec_stride) {
|
| 1098 |
+
auto sss0 = initial_256;
|
| 1099 |
+
auto sss1 = initial_256;
|
| 1100 |
+
auto sss2 = initial_256;
|
| 1101 |
+
auto sss3 = initial_256;
|
| 1102 |
+
int64_t i = 0;
|
| 1103 |
+
const auto * lineIn_min = lineIn + j + ids_min;
|
| 1104 |
+
|
| 1105 |
+
for (; i < ids_size - 1; i += 2) {
|
| 1106 |
+
// Load 2 values from weight vector
|
| 1107 |
+
auto mmk = _mm256_set1_epi32(*(int32_t*)&k[i]);
|
| 1108 |
+
|
| 1109 |
+
// RGBA: Load 8 pixels per line
|
| 1110 |
+
// source1 = [
|
| 1111 |
+
// r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3
|
| 1112 |
+
// r4 g4 b4 a4 r5 g5 b5 a5 r6 g6 b6 a6 r7 g7 b7 a7
|
| 1113 |
+
// ]
|
| 1114 |
+
// RGB: Load 10 pixels per line (however we can process only 8 pixels):
|
| 1115 |
+
// source1 = [
|
| 1116 |
+
// r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5
|
| 1117 |
+
// r4 g4 b4 r5 g5 b5 r6 g6 b6 r7 g7 b7 r8 g8 b8 r9
|
| 1118 |
+
// ]
|
| 1119 |
+
auto source1 =
|
| 1120 |
+
_mm256_loadu_si256((__m256i*)(lineIn_min + data_size * i));
|
| 1121 |
+
auto source2 =
|
| 1122 |
+
_mm256_loadu_si256((__m256i*)(lineIn_min + data_size * (i + 1)));
|
| 1123 |
+
|
| 1124 |
+
// Interleave source1 and source2 from the low half of each 128-bit lane
|
| 1125 |
+
// and cast the result to epi16
|
| 1126 |
+
// RGBA: pix1 = [
|
| 1127 |
+
// r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 a0 0 A0 0
|
| 1128 |
+
// r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 a1 0 A1 0
|
| 1129 |
+
// ]
|
| 1130 |
+
// RGB: pix1 = [
|
| 1131 |
+
// r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 0 0 0 0
|
| 1132 |
+
// r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 0 0 0 0
|
| 1133 |
+
// ]
|
| 1134 |
+
auto source_lo = _mm256_unpacklo_epi8(source1, source2);
|
| 1135 |
+
auto pix1 = _mm256_unpacklo_epi8(source_lo, zero_256);
|
| 1136 |
+
// Compute output value as
|
| 1137 |
+
// C += w0 * c0 + w1 * C0
|
| 1138 |
+
// C += w0 * c1 + w1 * C1 for each channel in 32-bit precision
|
| 1139 |
+
sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk));
|
| 1140 |
+
|
| 1141 |
+
// RGBA: pix2 = [
|
| 1142 |
+
// r2 0 R2 0 g2 0 G2 0 b2 0 B2 0 a2 0 A2 0
|
| 1143 |
+
// r3 0 R3 0 g3 0 G3 0 b3 0 B3 0 a3 0 A3 0
|
| 1144 |
+
// ]
|
| 1145 |
+
// RGB: pix2 = [
|
| 1146 |
+
// r2 0 R2 0 g2 0 G2 0 b2 0 B2 0 0 0 0 0
|
| 1147 |
+
// r3 0 R3 0 g3 0 G3 0 b3 0 B3 0 0 0 0 0
|
| 1148 |
+
// ]
|
| 1149 |
+
auto pix2 = _mm256_unpackhi_epi8(source_lo, zero_256);
|
| 1150 |
+
// Compute output value as
|
| 1151 |
+
// C += w0 * c2 + w1 * C2
|
| 1152 |
+
// C += w0 * c3 + w1 * C3 for each channel in 32-bit precision
|
| 1153 |
+
sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
|
| 1154 |
+
|
| 1155 |
+
// Same as above for the high half of each 128-bit lane
|
| 1156 |
+
auto source_hi = _mm256_unpackhi_epi8(source1, source2);
|
| 1157 |
+
auto pix3 = _mm256_unpacklo_epi8(source_hi, zero_256);
|
| 1158 |
+
sss2 = _mm256_add_epi32(sss2, _mm256_madd_epi16(pix3, mmk));
|
| 1159 |
+
auto pix4 = _mm256_unpackhi_epi8(source_hi, zero_256);
|
| 1160 |
+
sss3 = _mm256_add_epi32(sss3, _mm256_madd_epi16(pix4, mmk));
|
| 1161 |
+
}
|
| 1162 |
+
// Same processing as above but with a single weight value
|
| 1163 |
+
for (; i < ids_size; i += 1) {
|
| 1164 |
+
auto mmk = _mm256_set1_epi32(k[i]);
|
| 1165 |
+
|
| 1166 |
+
auto source1 = _mm256_loadu_si256((__m256i*)(lineIn_min + i * data_size));
|
| 1167 |
+
|
| 1168 |
+
auto source_lo = _mm256_unpacklo_epi8(source1, zero_256);
|
| 1169 |
+
auto pix1 = _mm256_unpacklo_epi8(source_lo, zero_256);
|
| 1170 |
+
sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk));
|
| 1171 |
+
auto pix2 = _mm256_unpackhi_epi8(source_lo, zero_256);
|
| 1172 |
+
sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
|
| 1173 |
+
|
| 1174 |
+
auto source_hi = _mm256_unpackhi_epi8(source1, zero_256);
|
| 1175 |
+
auto pix3 = _mm256_unpacklo_epi8(source_hi, _mm256_setzero_si256());
|
| 1176 |
+
sss2 = _mm256_add_epi32(sss2, _mm256_madd_epi16(pix3, mmk));
|
| 1177 |
+
auto pix4 = _mm256_unpackhi_epi8(source_hi, _mm256_setzero_si256());
|
| 1178 |
+
sss3 = _mm256_add_epi32(sss3, _mm256_madd_epi16(pix4, mmk));
|
| 1179 |
+
}
|
| 1180 |
+
// Convert fixed point values back to integers (truncating)
|
| 1181 |
+
sss0 = _mm256_srai_epi32(sss0, coefs_precision);
|
| 1182 |
+
sss1 = _mm256_srai_epi32(sss1, coefs_precision);
|
| 1183 |
+
sss2 = _mm256_srai_epi32(sss2, coefs_precision);
|
| 1184 |
+
sss3 = _mm256_srai_epi32(sss3, coefs_precision);
|
| 1185 |
+
// Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
|
| 1186 |
+
// (a a a a b b b b c c c c d d d d) -> (a a b b c c d d)
|
| 1187 |
+
sss0 = _mm256_packs_epi32(sss0, sss1);
|
| 1188 |
+
sss2 = _mm256_packs_epi32(sss2, sss3);
|
| 1189 |
+
// Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
|
| 1190 |
+
// (a a b b c c d d) -> (a b c d)
|
| 1191 |
+
sss0 = _mm256_packus_epi16(sss0, sss2);
|
| 1192 |
+
|
| 1193 |
+
// Stores 32 bytes
|
| 1194 |
+
_mm256_storeu_si256((__m256i*)(lineOut + j), sss0);
|
| 1195 |
+
}
|
| 1196 |
+
|
| 1197 |
+
// TODO: Do we also need block 4 ???
|
| 1198 |
+
// block 2
|
| 1199 |
+
const auto b2_usable_vec_stride = (8 / data_stride) * data_stride;
|
| 1200 |
+
for (; j < data_size - vec_size / 4; j += b2_usable_vec_stride) {
|
| 1201 |
+
auto sss0 = initial;
|
| 1202 |
+
auto sss1 = initial;
|
| 1203 |
+
int64_t i = 0;
|
| 1204 |
+
const auto * lineIn_min = lineIn + j + ids_min;
|
| 1205 |
+
|
| 1206 |
+
for (; i < ids_size - 1; i += 2) {
|
| 1207 |
+
// Load 2 values from weight vector
|
| 1208 |
+
// mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ]
|
| 1209 |
+
auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]);
|
| 1210 |
+
|
| 1211 |
+
// Load 2 pixels per line
|
| 1212 |
+
// RGBA: source1 = [
|
| 1213 |
+
// r0 g0 b0 a0 r1 g1 b1 a1 0 0 0 0 0 0 0 0
|
| 1214 |
+
// ]
|
| 1215 |
+
// RGB: source1 = [
|
| 1216 |
+
// r0 g0 b0 r1 g1 b1 r2 g2 0 0 0 0 0 0 0 0
|
| 1217 |
+
// ]
|
| 1218 |
+
auto source1 = _mm_loadl_epi64((__m128i *) (lineIn_min + i * data_size));
|
| 1219 |
+
auto source2 = _mm_loadl_epi64((__m128i *) (lineIn_min + (i + 1) * data_size));
|
| 1220 |
+
// Interleave source1 and source2 and cast the result to epi16
|
| 1221 |
+
// RGBA: pix = [
|
| 1222 |
+
// r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 a0 0 A0 0
|
| 1223 |
+
// ]
|
| 1224 |
+
// RGB: pix = [
|
| 1225 |
+
// r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 0 0 0 0
|
| 1226 |
+
// ]
|
| 1227 |
+
auto source = _mm_unpacklo_epi8(source1, source2);
|
| 1228 |
+
auto pix = _mm_unpacklo_epi8(source, zero);
|
| 1229 |
+
// Compute output value as C += w0 * c0 + w1 * C0 for each channel in 32-bit precision
|
| 1230 |
+
sss0 = _mm_add_epi32(sss0, _mm_madd_epi16(pix, mmk));
|
| 1231 |
+
// RGBA: pix = [
|
| 1232 |
+
// r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 a1 0 A1 0
|
| 1233 |
+
// ]
|
| 1234 |
+
// RGB: pix = [
|
| 1235 |
+
// r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 0 0 0 0
|
| 1236 |
+
// ]
|
| 1237 |
+
pix = _mm_unpackhi_epi8(source, zero);
|
| 1238 |
+
// Compute output value as C += w0 * c1 + w1 * C1 for each channel in 32-bit precision
|
| 1239 |
+
sss1 = _mm_add_epi32(sss1, _mm_madd_epi16(pix, mmk));
|
| 1240 |
+
}
|
| 1241 |
+
// Same processing as above but with a single weight value
|
| 1242 |
+
for (; i < ids_size; i += 1) {
|
| 1243 |
+
auto mmk = _mm_set1_epi32(k[i]);
|
| 1244 |
+
|
| 1245 |
+
auto source1 = _mm_loadl_epi64((__m128i*) (lineIn_min + i * data_size));
|
| 1246 |
+
|
| 1247 |
+
auto source = _mm_unpacklo_epi8(source1, zero);
|
| 1248 |
+
auto pix1 = _mm_unpacklo_epi8(source, zero);
|
| 1249 |
+
sss0 = _mm_add_epi32(sss0, _mm_madd_epi16(pix1, mmk));
|
| 1250 |
+
auto pix2 = _mm_unpackhi_epi8(source, zero);
|
| 1251 |
+
sss1 = _mm_add_epi32(sss1, _mm_madd_epi16(pix2, mmk));
|
| 1252 |
+
}
|
| 1253 |
+
// Convert fixed point values back to integers (truncating)
|
| 1254 |
+
sss0 = _mm_srai_epi32(sss0, coefs_precision);
|
| 1255 |
+
sss1 = _mm_srai_epi32(sss1, coefs_precision);
|
| 1256 |
+
// Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
|
| 1257 |
+
// (a a a a b b b b c c c c d d d d) -> (a a b b c c d d)
|
| 1258 |
+
sss0 = _mm_packs_epi32(sss0, sss1);
|
| 1259 |
+
// Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
|
| 1260 |
+
// (a a b b c c d d) -> (a b c d)
|
| 1261 |
+
sss0 = _mm_packus_epi16(sss0, sss0);
|
| 1262 |
+
// Store 2 pixels to the output
|
| 1263 |
+
_mm_storel_epi64((__m128i*)(lineOut + j), sss0);
|
| 1264 |
+
}
|
| 1265 |
+
|
| 1266 |
+
// block 1
|
| 1267 |
+
const auto b1_usable_vec_stride = (4 / data_stride) * data_stride;
|
| 1268 |
+
const auto i32_aligned = num_channels == 4;
|
| 1269 |
+
for (; j < data_size - 4; j += b1_usable_vec_stride) {
|
| 1270 |
+
auto sss = initial;
|
| 1271 |
+
int64_t i = 0;
|
| 1272 |
+
const auto * lineIn_min = lineIn + j + ids_min;
|
| 1273 |
+
|
| 1274 |
+
for (; i < ids_size - 1; i += 2) {
|
| 1275 |
+
// Load 2 values from weight vector
|
| 1276 |
+
// mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ]
|
| 1277 |
+
auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]);
|
| 1278 |
+
|
| 1279 |
+
// Load one pixel per line
|
| 1280 |
+
// RGBA: source1 = [
|
| 1281 |
+
// r0 g0 b0 a0 0 0 0 0 0 0 0 0 0 0 0 0
|
| 1282 |
+
// ]
|
| 1283 |
+
// RGB: source1 = [
|
| 1284 |
+
// r0 g0 b0 r1 0 0 0 0 0 0 0 0 0 0 0 0
|
| 1285 |
+
// ]
|
| 1286 |
+
auto source1 = mm_cvtsi32_si128(lineIn_min + i * data_size, i32_aligned);
|
| 1287 |
+
auto source2 = mm_cvtsi32_si128(lineIn_min + (i + 1) * data_size, i32_aligned);
|
| 1288 |
+
|
| 1289 |
+
// Interleave source1 and source2 and cast the result to epi16
|
| 1290 |
+
// RGBA: pix = [
|
| 1291 |
+
// r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 a0 0 A0 0
|
| 1292 |
+
// ]
|
| 1293 |
+
// RGB: pix = [
|
| 1294 |
+
// r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 0 0 0 0
|
| 1295 |
+
// ]
|
| 1296 |
+
auto source = _mm_unpacklo_epi8(source1, source2);
|
| 1297 |
+
auto pix = _mm_unpacklo_epi8(source, zero);
|
| 1298 |
+
// Compute output value as C += w0 * c0 + w1 * C0 for each channel in 32-bit precision
|
| 1299 |
+
sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
|
| 1300 |
+
}
|
| 1301 |
+
|
| 1302 |
+
for (; i < ids_size; i++) {
|
| 1303 |
+
auto mmk = _mm_set1_epi32(k[i]);
|
| 1304 |
+
auto pix = mm_cvtepu8_epi32(lineIn_min + i * data_size, i32_aligned);
|
| 1305 |
+
sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
|
| 1306 |
+
}
|
| 1307 |
+
sss = _mm_srai_epi32(sss, coefs_precision);
|
| 1308 |
+
sss = _mm_packs_epi32(sss, zero);
|
| 1309 |
+
sss = _mm_packus_epi16(sss, zero);
|
| 1310 |
+
|
| 1311 |
+
auto o = _mm_cvtsi128_si32(sss);
|
| 1312 |
+
|
| 1313 |
+
// Here we write 4 bytes to the output even if num_channels < 4, e.g o = {r,g,b,X} for num_channels=3
|
| 1314 |
+
// It is OK to write 4th byte (e.g. X) as on the next step we will overwrite it with new data.
|
| 1315 |
+
// We also wont go out of bounds of lineOut memory allocation
|
| 1316 |
+
std::memcpy(lineOut + j, (uint8_t *) &o, 4);
|
| 1317 |
+
}
|
| 1318 |
+
|
| 1319 |
+
for (; j < data_size; j += data_stride) {
|
| 1320 |
+
auto sss = initial;
|
| 1321 |
+
int64_t i = 0;
|
| 1322 |
+
const auto * lineIn_min = lineIn + j + ids_min;
|
| 1323 |
+
// For RGBA we can use (ids_size - 1) as tighter limit but for RGB we can read outside memory boundary
|
| 1324 |
+
// for the last remaining line
|
| 1325 |
+
for (; i < ids_size - 2; i += 2) {
|
| 1326 |
+
// Load two coefficients at once
|
| 1327 |
+
auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]);
|
| 1328 |
+
|
| 1329 |
+
// Load 2 lines
|
| 1330 |
+
auto source1 = mm_cvtsi32_si128(lineIn_min + i * data_size, i32_aligned);
|
| 1331 |
+
auto source2 = mm_cvtsi32_si128(lineIn_min + (i + 1) * data_size, i32_aligned);
|
| 1332 |
+
|
| 1333 |
+
auto source = _mm_unpacklo_epi8(source1, source2);
|
| 1334 |
+
auto pix = _mm_unpacklo_epi8(source, zero);
|
| 1335 |
+
sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
|
| 1336 |
+
}
|
| 1337 |
+
|
| 1338 |
+
// Same processing as above but with a single weight value
|
| 1339 |
+
for (; i < ids_size; i++) {
|
| 1340 |
+
auto mmk = _mm_set1_epi32(k[i]);
|
| 1341 |
+
|
| 1342 |
+
const uint8_t * p = lineIn_min + i * data_size;
|
| 1343 |
+
__m128i pix;
|
| 1344 |
+
// There is no much perf gain using more detailed condition like
|
| 1345 |
+
// num_channels == 3 && ids_min + j + data_size * i + 4 >= in_max_size
|
| 1346 |
+
// const int64_t in_max_size = data_size * in_ysize;
|
| 1347 |
+
if (num_channels == 3) {
|
| 1348 |
+
uint8_t input[4];
|
| 1349 |
+
std::memcpy(input, p, 3);
|
| 1350 |
+
pix = mm_cvtepu8_epi32(input, true);
|
| 1351 |
+
} else {
|
| 1352 |
+
pix = mm_cvtepu8_epi32(p, true);
|
| 1353 |
+
}
|
| 1354 |
+
sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
|
| 1355 |
+
}
|
| 1356 |
+
|
| 1357 |
+
// Convert fixed point values back to integers (truncating)
|
| 1358 |
+
sss = _mm_srai_epi32(sss, coefs_precision);
|
| 1359 |
+
// Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
|
| 1360 |
+
// (a a a a b b b b c c c c d d d d) -> (a a b b c c d d)
|
| 1361 |
+
sss = _mm_packs_epi32(sss, zero);
|
| 1362 |
+
// Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
|
| 1363 |
+
// (a a b b c c d d) -> (a b c d)
|
| 1364 |
+
sss = _mm_packus_epi16(sss, zero);
|
| 1365 |
+
// Store one pixel to the output
|
| 1366 |
+
auto o = _mm_cvtsi128_si32(sss);
|
| 1367 |
+
if (num_channels == 3 && C10_UNLIKELY(j + 4 >= data_size)) {
|
| 1368 |
+
std::memcpy(lineOut + j, (uint8_t *) &o, 3);
|
| 1369 |
+
} else {
|
| 1370 |
+
std::memcpy(lineOut + j, (uint8_t *) &o, 4);
|
| 1371 |
+
}
|
| 1372 |
+
}
|
| 1373 |
+
}
|
| 1374 |
+
|
| 1375 |
+
} // anonymous namespace
|
| 1376 |
+
#endif // CPU_CAPABILITY_AVX2
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/WeightNormKernel.h
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/native/DispatchStub.h>
|
| 3 |
+
#include <cstdint>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
class TensorBase;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
namespace at { namespace native {
|
| 10 |
+
|
| 11 |
+
using weight_norm_fn = void(*)(
|
| 12 |
+
TensorBase&, TensorBase&, const TensorBase&, const TensorBase&, int64_t);
|
| 13 |
+
using weight_norm_backward_fn = void(*)(
|
| 14 |
+
TensorBase&, TensorBase&, const TensorBase&, const TensorBase&,
|
| 15 |
+
const TensorBase&, const TensorBase&, int64_t);
|
| 16 |
+
|
| 17 |
+
DECLARE_DISPATCH(weight_norm_fn, weight_norm_stub);
|
| 18 |
+
DECLARE_DISPATCH(weight_norm_backward_fn, weight_norm_backward_stub);
|
| 19 |
+
|
| 20 |
+
}} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/mixed_data_type.h
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
|
| 5 |
+
namespace at { namespace native {
|
| 6 |
+
|
| 7 |
+
inline ScalarType first_type() {
|
| 8 |
+
return ScalarType::Undefined;
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
template <typename... Args>
|
| 12 |
+
inline ScalarType first_type(const Tensor& arg, const Args&... parameters) {
|
| 13 |
+
return arg.defined() ? arg.scalar_type() : first_type(parameters...);
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
template <typename... Args>
|
| 17 |
+
inline bool is_mixed_type(const Tensor& input, const Args&... parameters) {
|
| 18 |
+
const auto parameter_type = first_type(parameters...);
|
| 19 |
+
return ((parameter_type != ScalarType::Undefined) &&
|
| 20 |
+
(parameter_type != input.scalar_type()));
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
// currently on CPU, mixed data type is only supported
|
| 24 |
+
// when input is 'BFloat16' or 'Half' and parameters are 'Float'
|
| 25 |
+
inline void check_mixed_data_type(const Tensor& input) {
|
| 26 |
+
TORCH_CHECK(at::isReducedFloatingType(input.scalar_type()),
|
| 27 |
+
"mixed dtype (CPU): all inputs must share same datatype.");
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
template <typename... Args>
|
| 31 |
+
inline void check_mixed_data_type(const Tensor& input, const Tensor& parameter, const Args&... parameters) {
|
| 32 |
+
TORCH_CHECK(!parameter.defined() || parameter.scalar_type() == ScalarType::Float,
|
| 33 |
+
"mixed dtype (CPU): expect parameter to have scalar type of Float");
|
| 34 |
+
check_mixed_data_type(input, parameters...);
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
inline ScalarType param_scalar_type(const Tensor& t, bool is_mixed_type) {
|
| 38 |
+
return is_mixed_type ? ScalarType::Float : t.scalar_type();
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
}} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/moments_utils.h
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <array>
|
| 4 |
+
#include <cstring>
|
| 5 |
+
#include <numeric>
|
| 6 |
+
#include <utility>
|
| 7 |
+
#include <vector>
|
| 8 |
+
|
| 9 |
+
#include <ATen/Parallel.h>
|
| 10 |
+
#include <ATen/OpMathType.h>
|
| 11 |
+
#include <ATen/cpu/vec/vec.h>
|
| 12 |
+
#include <ATen/native/cpu/utils.h>
|
| 13 |
+
#include <c10/util/SmallVector.h>
|
| 14 |
+
#include <c10/util/irange.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
namespace native {
|
| 18 |
+
inline namespace CPU_CAPABILITY {
|
| 19 |
+
|
| 20 |
+
template<typename T> using opmath_t = at::opmath_type<T>;
|
| 21 |
+
|
| 22 |
+
constexpr int64_t kChunkSize = 16;
|
| 23 |
+
|
| 24 |
+
template <typename T>
|
| 25 |
+
void AddMoments(
|
| 26 |
+
int64_t m0_add,
|
| 27 |
+
const T& m1_add,
|
| 28 |
+
const T& m2_add,
|
| 29 |
+
int64_t& m0,
|
| 30 |
+
T& m1,
|
| 31 |
+
T& m2) {
|
| 32 |
+
const int64_t n = m0 + m0_add;
|
| 33 |
+
const T c = n == 0 ? static_cast<T>(0) : static_cast<T>(m0_add) / static_cast<T>(n);
|
| 34 |
+
const T delta = m1_add - m1;
|
| 35 |
+
m1 += c * delta;
|
| 36 |
+
m2 += m2_add + delta * delta * c * static_cast<T>(m0);
|
| 37 |
+
m0 = n;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
template <typename T>
|
| 41 |
+
C10_ALWAYS_INLINE void AddMomentsVec(
|
| 42 |
+
int64_t m0_add,
|
| 43 |
+
const vec::Vectorized<T>& m1_add,
|
| 44 |
+
const vec::Vectorized<T>& m2_add,
|
| 45 |
+
int64_t& m0,
|
| 46 |
+
vec::Vectorized<T>& m1,
|
| 47 |
+
vec::Vectorized<T>& m2) {
|
| 48 |
+
using Vec = vec::Vectorized<T>;
|
| 49 |
+
const int64_t n = m0 + m0_add;
|
| 50 |
+
const T c = n == 0 ? static_cast<T>(0) : static_cast<T>(m0_add) / static_cast<T>(n);
|
| 51 |
+
const Vec c_vec(c);
|
| 52 |
+
const Vec delta = m1_add - m1;
|
| 53 |
+
m1 += c_vec * delta;
|
| 54 |
+
m2 += m2_add + delta * delta * c_vec * Vec(static_cast<T>(m0));
|
| 55 |
+
m0 = n;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
template <typename T>
|
| 59 |
+
inline typename std::enable_if<std::is_same<T, opmath_t<T>>::value, void>::type
|
| 60 |
+
UpdateMomentsVec(
|
| 61 |
+
int64_t m0,
|
| 62 |
+
const T* X_ptr,
|
| 63 |
+
const std::array<vec::Vectorized<opmath_t<T>>, kChunkSize>& c_vecs,
|
| 64 |
+
int64_t& m0_stk0,
|
| 65 |
+
vec::Vectorized<opmath_t<T>>& m1_stk0,
|
| 66 |
+
vec::Vectorized<opmath_t<T>>& m2_stk0) {
|
| 67 |
+
using Vec = vec::Vectorized<opmath_t<T>>;
|
| 68 |
+
Vec m1_vec(0);
|
| 69 |
+
Vec m2_vec(0);
|
| 70 |
+
for (const auto j : c10::irange(m0)) {
|
| 71 |
+
const Vec x_vec = Vec::loadu(X_ptr + j * Vec::size());
|
| 72 |
+
const Vec delta_vec = x_vec - m1_vec;
|
| 73 |
+
m1_vec += delta_vec * c_vecs[j];
|
| 74 |
+
m2_vec += delta_vec * (x_vec - m1_vec);
|
| 75 |
+
}
|
| 76 |
+
AddMomentsVec(m0, m1_vec, m2_vec, m0_stk0, m1_stk0, m2_stk0);
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
// each bfloat16/half vector will be converted to two float vectors,
|
| 80 |
+
// and accumulated successively on m1_stk0/m2_stk0.
|
| 81 |
+
template <typename T>
|
| 82 |
+
inline typename std::enable_if<!std::is_same<T, at::opmath_type<T>>::value, void>::type
|
| 83 |
+
UpdateMomentsVec(
|
| 84 |
+
int64_t m0,
|
| 85 |
+
const T* X_ptr,
|
| 86 |
+
const std::array<vec::Vectorized<at::opmath_type<T>>, kChunkSize>& c_vecs,
|
| 87 |
+
int64_t& m0_stk0,
|
| 88 |
+
vec::Vectorized<at::opmath_type<T>>& m1_stk0,
|
| 89 |
+
vec::Vectorized<at::opmath_type<T>>& m2_stk0) {
|
| 90 |
+
using Vec = vec::Vectorized<T>;
|
| 91 |
+
using fVec = vec::Vectorized<at::opmath_type<T>>;
|
| 92 |
+
fVec m1_fvec0(0), m1_fvec1(0);
|
| 93 |
+
fVec m2_fvec0(0), m2_fvec1(0);
|
| 94 |
+
for (const auto j : c10::irange(m0)) {
|
| 95 |
+
const Vec x_bvec = Vec::loadu(X_ptr + j * Vec::size());
|
| 96 |
+
auto [x_fvec0, x_fvec1] = convert_to_float<T>(x_bvec);
|
| 97 |
+
const fVec delta_fvec0 = x_fvec0 - m1_fvec0;
|
| 98 |
+
const fVec delta_fvec1 = x_fvec1 - m1_fvec1;
|
| 99 |
+
m1_fvec0 += delta_fvec0 * c_vecs[j];
|
| 100 |
+
m1_fvec1 += delta_fvec1 * c_vecs[j];
|
| 101 |
+
m2_fvec0 += delta_fvec0 * (x_fvec0 - m1_fvec0);
|
| 102 |
+
m2_fvec1 += delta_fvec1 * (x_fvec1 - m1_fvec1);
|
| 103 |
+
}
|
| 104 |
+
AddMomentsVec(m0, m1_fvec0, m2_fvec0, m0_stk0, m1_stk0, m2_stk0);
|
| 105 |
+
AddMomentsVec(m0, m1_fvec1, m2_fvec1, m0_stk0, m1_stk0, m2_stk0);
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
// Compute rowwise moments by Welford algorithm and cascade sum to improve
|
| 109 |
+
// numerical stability.
|
| 110 |
+
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
|
| 111 |
+
// https://en.wikipedia.org/wiki/Pairwise_summation
|
| 112 |
+
template <typename T, int64_t kMaxDepth>
|
| 113 |
+
std::pair<opmath_t<T>, opmath_t<T>> RowwiseMomentsImpl(const T* X, int64_t N, int64_t ddof = 0) {
|
| 114 |
+
using math_t = opmath_t<T>;
|
| 115 |
+
|
| 116 |
+
constexpr int64_t kVecSize = vec::Vectorized<T>::size();
|
| 117 |
+
constexpr int64_t kAccVecSize = vec::Vectorized<math_t>::size();
|
| 118 |
+
const int64_t n = N / kVecSize;
|
| 119 |
+
const int64_t m = divup(n, kChunkSize);
|
| 120 |
+
const int64_t depth = utils::CeilLog2(m);
|
| 121 |
+
|
| 122 |
+
using Vec = vec::Vectorized<math_t>;
|
| 123 |
+
const Vec kZeroVec(math_t(0));
|
| 124 |
+
c10::SmallVector<int64_t, kMaxDepth> m0_stk(depth, 0);
|
| 125 |
+
c10::SmallVector<Vec, kMaxDepth> m1_stk(depth, kZeroVec);
|
| 126 |
+
c10::SmallVector<Vec, kMaxDepth> m2_stk(depth, kZeroVec);
|
| 127 |
+
|
| 128 |
+
for (const auto i : c10::irange(m)) {
|
| 129 |
+
const T* X_ptr = X + i * kChunkSize * kVecSize;
|
| 130 |
+
const int64_t m0 = std::min(kChunkSize, n - i * kChunkSize);
|
| 131 |
+
static std::array<Vec, kChunkSize> c_vecs = ([]() {
|
| 132 |
+
std::array<Vec, kChunkSize> result;
|
| 133 |
+
for (const auto i : c10::irange(kChunkSize)) {
|
| 134 |
+
result[i] = Vec(math_t(1) / static_cast<math_t>(i + 1));
|
| 135 |
+
}
|
| 136 |
+
return result;
|
| 137 |
+
})();
|
| 138 |
+
UpdateMomentsVec(m0, X_ptr, c_vecs, m0_stk[0], m1_stk[0], m2_stk[0]);
|
| 139 |
+
|
| 140 |
+
int64_t mask = i + 1;
|
| 141 |
+
for (int64_t j = 1; j < depth && (mask & 1) == 0; ++j) {
|
| 142 |
+
AddMomentsVec(
|
| 143 |
+
m0_stk[j - 1],
|
| 144 |
+
m1_stk[j - 1],
|
| 145 |
+
m2_stk[j - 1],
|
| 146 |
+
m0_stk[j],
|
| 147 |
+
m1_stk[j],
|
| 148 |
+
m2_stk[j]);
|
| 149 |
+
m0_stk[j - 1] = 0;
|
| 150 |
+
m1_stk[j - 1] = kZeroVec;
|
| 151 |
+
m2_stk[j - 1] = kZeroVec;
|
| 152 |
+
mask >>= 1;
|
| 153 |
+
}
|
| 154 |
+
}
|
| 155 |
+
for (const auto i : c10::irange(1, depth)) {
|
| 156 |
+
AddMomentsVec(
|
| 157 |
+
m0_stk[i], m1_stk[i], m2_stk[i], m0_stk[0], m1_stk[0], m2_stk[0]);
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
std::array<math_t, kAccVecSize> m1_arr{};
|
| 161 |
+
std::array<math_t, kAccVecSize> m2_arr{};
|
| 162 |
+
m1_stk[0].store(m1_arr.data());
|
| 163 |
+
m2_stk[0].store(m2_arr.data());
|
| 164 |
+
|
| 165 |
+
int64_t m0 = 0;
|
| 166 |
+
math_t m1 = 0;
|
| 167 |
+
math_t m2 = 0;
|
| 168 |
+
for (int64_t i = n * kVecSize; i < N; ++i) {
|
| 169 |
+
math_t x = static_cast<math_t>(X[i]);
|
| 170 |
+
const math_t delta = x - m1;
|
| 171 |
+
++m0;
|
| 172 |
+
m1 += delta / static_cast<math_t>(m0);
|
| 173 |
+
m2 += delta * (x - m1);
|
| 174 |
+
}
|
| 175 |
+
// for BFloat16, each vector in m1_arr/m2_arr holds 2*n accumulated result
|
| 176 |
+
int64_t m0_add = n * kVecSize / kAccVecSize;
|
| 177 |
+
for (const auto i : c10::irange(kAccVecSize)) {
|
| 178 |
+
AddMoments(m0_add, m1_arr[i], m2_arr[i], m0, m1, m2);
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
return std::make_pair(m1, m2 / static_cast<math_t>(N - ddof));
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
template <typename T>
|
| 185 |
+
std::pair<opmath_t<T>, opmath_t<T>> RowwiseMoments(const T* X, int64_t N, int64_t ddof = 0) {
|
| 186 |
+
using Vec = vec::Vectorized<T>;
|
| 187 |
+
constexpr int64_t kVecSize = Vec::size();
|
| 188 |
+
const int64_t n = N / kVecSize;
|
| 189 |
+
const int64_t m = divup(n, kChunkSize);
|
| 190 |
+
const int64_t depth = utils::CeilLog2(m);
|
| 191 |
+
if (depth <= 4) {
|
| 192 |
+
return RowwiseMomentsImpl<T, 4>(X, N, ddof);
|
| 193 |
+
} else if (depth <= 8) {
|
| 194 |
+
return RowwiseMomentsImpl<T, 8>(X, N, ddof);
|
| 195 |
+
} else if (depth <= 16) {
|
| 196 |
+
return RowwiseMomentsImpl<T, 16>(X, N, ddof);
|
| 197 |
+
} else if (depth <= 32) {
|
| 198 |
+
return RowwiseMomentsImpl<T, 32>(X, N, ddof);
|
| 199 |
+
} else {
|
| 200 |
+
return RowwiseMomentsImpl<T, 64>(X, N, ddof);
|
| 201 |
+
}
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
} // namespace CPU_CAPABILITY
|
| 205 |
+
} // namespace native
|
| 206 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/TensorModeKernel.cuh
ADDED
|
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/cuda/detail/IndexUtils.cuh>
|
| 4 |
+
#include <ATen/native/cuda/Loops.cuh>
|
| 5 |
+
#include <ATen/native/cuda/SortingCommon.cuh>
|
| 6 |
+
#include <ATen/native/cuda/block_reduce.cuh>
|
| 7 |
+
|
| 8 |
+
namespace at {
|
| 9 |
+
namespace native {
|
| 10 |
+
|
| 11 |
+
// Used for a segmented reduction
|
| 12 |
+
struct ModeUnsignedBoolPair {
|
| 13 |
+
unsigned int val;
|
| 14 |
+
bool flag;
|
| 15 |
+
};
|
| 16 |
+
|
| 17 |
+
// In the kernel below, we have a common pattern of reducing (unsigned int,
|
| 18 |
+
// unsigned int) pairs of data
|
| 19 |
+
struct ModeUnsignedPair {
|
| 20 |
+
unsigned int val;
|
| 21 |
+
unsigned int index;
|
| 22 |
+
};
|
| 23 |
+
|
| 24 |
+
// Inclusive Scan via an upsweep/downsweep mechanism. Assumes:
|
| 25 |
+
//
|
| 26 |
+
// 1. Power2ScanSize is a power of 2. This code still works for collections that
|
| 27 |
+
// do not exactly contain a power of 2 number of elements, simply round up to
|
| 28 |
+
// the nearest power of 2 and then call.
|
| 29 |
+
//
|
| 30 |
+
// 2. That there are two-elements per thread, i.e. the size of the smem storage
|
| 31 |
+
// is 2 * blockDim.x * sizeof(T).
|
| 32 |
+
//
|
| 33 |
+
// Consider a (+)-Scan on the following elements:
|
| 34 |
+
//
|
| 35 |
+
// Upsweep:
|
| 36 |
+
//
|
| 37 |
+
// 0 1 2 3 4 5 6 7
|
| 38 |
+
// 1 5 9 13
|
| 39 |
+
// 6 22
|
| 40 |
+
// 28
|
| 41 |
+
//
|
| 42 |
+
// Downsweep:
|
| 43 |
+
// 15
|
| 44 |
+
// 3 10 21
|
| 45 |
+
template <int Power2ScanSize, typename T, class BinaryOp>
|
| 46 |
+
__device__ void inclusivePrefixScan(T* smem, BinaryOp binop) {
|
| 47 |
+
// Reduce step ("upsweep")
|
| 48 |
+
#pragma unroll
|
| 49 |
+
for (int stride = 1; stride < Power2ScanSize; stride <<= 1) {
|
| 50 |
+
int index = (threadIdx.x + 1) * stride * 2 - 1;
|
| 51 |
+
if (index < Power2ScanSize) {
|
| 52 |
+
smem[index] = binop(smem[index], smem[index - stride]);
|
| 53 |
+
}
|
| 54 |
+
__syncthreads();
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
// Post-reduce step ("downsweep")
|
| 58 |
+
#pragma unroll
|
| 59 |
+
for (int stride = Power2ScanSize / 4; stride > 0; stride >>= 1) {
|
| 60 |
+
int index = (threadIdx.x + 1) * stride * 2 - 1;
|
| 61 |
+
if ((index + stride) < Power2ScanSize) {
|
| 62 |
+
smem[index + stride] = binop(smem[index + stride], smem[index]);
|
| 63 |
+
}
|
| 64 |
+
__syncthreads();
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
// Block-wide reduction where each thread locally reduces N
|
| 69 |
+
// values before letting a single warp take over - assumes
|
| 70 |
+
// threadVals is in registers, not shared memory
|
| 71 |
+
//
|
| 72 |
+
// If smem is not used again, there is no need to __syncthreads before this
|
| 73 |
+
// call. However, if smem will be used, e.g., this function is called in a loop,
|
| 74 |
+
// then __syncthreads is needed either before or afterwards to prevent non-0
|
| 75 |
+
// threads overriding smem in the next loop before num-0 thread reads from it.
|
| 76 |
+
template <int N, typename T, typename ReduceOp>
|
| 77 |
+
__device__ T reduceBlockWithNThreadLocalReductions(
|
| 78 |
+
T* smem,
|
| 79 |
+
T threadVals[N],
|
| 80 |
+
const unsigned int numVals,
|
| 81 |
+
ReduceOp reduceOp,
|
| 82 |
+
T init) {
|
| 83 |
+
int offset = threadIdx.x * N;
|
| 84 |
+
T local = offset < numVals ? threadVals[0] : init;
|
| 85 |
+
|
| 86 |
+
#pragma unroll
|
| 87 |
+
for (int i = 1; i < N; ++i) {
|
| 88 |
+
++offset;
|
| 89 |
+
T next = offset < numVals ? threadVals[i] : init;
|
| 90 |
+
local = reduceOp.combine(local, next);
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
return cuda_utils::BlockReduce(local, reduceOp, init, smem);
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
template <typename T>
|
| 97 |
+
__device__ inline void swapVars(T& t1, T& t2) {
|
| 98 |
+
T tmp = t1;
|
| 99 |
+
t1 = t2;
|
| 100 |
+
t2 = tmp;
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
template <typename Comparator, typename K, typename V>
|
| 104 |
+
__device__ inline void bitonicSwap(
|
| 105 |
+
K& kA,
|
| 106 |
+
V& vA,
|
| 107 |
+
bool& validA,
|
| 108 |
+
K& kB,
|
| 109 |
+
V& vB,
|
| 110 |
+
bool& validB,
|
| 111 |
+
bool dir,
|
| 112 |
+
const Comparator& comp) {
|
| 113 |
+
// Invalid entries always sort to the end
|
| 114 |
+
bool swap = (comp(kA, kB) && validA) || !validB;
|
| 115 |
+
if (swap == dir) {
|
| 116 |
+
swapVars(kA, kB);
|
| 117 |
+
swapVars(vA, vB);
|
| 118 |
+
swapVars(validA, validB);
|
| 119 |
+
}
|
| 120 |
+
};
|
| 121 |
+
|
| 122 |
+
template <typename Comparator, typename K>
|
| 123 |
+
__device__ inline void bitonicSwapKeys(
|
| 124 |
+
K& kA,
|
| 125 |
+
bool& validA,
|
| 126 |
+
K& kB,
|
| 127 |
+
bool& validB,
|
| 128 |
+
bool dir,
|
| 129 |
+
const Comparator& comp) {
|
| 130 |
+
bool swap = (comp(kA, kB) && validA) || !validB;
|
| 131 |
+
if (swap == dir) {
|
| 132 |
+
swapVars(kA, kB);
|
| 133 |
+
swapVars(validA, validB);
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
template <
|
| 138 |
+
typename K,
|
| 139 |
+
typename IndexType,
|
| 140 |
+
int Power2SortSize,
|
| 141 |
+
typename Comparator>
|
| 142 |
+
__device__ inline void bitonicSortKeys(
|
| 143 |
+
K keys[Power2SortSize],
|
| 144 |
+
bool valid[Power2SortSize],
|
| 145 |
+
const Comparator& comp) {
|
| 146 |
+
#if !defined(USE_ROCM)
|
| 147 |
+
#pragma unroll
|
| 148 |
+
#endif
|
| 149 |
+
for (unsigned int size = 2; size < Power2SortSize; size *= 2) {
|
| 150 |
+
bool flag = ((threadIdx.x & (size / 2)) != 0);
|
| 151 |
+
|
| 152 |
+
#if !defined(USE_ROCM)
|
| 153 |
+
#pragma unroll
|
| 154 |
+
#endif
|
| 155 |
+
for (unsigned int stride = size / 2; stride > 0; stride /= 2) {
|
| 156 |
+
__syncthreads();
|
| 157 |
+
|
| 158 |
+
unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
|
| 159 |
+
bitonicSwapKeys<Comparator, K>(
|
| 160 |
+
keys[pos],
|
| 161 |
+
valid[pos],
|
| 162 |
+
keys[pos + stride],
|
| 163 |
+
valid[pos + stride],
|
| 164 |
+
flag,
|
| 165 |
+
comp);
|
| 166 |
+
}
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
#if !defined(USE_ROCM)
|
| 170 |
+
#pragma unroll
|
| 171 |
+
#endif
|
| 172 |
+
for (unsigned int stride = Power2SortSize / 2; stride > 0; stride /= 2) {
|
| 173 |
+
__syncthreads();
|
| 174 |
+
|
| 175 |
+
unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
|
| 176 |
+
bitonicSwapKeys<Comparator, K>(
|
| 177 |
+
keys[pos],
|
| 178 |
+
valid[pos],
|
| 179 |
+
keys[pos + stride],
|
| 180 |
+
valid[pos + stride],
|
| 181 |
+
false,
|
| 182 |
+
comp);
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
__syncthreads();
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
// The mode kernel has the following characteristics: It uses internal shared
|
| 189 |
+
// memory buffers of Power2Size, which must be greater than the number of
|
| 190 |
+
// elements. Additionally, there is one block for every slice to calculate the
|
| 191 |
+
// mode for, and in each block there is one thread for every two elements.
|
| 192 |
+
//
|
| 193 |
+
// Both sorted and positions are assumed to be contiguous Tensors with the mode
|
| 194 |
+
// dimension as the innermost dim, such that we can get the particular slice for
|
| 195 |
+
// a Tensor via its linear block dimension * the slice size.
|
| 196 |
+
template <typename T, unsigned int Power2Size>
|
| 197 |
+
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11070
|
| 198 |
+
__launch_bounds__(1024, 1)
|
| 199 |
+
#endif
|
| 200 |
+
__global__ void compute_mode(
|
| 201 |
+
const T* input,
|
| 202 |
+
at::cuda::detail::TensorInfo<T, unsigned int> values,
|
| 203 |
+
at::cuda::detail::TensorInfo<int64_t, unsigned int> indices,
|
| 204 |
+
int64_t sliceSize,
|
| 205 |
+
int64_t slices) {
|
| 206 |
+
int tidx = threadIdx.x;
|
| 207 |
+
int stidx = blockDim.x + threadIdx.x; // Second index this thread responsible for
|
| 208 |
+
|
| 209 |
+
// First, we need to calculate the offset into the sorted Tensor that
|
| 210 |
+
// represents the start of the slice for this block to calculate the mode for.
|
| 211 |
+
// This offset is a combination of the gridIndices, and the number of elements
|
| 212 |
+
// in the slice.
|
| 213 |
+
unsigned int blockId = getLinearBlockId<unsigned int>();
|
| 214 |
+
unsigned int linearOffset = blockId * sliceSize;
|
| 215 |
+
|
| 216 |
+
if (blockId >= slices) {
|
| 217 |
+
return;
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
// shmem is a dynamically sized buffer we will use throughout the kernel to
|
| 221 |
+
// handle computation efficiently. The size of this shmem must be
|
| 222 |
+
// sizeof(T) * Power2Size + (2 * sizeof(unsigned int) * Power2Size)
|
| 223 |
+
//
|
| 224 |
+
// Initially, the buffer will be organized as follows:
|
| 225 |
+
//
|
| 226 |
+
// [smem (slice elements) | bmem (valid indices) | <scratch space>]
|
| 227 |
+
extern __shared__ char shmem[];
|
| 228 |
+
|
| 229 |
+
// smem represents a proportion of the shared memory buffer that is used to
|
| 230 |
+
// store the elements from the slice:
|
| 231 |
+
T* smem = reinterpret_cast<T*>(shmem);
|
| 232 |
+
|
| 233 |
+
// Each thread loads up to two elements from the Tensor into shared memory
|
| 234 |
+
if (tidx < sliceSize) {
|
| 235 |
+
smem[tidx] = c10::load(&input[linearOffset + tidx]);
|
| 236 |
+
}
|
| 237 |
+
if (stidx < sliceSize) {
|
| 238 |
+
smem[stidx] = c10::load(&input[linearOffset + stidx]);
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
// Next, we initialize a boolean region of the buffer, offset by the loaded
|
| 242 |
+
// element smem region
|
| 243 |
+
bool* bmem = reinterpret_cast<bool*>(&smem[Power2Size]);
|
| 244 |
+
|
| 245 |
+
// The first use of this region stores bmem[i] = i < sliceSize to mark the
|
| 246 |
+
// valid components in the smem buffer
|
| 247 |
+
bmem[tidx] = tidx < sliceSize;
|
| 248 |
+
bmem[stidx] = stidx < sliceSize;
|
| 249 |
+
__syncthreads(); // barrier for smem, bmem initialization
|
| 250 |
+
|
| 251 |
+
// First, sort the input slice in ascending order. smem contains the input
|
| 252 |
+
// elements, and bmem marks the valid indices
|
| 253 |
+
bitonicSortKeys<T, unsigned int, Power2Size>(
|
| 254 |
+
smem, bmem, [&] GPU_LAMBDA(const auto& a, const auto& b) {
|
| 255 |
+
return a < b;
|
| 256 |
+
});
|
| 257 |
+
__syncthreads(); // make no assumptions that the sort syncs at end
|
| 258 |
+
|
| 259 |
+
// The next step of our algorithm is performing a block-wide comparison of
|
| 260 |
+
// neighboring elements. In particular, given an sorted input slice A, we
|
| 261 |
+
// produce an output slice B, such that B[i] = 1 if A[i-i] != A[i], otherwise
|
| 262 |
+
// 0.
|
| 263 |
+
//
|
| 264 |
+
// Given the input A = [0, 0, 1, 1, 2, 2, 2, 4, 5, 6, 6, 7, 8]
|
| 265 |
+
// B = [1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1]
|
| 266 |
+
//
|
| 267 |
+
// In particular, we can think of B[i] true indicating the start of a sequence
|
| 268 |
+
// of equal values in the sorted list. Similarly, we will also store the
|
| 269 |
+
// negation of B, which we'll call C. In particular, we can think of C[i] =
|
| 270 |
+
// true iff A[i-1] == A[i] in our original sorted slice.
|
| 271 |
+
//
|
| 272 |
+
// C = [0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0]
|
| 273 |
+
|
| 274 |
+
// We overwrite bmem, and treat the rest of shared memory as a buffer of
|
| 275 |
+
// (index, flag) pairs where the index represents values from C, and the flag
|
| 276 |
+
// represents values from B.
|
| 277 |
+
//
|
| 278 |
+
// [smem (sorted slice) | ubpmem (index, flag pairs)]
|
| 279 |
+
|
| 280 |
+
struct ModeUnsignedBoolPair* ubpmem =
|
| 281 |
+
reinterpret_cast<struct ModeUnsignedBoolPair*>(&smem[Power2Size]);
|
| 282 |
+
|
| 283 |
+
if (tidx == 0) {
|
| 284 |
+
ubpmem[0].flag = true;
|
| 285 |
+
ubpmem[0].val = 0;
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
// Compares elements (0, 1), (2, 3), ... and sets 1, 3, ...
|
| 289 |
+
ubpmem[tidx * 2 + 1].flag =
|
| 290 |
+
smem[tidx * 2] != smem[tidx * 2 + 1]; // (0, 1), (1, 2), etc.
|
| 291 |
+
ubpmem[tidx * 2 + 1].val = !ubpmem[tidx * 2 + 1].flag;
|
| 292 |
+
|
| 293 |
+
// Compares elements (1, 2), (3, 4), ... and sets 2, 4, ...
|
| 294 |
+
if (((tidx + 1) * 2) < Power2Size) {
|
| 295 |
+
ubpmem[(tidx + 1) * 2].flag =
|
| 296 |
+
smem[((tidx + 1) * 2) - 1] != smem[(tidx + 1) * 2];
|
| 297 |
+
ubpmem[(tidx + 1) * 2].val = !ubpmem[(tidx + 1) * 2].flag;
|
| 298 |
+
}
|
| 299 |
+
__syncthreads(); // barrier for ubpmem initialization
|
| 300 |
+
|
| 301 |
+
// Next, we perform a segmented prefix sum on the neighboring elements, where
|
| 302 |
+
// the presence of a one indicates the start of a segment. In this case B acts
|
| 303 |
+
// as the segment start flags, and C is the buffer to be summed:
|
| 304 |
+
//
|
| 305 |
+
// Input (C) = [0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0]
|
| 306 |
+
// Flag (B) = [1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1]
|
| 307 |
+
// Output (C) = [0, 1, 0, 1, 0, 1, 2, 0, 0, 0, 1, 0, 0]
|
| 308 |
+
//
|
| 309 |
+
// Afterwards, the (index) components of the ubpmem buffer contain the lengths
|
| 310 |
+
// of the segments (minus 1), i.e. the counts of each element in the original
|
| 311 |
+
// input.
|
| 312 |
+
inclusivePrefixScan<Power2Size>(
|
| 313 |
+
ubpmem, [=] GPU_LAMBDA(const auto& a, const auto& b) {
|
| 314 |
+
ModeUnsignedBoolPair c;
|
| 315 |
+
c.val = a.flag ? a.val : a.val + b.val;
|
| 316 |
+
c.flag = a.flag | b.flag;
|
| 317 |
+
return c;
|
| 318 |
+
});
|
| 319 |
+
// assumes scan syncs at the end
|
| 320 |
+
|
| 321 |
+
// Next, we reinterpret the ubpmem buffer as pairs of unsigned integers (i.e.
|
| 322 |
+
// we treat the boolean flag regions as integers). We initialize these to
|
| 323 |
+
// represent indices, and we'll call this buffer I
|
| 324 |
+
struct ModeUnsignedPair* uupmem =
|
| 325 |
+
reinterpret_cast<struct ModeUnsignedPair*>(ubpmem);
|
| 326 |
+
|
| 327 |
+
// At this point, we need to find the maximum element in lengths buffer C.
|
| 328 |
+
// This element will represent the count (-1) of the mode. Because of the
|
| 329 |
+
// way we have set up the problem, the index where this mode occurs will
|
| 330 |
+
// also be the location of the mode value in the sorted array, e.g.
|
| 331 |
+
//
|
| 332 |
+
// smem = [0, 0, 1, 1, 1, 2]
|
| 333 |
+
// C = [0, 1, 0, 1, 2, 0]
|
| 334 |
+
// I = [0, 1, 2, 3, 4, 5]
|
| 335 |
+
// ^
|
| 336 |
+
// maximum value, also aligned with mode = 1
|
| 337 |
+
//
|
| 338 |
+
// We perform a block wide max-reduction of the C buffer, but we also need the
|
| 339 |
+
// indices to come along with it, so we utilize the uupmem construction.
|
| 340 |
+
//
|
| 341 |
+
// At the end we need to return the ModeUnsignedPair containing index = 4, val
|
| 342 |
+
// = 2, which represents the max
|
| 343 |
+
|
| 344 |
+
// In practice, we will make each thread locally reduce 2 values in its
|
| 345 |
+
// registers prior to the global block-wide reduction. Note that instead of
|
| 346 |
+
// tidx/stidx, we utilize tidx * 2, tidx * 2 + 1, so each thread deals with
|
| 347 |
+
// adjacent elements. This is because the reduce code below relies on thread
|
| 348 |
+
// elements to be adjacent.
|
| 349 |
+
struct ModeUnsignedPair uup[2];
|
| 350 |
+
uup[0].index = tidx * 2;
|
| 351 |
+
uup[0].val = ubpmem[tidx * 2].val;
|
| 352 |
+
uup[1].index = tidx * 2 + 1;
|
| 353 |
+
uup[1].val = ubpmem[tidx * 2 + 1].val;
|
| 354 |
+
__syncthreads();
|
| 355 |
+
|
| 356 |
+
struct ModeUnsignedPair max = {0, 0};
|
| 357 |
+
|
| 358 |
+
struct MaxOp {
|
| 359 |
+
inline __device__ ModeUnsignedPair combine(ModeUnsignedPair a, ModeUnsignedPair b) const {
|
| 360 |
+
return b.val > a.val ? b : a;
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
inline __device__ ModeUnsignedPair warp_shfl_down(ModeUnsignedPair acc, int offset) const {
|
| 364 |
+
ModeUnsignedPair ret;
|
| 365 |
+
ret.index = WARP_SHFL_DOWN(acc.index, offset);
|
| 366 |
+
ret.val = WARP_SHFL_DOWN(acc.val, offset);
|
| 367 |
+
return ret;
|
| 368 |
+
}
|
| 369 |
+
} max_op;
|
| 370 |
+
|
| 371 |
+
max = reduceBlockWithNThreadLocalReductions<2>(
|
| 372 |
+
uupmem,
|
| 373 |
+
uup,
|
| 374 |
+
sliceSize,
|
| 375 |
+
max_op,
|
| 376 |
+
max);
|
| 377 |
+
|
| 378 |
+
// Store the mode in shared memory for use in finding the mode in the input
|
| 379 |
+
// slice
|
| 380 |
+
__shared__ T mode;
|
| 381 |
+
|
| 382 |
+
// Given the above constraints, the mode is the value at the reduced index in
|
| 383 |
+
// the original sorted element buffer
|
| 384 |
+
if (tidx == 0) {
|
| 385 |
+
mode = smem[max.index];
|
| 386 |
+
}
|
| 387 |
+
__syncthreads(); // broadcast mode
|
| 388 |
+
|
| 389 |
+
// Finally, we need to find "an" index of the mode in the input
|
| 390 |
+
// Tensor. The API does not constrain which index we pick, but here
|
| 391 |
+
// we always pick the largest index. We store the index if the value
|
| 392 |
+
// is the mode, or 0 otherwise. Then find the maximum value.
|
| 393 |
+
//
|
| 394 |
+
// Again we reduce 2 elements in the thread's registers prior to the
|
| 395 |
+
// block-wide reduction
|
| 396 |
+
unsigned mode_index[2] = {0u, 0u};
|
| 397 |
+
if (tidx * 2 < sliceSize) {
|
| 398 |
+
const unsigned idx = tidx * 2;
|
| 399 |
+
mode_index[0] = c10::load(&input[linearOffset + idx]) == mode ? idx : 0u;
|
| 400 |
+
}
|
| 401 |
+
if (tidx * 2 + 1 < sliceSize) {
|
| 402 |
+
const unsigned idx = tidx * 2 + 1;
|
| 403 |
+
mode_index[1] = c10::load(&input[linearOffset + idx]) == mode ? idx : 0u;
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
struct MaxIndexOp {
|
| 407 |
+
inline __device__ unsigned combine(unsigned a, unsigned b) const {
|
| 408 |
+
return b > a ? b : a;
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
inline __device__ unsigned warp_shfl_down(unsigned acc, int offset) const {
|
| 412 |
+
return WARP_SHFL_DOWN(acc, offset);
|
| 413 |
+
}
|
| 414 |
+
} max_index_op;
|
| 415 |
+
|
| 416 |
+
int64_t index = reduceBlockWithNThreadLocalReductions<2>(
|
| 417 |
+
reinterpret_cast<unsigned*>(&shmem[0]),
|
| 418 |
+
mode_index,
|
| 419 |
+
sliceSize,
|
| 420 |
+
max_index_op,
|
| 421 |
+
0u);
|
| 422 |
+
|
| 423 |
+
// Finally, we have the mode, and an index where it occurs. We use a single
|
| 424 |
+
// thread to place this in the appropriate output position
|
| 425 |
+
if (tidx == 0) {
|
| 426 |
+
unsigned int outputOffset =
|
| 427 |
+
at::cuda::detail::IndexToOffset<T, unsigned int, -1>::get(
|
| 428 |
+
blockId, values);
|
| 429 |
+
values.data[outputOffset] = mode;
|
| 430 |
+
indices.data[outputOffset] = index;
|
| 431 |
+
}
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
} // namespace native
|
| 435 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/nested/NestedTensorBinaryOps.h
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/ATen_fwd.h>
|
| 4 |
+
#include <ATen/native/DispatchStub.h>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
namespace native {
|
| 8 |
+
|
| 9 |
+
enum class NESTED_DENSE_OP: uint8_t {ADD, MUL};
|
| 10 |
+
|
| 11 |
+
using nested_dense_elementwise_fn = void (*)(Tensor& result, const Tensor & self, const Tensor & other, const NESTED_DENSE_OP& op);
|
| 12 |
+
|
| 13 |
+
DECLARE_DISPATCH(nested_dense_elementwise_fn, nested_dense_elementwise_stub);
|
| 14 |
+
|
| 15 |
+
} // namespace native
|
| 16 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/vol2col.h
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cstring>
|
| 4 |
+
|
| 5 |
+
namespace at::native {
|
| 6 |
+
|
| 7 |
+
template <typename T>
|
| 8 |
+
static void vol2col(
|
| 9 |
+
const T* data_vol,
|
| 10 |
+
const int64_t channels,
|
| 11 |
+
const int64_t depth,
|
| 12 |
+
const int64_t height,
|
| 13 |
+
const int64_t width,
|
| 14 |
+
const int64_t depth_col,
|
| 15 |
+
const int64_t height_col,
|
| 16 |
+
const int64_t width_col,
|
| 17 |
+
const int64_t kT,
|
| 18 |
+
const int64_t kernel_height,
|
| 19 |
+
const int64_t kernel_width,
|
| 20 |
+
const int64_t pT,
|
| 21 |
+
const int64_t pH,
|
| 22 |
+
const int64_t pW,
|
| 23 |
+
const int64_t dT,
|
| 24 |
+
const int64_t dH,
|
| 25 |
+
const int64_t dW,
|
| 26 |
+
const int64_t dilationT,
|
| 27 |
+
const int64_t dilationH,
|
| 28 |
+
const int64_t dilationW,
|
| 29 |
+
T* data_col) {
|
| 30 |
+
int64_t c, t, h, w;
|
| 31 |
+
int64_t channels_col = channels * kT * kernel_height * kernel_width;
|
| 32 |
+
for (c = 0; c < channels_col; ++c) {
|
| 33 |
+
int64_t w_offset = c % kernel_width;
|
| 34 |
+
int64_t h_offset = (c / kernel_width) % kernel_height;
|
| 35 |
+
int64_t t_offset = (c / kernel_width / kernel_height) % kT;
|
| 36 |
+
int64_t c_vol = c / kT / kernel_height / kernel_width;
|
| 37 |
+
for (t = 0; t < depth_col; ++t) {
|
| 38 |
+
int64_t t_pad = t * dT - pT + t_offset * dilationT;
|
| 39 |
+
for (h = 0; h < height_col; ++h) {
|
| 40 |
+
int64_t h_pad = h * dH - pH + h_offset * dilationH;
|
| 41 |
+
for (w = 0; w < width_col; ++w) {
|
| 42 |
+
int64_t w_pad = w * dW - pW + w_offset * dilationW;
|
| 43 |
+
if (t_pad >= 0 && t_pad < depth && h_pad >= 0 && h_pad < height &&
|
| 44 |
+
w_pad >= 0 && w_pad < width)
|
| 45 |
+
data_col[((c * depth_col + t) * height_col + h) * width_col + w] =
|
| 46 |
+
data_vol
|
| 47 |
+
[((c_vol * depth + t_pad) * height + h_pad) * width +
|
| 48 |
+
w_pad];
|
| 49 |
+
else
|
| 50 |
+
data_col[((c * depth_col + t) * height_col + h) * width_col + w] =
|
| 51 |
+
0;
|
| 52 |
+
}
|
| 53 |
+
}
|
| 54 |
+
}
|
| 55 |
+
}
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
template <typename T>
|
| 59 |
+
static void col2vol(
|
| 60 |
+
const T* data_col,
|
| 61 |
+
const int64_t channels,
|
| 62 |
+
const int64_t depth,
|
| 63 |
+
const int64_t height,
|
| 64 |
+
const int64_t width,
|
| 65 |
+
const int64_t out_depth,
|
| 66 |
+
const int64_t out_height,
|
| 67 |
+
const int64_t out_width,
|
| 68 |
+
const int64_t kT,
|
| 69 |
+
const int64_t kernel_height,
|
| 70 |
+
const int64_t kernel_width,
|
| 71 |
+
const int64_t pT,
|
| 72 |
+
const int64_t pH,
|
| 73 |
+
const int64_t pW,
|
| 74 |
+
const int64_t dT,
|
| 75 |
+
const int64_t dH,
|
| 76 |
+
const int64_t dW,
|
| 77 |
+
const int64_t dilationT,
|
| 78 |
+
const int64_t dilationH,
|
| 79 |
+
const int64_t dilationW,
|
| 80 |
+
T* data_vol) {
|
| 81 |
+
memset(data_vol, 0, sizeof(T) * depth * height * width * channels);
|
| 82 |
+
int64_t depth_col = out_depth;
|
| 83 |
+
int64_t height_col = out_height;
|
| 84 |
+
int64_t width_col = out_width;
|
| 85 |
+
int64_t channels_col = channels * kT * kernel_height * kernel_width;
|
| 86 |
+
for (int64_t c = 0; c < channels_col; ++c) {
|
| 87 |
+
int64_t w_offset = c % kernel_width;
|
| 88 |
+
int64_t h_offset = (c / kernel_width) % kernel_height;
|
| 89 |
+
int64_t t_offset = (c / kernel_width / kernel_height) % kT;
|
| 90 |
+
int64_t c_vol = c / kT / kernel_height / kernel_width;
|
| 91 |
+
for (int64_t t = 0; t < depth_col; ++t) {
|
| 92 |
+
int64_t t_pad = t * dT - pT + t_offset * dilationT;
|
| 93 |
+
for (int64_t h = 0; h < height_col; ++h) {
|
| 94 |
+
int64_t h_pad = h * dH - pH + h_offset * dilationH;
|
| 95 |
+
for (int64_t w = 0; w < width_col; ++w) {
|
| 96 |
+
int64_t w_pad = w * dW - pW + w_offset * dilationW;
|
| 97 |
+
if (t_pad >= 0 && t_pad < depth && h_pad >= 0 && h_pad < height &&
|
| 98 |
+
w_pad >= 0 && w_pad < width)
|
| 99 |
+
data_vol
|
| 100 |
+
[((c_vol * depth + t_pad) * height + h_pad) * width + w_pad] +=
|
| 101 |
+
data_col
|
| 102 |
+
[((c * depth_col + t) * height_col + h) * width_col + w];
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
}
|
| 106 |
+
}
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_cast_Long_native.h
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from NativeFunction.h
|
| 4 |
+
|
| 5 |
+
#include <c10/core/Scalar.h>
|
| 6 |
+
#include <c10/core/Storage.h>
|
| 7 |
+
#include <c10/core/TensorOptions.h>
|
| 8 |
+
#include <c10/util/Deprecated.h>
|
| 9 |
+
#include <c10/util/Optional.h>
|
| 10 |
+
#include <c10/core/QScheme.h>
|
| 11 |
+
#include <ATen/core/Reduction.h>
|
| 12 |
+
#include <ATen/core/Tensor.h>
|
| 13 |
+
#include <tuple>
|
| 14 |
+
#include <vector>
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
namespace at {
|
| 18 |
+
namespace native {
|
| 19 |
+
TORCH_API at::Tensor _cast_Long(const at::Tensor & self, bool non_blocking=false);
|
| 20 |
+
} // namespace native
|
| 21 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_cufft_clear_plan_cache_compositeimplicitautograd_dispatch.h
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace compositeimplicitautograd {
|
| 19 |
+
|
| 20 |
+
TORCH_API void _cufft_clear_plan_cache(at::DeviceIndex device_index);
|
| 21 |
+
|
| 22 |
+
} // namespace compositeimplicitautograd
|
| 23 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_fft_c2c_ops.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operator.h
|
| 4 |
+
|
| 5 |
+
#include <tuple>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 9 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 10 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 11 |
+
#include <ATen/core/ATen_fwd.h>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
namespace _ops {
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
struct TORCH_API _fft_c2c {
|
| 18 |
+
using schema = at::Tensor (const at::Tensor &, c10::SymIntArrayRef, int64_t, bool);
|
| 19 |
+
using ptr_schema = schema*;
|
| 20 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 21 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_fft_c2c")
|
| 22 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 23 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor")
|
| 24 |
+
static at::Tensor call(const at::Tensor & self, c10::SymIntArrayRef dim, int64_t normalization, bool forward);
|
| 25 |
+
static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef dim, int64_t normalization, bool forward);
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
struct TORCH_API _fft_c2c_out {
|
| 29 |
+
using schema = at::Tensor & (const at::Tensor &, c10::SymIntArrayRef, int64_t, bool, at::Tensor &);
|
| 30 |
+
using ptr_schema = schema*;
|
| 31 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 32 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_fft_c2c")
|
| 33 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
|
| 34 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_fft_c2c.out(Tensor self, SymInt[] dim, int normalization, bool forward, *, Tensor(a!) out) -> Tensor(a!)")
|
| 35 |
+
static at::Tensor & call(const at::Tensor & self, c10::SymIntArrayRef dim, int64_t normalization, bool forward, at::Tensor & out);
|
| 36 |
+
static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef dim, int64_t normalization, bool forward, at::Tensor & out);
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
}} // namespace at::_ops
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_expm1_ops.h
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operator.h
|
| 4 |
+
|
| 5 |
+
#include <tuple>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 9 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 10 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 11 |
+
#include <ATen/core/ATen_fwd.h>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
namespace _ops {
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
struct TORCH_API _foreach_expm1 {
|
| 18 |
+
using schema = ::std::vector<at::Tensor> (at::TensorList);
|
| 19 |
+
using ptr_schema = schema*;
|
| 20 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 21 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_foreach_expm1")
|
| 22 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 23 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_foreach_expm1(Tensor[] self) -> Tensor[]")
|
| 24 |
+
static ::std::vector<at::Tensor> call(at::TensorList self);
|
| 25 |
+
static ::std::vector<at::Tensor> redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList self);
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
struct TORCH_API _foreach_expm1_ {
|
| 29 |
+
using schema = void (at::TensorList);
|
| 30 |
+
using ptr_schema = schema*;
|
| 31 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 32 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_foreach_expm1_")
|
| 33 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 34 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_foreach_expm1_(Tensor(a!)[] self) -> ()")
|
| 35 |
+
static void call(at::TensorList self);
|
| 36 |
+
static void redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList self);
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
struct TORCH_API _foreach_expm1_out {
|
| 40 |
+
using schema = void (at::TensorList, at::TensorList);
|
| 41 |
+
using ptr_schema = schema*;
|
| 42 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 43 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_foreach_expm1")
|
| 44 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
|
| 45 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_foreach_expm1.out(Tensor[] self, *, Tensor(a!)[] out) -> ()")
|
| 46 |
+
static void call(at::TensorList self, at::TensorList out);
|
| 47 |
+
static void redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out);
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
+
}} // namespace at::_ops
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_log1p_cuda_dispatch.h
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace cuda {
|
| 19 |
+
|
| 20 |
+
TORCH_API ::std::vector<at::Tensor> _foreach_log1p(at::TensorList self);
|
| 21 |
+
TORCH_API void _foreach_log1p_(at::TensorList self);
|
| 22 |
+
|
| 23 |
+
} // namespace cuda
|
| 24 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_foreach_reciprocal_ops.h
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operator.h
|
| 4 |
+
|
| 5 |
+
#include <tuple>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 9 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 10 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 11 |
+
#include <ATen/core/ATen_fwd.h>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
namespace _ops {
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
struct TORCH_API _foreach_reciprocal {
|
| 18 |
+
using schema = ::std::vector<at::Tensor> (at::TensorList);
|
| 19 |
+
using ptr_schema = schema*;
|
| 20 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 21 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_foreach_reciprocal")
|
| 22 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 23 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_foreach_reciprocal(Tensor[] self) -> Tensor[]")
|
| 24 |
+
static ::std::vector<at::Tensor> call(at::TensorList self);
|
| 25 |
+
static ::std::vector<at::Tensor> redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList self);
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
struct TORCH_API _foreach_reciprocal_ {
|
| 29 |
+
using schema = void (at::TensorList);
|
| 30 |
+
using ptr_schema = schema*;
|
| 31 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 32 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_foreach_reciprocal_")
|
| 33 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 34 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_foreach_reciprocal_(Tensor(a!)[] self) -> ()")
|
| 35 |
+
static void call(at::TensorList self);
|
| 36 |
+
static void redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList self);
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
struct TORCH_API _foreach_reciprocal_out {
|
| 40 |
+
using schema = void (at::TensorList, at::TensorList);
|
| 41 |
+
using ptr_schema = schema*;
|
| 42 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 43 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_foreach_reciprocal")
|
| 44 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
|
| 45 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_foreach_reciprocal.out(Tensor[] self, *, Tensor(a!)[] out) -> ()")
|
| 46 |
+
static void call(at::TensorList self, at::TensorList out);
|
| 47 |
+
static void redispatch(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out);
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
+
}} // namespace at::_ops
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_indices_copy.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Function.h
|
| 4 |
+
|
| 5 |
+
#include <ATen/Context.h>
|
| 6 |
+
#include <ATen/DeviceGuard.h>
|
| 7 |
+
#include <ATen/TensorUtils.h>
|
| 8 |
+
#include <ATen/TracerMode.h>
|
| 9 |
+
#include <ATen/core/Generator.h>
|
| 10 |
+
#include <ATen/core/Reduction.h>
|
| 11 |
+
#include <ATen/core/Tensor.h>
|
| 12 |
+
#include <c10/core/Scalar.h>
|
| 13 |
+
#include <c10/core/Storage.h>
|
| 14 |
+
#include <c10/core/TensorOptions.h>
|
| 15 |
+
#include <c10/util/Deprecated.h>
|
| 16 |
+
#include <c10/util/Optional.h>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#include <ATen/ops/_indices_copy_ops.h>
|
| 21 |
+
|
| 22 |
+
namespace at {
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
// aten::_indices_copy(Tensor self) -> Tensor
|
| 26 |
+
inline at::Tensor _indices_copy(const at::Tensor & self) {
|
| 27 |
+
return at::_ops::_indices_copy::call(self);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
// aten::_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
| 31 |
+
inline at::Tensor & _indices_copy_out(at::Tensor & out, const at::Tensor & self) {
|
| 32 |
+
return at::_ops::_indices_copy_out::call(self, out);
|
| 33 |
+
}
|
| 34 |
+
// aten::_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
| 35 |
+
inline at::Tensor & _indices_copy_outf(const at::Tensor & self, at::Tensor & out) {
|
| 36 |
+
return at::_ops::_indices_copy_out::call(self, out);
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_linalg_svd_meta_dispatch.h
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace meta {
|
| 19 |
+
|
| 20 |
+
TORCH_API ::std::tuple<at::Tensor,at::Tensor,at::Tensor> _linalg_svd(const at::Tensor & A, bool full_matrices=false, bool compute_uv=true, c10::optional<c10::string_view> driver=c10::nullopt);
|
| 21 |
+
TORCH_API ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &> _linalg_svd_out(at::Tensor & U, at::Tensor & S, at::Tensor & Vh, const at::Tensor & A, bool full_matrices=false, bool compute_uv=true, c10::optional<c10::string_view> driver=c10::nullopt);
|
| 22 |
+
TORCH_API ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &> _linalg_svd_outf(const at::Tensor & A, bool full_matrices, bool compute_uv, c10::optional<c10::string_view> driver, at::Tensor & U, at::Tensor & S, at::Tensor & Vh);
|
| 23 |
+
|
| 24 |
+
} // namespace meta
|
| 25 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_lstm_mps.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Function.h
|
| 4 |
+
|
| 5 |
+
#include <ATen/Context.h>
|
| 6 |
+
#include <ATen/DeviceGuard.h>
|
| 7 |
+
#include <ATen/TensorUtils.h>
|
| 8 |
+
#include <ATen/TracerMode.h>
|
| 9 |
+
#include <ATen/core/Generator.h>
|
| 10 |
+
#include <ATen/core/Reduction.h>
|
| 11 |
+
#include <ATen/core/Tensor.h>
|
| 12 |
+
#include <c10/core/Scalar.h>
|
| 13 |
+
#include <c10/core/Storage.h>
|
| 14 |
+
#include <c10/core/TensorOptions.h>
|
| 15 |
+
#include <c10/util/Deprecated.h>
|
| 16 |
+
#include <c10/util/Optional.h>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#include <ATen/ops/_lstm_mps_ops.h>
|
| 21 |
+
|
| 22 |
+
namespace at {
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
// aten::_lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)
|
| 26 |
+
inline ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> _lstm_mps(const at::Tensor & input, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) {
|
| 27 |
+
return at::_ops::_lstm_mps::call(input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
// aten::_lstm_mps.out(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4, Tensor(f!) out5) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!), Tensor(f!))
|
| 31 |
+
inline ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &> _lstm_mps_out(at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, at::Tensor & out5, const at::Tensor & input, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) {
|
| 32 |
+
return at::_ops::_lstm_mps_out::call(input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first, out0, out1, out2, out3, out4, out5);
|
| 33 |
+
}
|
| 34 |
+
// aten::_lstm_mps.out(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4, Tensor(f!) out5) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!), Tensor(f!))
|
| 35 |
+
inline ::std::tuple<at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &,at::Tensor &> _lstm_mps_outf(const at::Tensor & input, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, at::Tensor & out5) {
|
| 36 |
+
return at::_ops::_lstm_mps_out::call(input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first, out0, out1, out2, out3, out4, out5);
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_nested_get_values_copy_native.h
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from NativeFunction.h
|
| 4 |
+
|
| 5 |
+
#include <c10/core/Scalar.h>
|
| 6 |
+
#include <c10/core/Storage.h>
|
| 7 |
+
#include <c10/core/TensorOptions.h>
|
| 8 |
+
#include <c10/util/Deprecated.h>
|
| 9 |
+
#include <c10/util/Optional.h>
|
| 10 |
+
#include <c10/core/QScheme.h>
|
| 11 |
+
#include <ATen/core/Reduction.h>
|
| 12 |
+
#include <ATen/core/Tensor.h>
|
| 13 |
+
#include <tuple>
|
| 14 |
+
#include <vector>
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
namespace at {
|
| 18 |
+
namespace native {
|
| 19 |
+
TORCH_API at::Tensor & _nested_get_values_copy_out(const at::Tensor & self, at::Tensor & out);
|
| 20 |
+
TORCH_API at::Tensor _nested_get_values_copy(const at::Tensor & self);
|
| 21 |
+
} // namespace native
|
| 22 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_sparse_mask_projection_ops.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operator.h
|
| 4 |
+
|
| 5 |
+
#include <tuple>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 9 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 10 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 11 |
+
#include <ATen/core/ATen_fwd.h>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
namespace _ops {
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
struct TORCH_API _sparse_mask_projection {
|
| 18 |
+
using schema = at::Tensor (const at::Tensor &, const at::Tensor &, bool);
|
| 19 |
+
using ptr_schema = schema*;
|
| 20 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 21 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_sparse_mask_projection")
|
| 22 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 23 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_sparse_mask_projection(Tensor self, Tensor mask, bool accumulate_matches=False) -> Tensor")
|
| 24 |
+
static at::Tensor call(const at::Tensor & self, const at::Tensor & mask, bool accumulate_matches);
|
| 25 |
+
static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, bool accumulate_matches);
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
struct TORCH_API _sparse_mask_projection_out {
|
| 29 |
+
using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, bool, at::Tensor &);
|
| 30 |
+
using ptr_schema = schema*;
|
| 31 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 32 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::_sparse_mask_projection")
|
| 33 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
|
| 34 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "_sparse_mask_projection.out(Tensor self, Tensor mask, bool accumulate_matches=False, *, Tensor(a!) out) -> Tensor(a!)")
|
| 35 |
+
static at::Tensor & call(const at::Tensor & self, const at::Tensor & mask, bool accumulate_matches, at::Tensor & out);
|
| 36 |
+
static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, bool accumulate_matches, at::Tensor & out);
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
}} // namespace at::_ops
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_spdiags.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Function.h
|
| 4 |
+
|
| 5 |
+
#include <ATen/Context.h>
|
| 6 |
+
#include <ATen/DeviceGuard.h>
|
| 7 |
+
#include <ATen/TensorUtils.h>
|
| 8 |
+
#include <ATen/TracerMode.h>
|
| 9 |
+
#include <ATen/core/Generator.h>
|
| 10 |
+
#include <ATen/core/Reduction.h>
|
| 11 |
+
#include <ATen/core/Tensor.h>
|
| 12 |
+
#include <c10/core/Scalar.h>
|
| 13 |
+
#include <c10/core/Storage.h>
|
| 14 |
+
#include <c10/core/TensorOptions.h>
|
| 15 |
+
#include <c10/util/Deprecated.h>
|
| 16 |
+
#include <c10/util/Optional.h>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#include <ATen/ops/_spdiags_ops.h>
|
| 21 |
+
|
| 22 |
+
namespace at {
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
// aten::_spdiags(Tensor diagonals, Tensor offsets, int[] shape, Layout? layout=None) -> Tensor
|
| 26 |
+
inline at::Tensor _spdiags(const at::Tensor & diagonals, const at::Tensor & offsets, at::IntArrayRef shape, c10::optional<at::Layout> layout=c10::nullopt) {
|
| 27 |
+
return at::_ops::_spdiags::call(diagonals, offsets, shape, layout);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
// aten::_spdiags.out(Tensor diagonals, Tensor offsets, int[] shape, Layout? layout=None, *, Tensor(a!) out) -> Tensor(a!)
|
| 31 |
+
inline at::Tensor & _spdiags_out(at::Tensor & out, const at::Tensor & diagonals, const at::Tensor & offsets, at::IntArrayRef shape, c10::optional<at::Layout> layout=c10::nullopt) {
|
| 32 |
+
return at::_ops::_spdiags_out::call(diagonals, offsets, shape, layout, out);
|
| 33 |
+
}
|
| 34 |
+
// aten::_spdiags.out(Tensor diagonals, Tensor offsets, int[] shape, Layout? layout=None, *, Tensor(a!) out) -> Tensor(a!)
|
| 35 |
+
inline at::Tensor & _spdiags_outf(const at::Tensor & diagonals, const at::Tensor & offsets, at::IntArrayRef shape, c10::optional<at::Layout> layout, at::Tensor & out) {
|
| 36 |
+
return at::_ops::_spdiags_out::call(diagonals, offsets, shape, layout, out);
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_test_optional_intlist_compositeexplicitautograd_dispatch.h
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace compositeexplicitautograd {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor & _test_optional_intlist_out(at::Tensor & out, const at::Tensor & values, at::OptionalIntArrayRef addends);
|
| 21 |
+
TORCH_API at::Tensor & _test_optional_intlist_outf(const at::Tensor & values, at::OptionalIntArrayRef addends, at::Tensor & out);
|
| 22 |
+
|
| 23 |
+
} // namespace compositeexplicitautograd
|
| 24 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/_upsample_bilinear2d_aa_cpu_dispatch.h
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace cpu {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor _upsample_bilinear2d_aa(const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, c10::optional<double> scales_h=c10::nullopt, c10::optional<double> scales_w=c10::nullopt);
|
| 21 |
+
TORCH_API at::Tensor _upsample_bilinear2d_aa_symint(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional<double> scales_h=c10::nullopt, c10::optional<double> scales_w=c10::nullopt);
|
| 22 |
+
TORCH_API at::Tensor & _upsample_bilinear2d_aa_out(at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, c10::optional<double> scales_h=c10::nullopt, c10::optional<double> scales_w=c10::nullopt);
|
| 23 |
+
TORCH_API at::Tensor & _upsample_bilinear2d_aa_outf(const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, c10::optional<double> scales_h, c10::optional<double> scales_w, at::Tensor & out);
|
| 24 |
+
TORCH_API at::Tensor & _upsample_bilinear2d_aa_symint_out(at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional<double> scales_h=c10::nullopt, c10::optional<double> scales_w=c10::nullopt);
|
| 25 |
+
TORCH_API at::Tensor & _upsample_bilinear2d_aa_symint_outf(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, c10::optional<double> scales_h, c10::optional<double> scales_w, at::Tensor & out);
|
| 26 |
+
|
| 27 |
+
} // namespace cpu
|
| 28 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/abs_ops.h
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operator.h
|
| 4 |
+
|
| 5 |
+
#include <tuple>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 9 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 10 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 11 |
+
#include <ATen/core/ATen_fwd.h>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
namespace _ops {
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
struct TORCH_API abs {
|
| 18 |
+
using schema = at::Tensor (const at::Tensor &);
|
| 19 |
+
using ptr_schema = schema*;
|
| 20 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 21 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::abs")
|
| 22 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 23 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "abs(Tensor self) -> Tensor")
|
| 24 |
+
static at::Tensor call(const at::Tensor & self);
|
| 25 |
+
static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self);
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
struct TORCH_API abs_ {
|
| 29 |
+
using schema = at::Tensor & (at::Tensor &);
|
| 30 |
+
using ptr_schema = schema*;
|
| 31 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 32 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::abs_")
|
| 33 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 34 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "abs_(Tensor(a!) self) -> Tensor(a!)")
|
| 35 |
+
static at::Tensor & call(at::Tensor & self);
|
| 36 |
+
static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, at::Tensor & self);
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
struct TORCH_API abs_out {
|
| 40 |
+
using schema = at::Tensor & (const at::Tensor &, at::Tensor &);
|
| 41 |
+
using ptr_schema = schema*;
|
| 42 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 43 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::abs")
|
| 44 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
|
| 45 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)")
|
| 46 |
+
static at::Tensor & call(const at::Tensor & self, at::Tensor & out);
|
| 47 |
+
static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out);
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
+
}} // namespace at::_ops
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/aminmax.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Function.h
|
| 4 |
+
|
| 5 |
+
#include <ATen/Context.h>
|
| 6 |
+
#include <ATen/DeviceGuard.h>
|
| 7 |
+
#include <ATen/TensorUtils.h>
|
| 8 |
+
#include <ATen/TracerMode.h>
|
| 9 |
+
#include <ATen/core/Generator.h>
|
| 10 |
+
#include <ATen/core/Reduction.h>
|
| 11 |
+
#include <ATen/core/Tensor.h>
|
| 12 |
+
#include <c10/core/Scalar.h>
|
| 13 |
+
#include <c10/core/Storage.h>
|
| 14 |
+
#include <c10/core/TensorOptions.h>
|
| 15 |
+
#include <c10/util/Deprecated.h>
|
| 16 |
+
#include <c10/util/Optional.h>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#include <ATen/ops/aminmax_ops.h>
|
| 21 |
+
|
| 22 |
+
namespace at {
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
// aten::aminmax(Tensor self, *, int? dim=None, bool keepdim=False) -> (Tensor min, Tensor max)
|
| 26 |
+
inline ::std::tuple<at::Tensor,at::Tensor> aminmax(const at::Tensor & self, c10::optional<int64_t> dim=c10::nullopt, bool keepdim=false) {
|
| 27 |
+
return at::_ops::aminmax::call(self, dim, keepdim);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
// aten::aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max)
|
| 31 |
+
inline ::std::tuple<at::Tensor &,at::Tensor &> aminmax_out(at::Tensor & min, at::Tensor & max, const at::Tensor & self, c10::optional<int64_t> dim=c10::nullopt, bool keepdim=false) {
|
| 32 |
+
return at::_ops::aminmax_out::call(self, dim, keepdim, min, max);
|
| 33 |
+
}
|
| 34 |
+
// aten::aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max)
|
| 35 |
+
inline ::std::tuple<at::Tensor &,at::Tensor &> aminmax_outf(const at::Tensor & self, c10::optional<int64_t> dim, bool keepdim, at::Tensor & min, at::Tensor & max) {
|
| 36 |
+
return at::_ops::aminmax_out::call(self, dim, keepdim, min, max);
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/arcsinh_compositeimplicitautograd_dispatch.h
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace compositeimplicitautograd {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor arcsinh(const at::Tensor & self);
|
| 21 |
+
TORCH_API at::Tensor & arcsinh_out(at::Tensor & out, const at::Tensor & self);
|
| 22 |
+
TORCH_API at::Tensor & arcsinh_outf(const at::Tensor & self, at::Tensor & out);
|
| 23 |
+
TORCH_API at::Tensor & arcsinh_(at::Tensor & self);
|
| 24 |
+
|
| 25 |
+
} // namespace compositeimplicitautograd
|
| 26 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/avg_pool2d_ops.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from Operator.h
|
| 4 |
+
|
| 5 |
+
#include <tuple>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 9 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 10 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 11 |
+
#include <ATen/core/ATen_fwd.h>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
namespace _ops {
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
struct TORCH_API avg_pool2d_out {
|
| 18 |
+
using schema = at::Tensor & (const at::Tensor &, at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, bool, bool, c10::optional<int64_t>, at::Tensor &);
|
| 19 |
+
using ptr_schema = schema*;
|
| 20 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 21 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::avg_pool2d")
|
| 22 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "out")
|
| 23 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "avg_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!)")
|
| 24 |
+
static at::Tensor & call(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, c10::optional<int64_t> divisor_override, at::Tensor & out);
|
| 25 |
+
static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, c10::optional<int64_t> divisor_override, at::Tensor & out);
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
struct TORCH_API avg_pool2d {
|
| 29 |
+
using schema = at::Tensor (const at::Tensor &, at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, bool, bool, c10::optional<int64_t>);
|
| 30 |
+
using ptr_schema = schema*;
|
| 31 |
+
// See Note [static constexpr char* members for windows NVCC]
|
| 32 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::avg_pool2d")
|
| 33 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
|
| 34 |
+
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor")
|
| 35 |
+
static at::Tensor call(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, c10::optional<int64_t> divisor_override);
|
| 36 |
+
static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, c10::optional<int64_t> divisor_override);
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
}} // namespace at::_ops
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/batch_norm_compositeimplicitautograd_dispatch.h
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace compositeimplicitautograd {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor batch_norm(const at::Tensor & input, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, const c10::optional<at::Tensor> & running_mean, const c10::optional<at::Tensor> & running_var, bool training, double momentum, double eps, bool cudnn_enabled);
|
| 21 |
+
|
| 22 |
+
} // namespace compositeimplicitautograd
|
| 23 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/block_diag_native.h
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from NativeFunction.h
|
| 4 |
+
|
| 5 |
+
#include <c10/core/Scalar.h>
|
| 6 |
+
#include <c10/core/Storage.h>
|
| 7 |
+
#include <c10/core/TensorOptions.h>
|
| 8 |
+
#include <c10/util/Deprecated.h>
|
| 9 |
+
#include <c10/util/Optional.h>
|
| 10 |
+
#include <c10/core/QScheme.h>
|
| 11 |
+
#include <ATen/core/Reduction.h>
|
| 12 |
+
#include <ATen/core/Tensor.h>
|
| 13 |
+
#include <tuple>
|
| 14 |
+
#include <vector>
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
namespace at {
|
| 18 |
+
namespace native {
|
| 19 |
+
TORCH_API at::Tensor block_diag(at::TensorList tensors);
|
| 20 |
+
TORCH_API at::Tensor & block_diag_out(at::TensorList tensors, at::Tensor & out);
|
| 21 |
+
} // namespace native
|
| 22 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/cat_compositeimplicitautograd_dispatch.h
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunction.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 12 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 13 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 14 |
+
#include <ATen/core/ATen_fwd.h>
|
| 15 |
+
|
| 16 |
+
namespace at {
|
| 17 |
+
|
| 18 |
+
namespace compositeimplicitautograd {
|
| 19 |
+
|
| 20 |
+
TORCH_API at::Tensor cat(at::TensorList tensors, at::Dimname dim);
|
| 21 |
+
TORCH_API at::Tensor & cat_out(at::Tensor & out, at::TensorList tensors, at::Dimname dim);
|
| 22 |
+
TORCH_API at::Tensor & cat_outf(at::TensorList tensors, at::Dimname dim, at::Tensor & out);
|
| 23 |
+
|
| 24 |
+
} // namespace compositeimplicitautograd
|
| 25 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ops/cat_native.h
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from NativeFunction.h
|
| 4 |
+
|
| 5 |
+
#include <c10/core/Scalar.h>
|
| 6 |
+
#include <c10/core/Storage.h>
|
| 7 |
+
#include <c10/core/TensorOptions.h>
|
| 8 |
+
#include <c10/util/Deprecated.h>
|
| 9 |
+
#include <c10/util/Optional.h>
|
| 10 |
+
#include <c10/core/QScheme.h>
|
| 11 |
+
#include <ATen/core/Reduction.h>
|
| 12 |
+
#include <ATen/core/Tensor.h>
|
| 13 |
+
#include <tuple>
|
| 14 |
+
#include <vector>
|
| 15 |
+
#include <ATen/ops/cat_meta.h>
|
| 16 |
+
|
| 17 |
+
namespace at {
|
| 18 |
+
namespace native {
|
| 19 |
+
struct TORCH_API structured_cat_out_cpu : public at::meta::structured_cat {
|
| 20 |
+
void impl(const at::ITensorListRef & tensors, int64_t dim, int64_t valid, bool all_contiguous, bool all_same_dtype, bool all_same_sizes_and_stride, at::MemoryFormat memory_format, const at::Tensor & out);
|
| 21 |
+
};
|
| 22 |
+
struct TORCH_API structured_cat_out_cuda : public at::meta::structured_cat {
|
| 23 |
+
void impl(const at::ITensorListRef & tensors, int64_t dim, int64_t valid, bool all_contiguous, bool all_same_dtype, bool all_same_sizes_and_stride, at::MemoryFormat memory_format, const at::Tensor & out);
|
| 24 |
+
};
|
| 25 |
+
TORCH_API at::Tensor cat_nested(const at::ITensorListRef & tensors, int64_t dim=0);
|
| 26 |
+
TORCH_API at::Tensor cat_sparse(const at::ITensorListRef & tensors, int64_t dim=0);
|
| 27 |
+
TORCH_API at::Tensor cat_quantized_cpu(const at::ITensorListRef & tensors, int64_t dim=0);
|
| 28 |
+
TORCH_API at::Tensor & cat_out_quantized_cpu(const at::ITensorListRef & tensors, int64_t dim, at::Tensor & out);
|
| 29 |
+
TORCH_API at::Tensor cat(at::TensorList tensors, at::Dimname dim);
|
| 30 |
+
TORCH_API at::Tensor & cat_out(at::TensorList tensors, at::Dimname dim, at::Tensor & out);
|
| 31 |
+
} // namespace native
|
| 32 |
+
} // namespace at
|