Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/functional.h +9 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/functional_base.h +480 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/functional_bfloat16.h +652 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/intrinsics.h +6 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec.h +62 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_half_neon.h +627 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_reduced_precision_common_neon.h +316 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_base.h +1537 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_convert.h +84 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_half.h +123 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_mask.h +318 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_n.h +412 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_quant.h +258 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/ADInterpreters.h +43 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchedFallback.h +86 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchedTensorImpl.h +181 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/DynamicLayer.h +129 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/FunctionalizeInterpreter.h +27 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/Interpreter.h +358 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/PlumbingHelper.h +68 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/TensorWrapper.h +108 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h +102 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/packed_params.h +78 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/qnnpack_utils.h +95 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/AtomicAddFloat.h +42 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/ChannelShuffleKernel.h +19 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/DepthwiseConvKernel.h +26 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/Elu.h +79 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/GridSamplerKernel.h +39 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/IndexKernelUtils.h +90 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/Intrinsics.h +38 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/IsContiguous.h +69 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/LogSoftmaxKernelImpl.h +342 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/Loops.h +400 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/MaxUnpoolKernel.h +19 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/PixelShuffleKernel.h +19 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/ReduceUtils.h +242 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.h +32 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/SampledAddmmKernel.h +17 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/SerialStackImpl.h +151 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/SoftmaxKernel.h +33 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/SpmmReduceKernel.h +27 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/StackKernel.h +17 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/UpSampleKernelAVXAntialias.h +1381 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/avx_mathfun.h +527 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/int_mm_kernel.h +43 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/mixed_data_type.h +46 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/utils.h +225 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/zmath.h +255 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/CUDAJitLoops.cuh +332 -0
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/functional.h
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/cpu/vec/functional_base.h>
|
| 5 |
+
#include <ATen/cpu/vec/functional_bfloat16.h>
|
| 6 |
+
|
| 7 |
+
#else
|
| 8 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 9 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/functional_base.h
ADDED
|
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
|
| 5 |
+
// See Note [Do not compile initializers with AVX]
|
| 6 |
+
|
| 7 |
+
#include <ATen/cpu/vec/vec.h>
|
| 8 |
+
#include <c10/util/irange.h>
|
| 9 |
+
|
| 10 |
+
namespace at {
|
| 11 |
+
namespace detail {
|
| 12 |
+
// We prefer to convert through float for reduced-precision floating
|
| 13 |
+
// point types if we have a Vectorized specialization for float and we
|
| 14 |
+
// don't have one for the actual type in question.
|
| 15 |
+
template <typename T>
|
| 16 |
+
struct should_prefer_converting_through_float
|
| 17 |
+
: std::bool_constant<
|
| 18 |
+
is_reduced_floating_point_v<T> &&
|
| 19 |
+
vec::is_vec_specialized_for_v<float> &&
|
| 20 |
+
!vec::is_vec_specialized_for_v<T>> {};
|
| 21 |
+
|
| 22 |
+
template <typename T>
|
| 23 |
+
constexpr auto should_prefer_converting_through_float_v =
|
| 24 |
+
should_prefer_converting_through_float<T>::value;
|
| 25 |
+
} // namespace detail
|
| 26 |
+
|
| 27 |
+
namespace vec {
|
| 28 |
+
// slow path
|
| 29 |
+
template <typename scalar_t, typename Op>
|
| 30 |
+
inline scalar_t vec_reduce_all(
|
| 31 |
+
const Op& vec_fun,
|
| 32 |
+
vec::Vectorized<scalar_t> acc_vec,
|
| 33 |
+
int64_t size) {
|
| 34 |
+
using Vec = vec::Vectorized<scalar_t>;
|
| 35 |
+
scalar_t acc_arr[Vec::size()];
|
| 36 |
+
acc_vec.store(acc_arr);
|
| 37 |
+
for (const auto i : c10::irange(1, size)) {
|
| 38 |
+
std::array<scalar_t, Vec::size()> acc_arr_next = {0};
|
| 39 |
+
acc_arr_next[0] = acc_arr[i];
|
| 40 |
+
Vec acc_vec_next = Vec::loadu(acc_arr_next.data());
|
| 41 |
+
acc_vec = vec_fun(acc_vec, acc_vec_next);
|
| 42 |
+
}
|
| 43 |
+
acc_vec.store(acc_arr);
|
| 44 |
+
return acc_arr[0];
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
template <typename scalar_t, typename Op>
|
| 48 |
+
struct VecReduceAllSIMD {
|
| 49 |
+
static inline scalar_t apply(
|
| 50 |
+
const Op& vec_fun,
|
| 51 |
+
const Vectorized<scalar_t>& acc_vec) {
|
| 52 |
+
return vec_reduce_all(vec_fun, acc_vec, Vectorized<scalar_t>::size());
|
| 53 |
+
}
|
| 54 |
+
};
|
| 55 |
+
|
| 56 |
+
#if defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && \
|
| 57 |
+
!defined(C10_MOBILE)
|
| 58 |
+
#if defined(CPU_CAPABILITY_AVX2)
|
| 59 |
+
template <typename Op>
|
| 60 |
+
struct VecReduceAllSIMD<float, Op> {
|
| 61 |
+
static inline float apply(
|
| 62 |
+
const Op& vec_fun,
|
| 63 |
+
const Vectorized<float>& acc_vec) {
|
| 64 |
+
using Vec = Vectorized<float>;
|
| 65 |
+
Vec v = acc_vec;
|
| 66 |
+
// 128-bit shuffle
|
| 67 |
+
Vec v1 = _mm256_permute2f128_ps(v, v, 0x1);
|
| 68 |
+
v = vec_fun(v, v1);
|
| 69 |
+
// 64-bit shuffle
|
| 70 |
+
v1 = _mm256_shuffle_ps(v, v, 0x4E);
|
| 71 |
+
v = vec_fun(v, v1);
|
| 72 |
+
// 32-bit shuffle
|
| 73 |
+
v1 = _mm256_shuffle_ps(v, v, 0xB1);
|
| 74 |
+
v = vec_fun(v, v1);
|
| 75 |
+
return _mm256_cvtss_f32(v);
|
| 76 |
+
}
|
| 77 |
+
};
|
| 78 |
+
#endif // defined(CPU_CAPABILITY_AVX2)
|
| 79 |
+
#if defined(CPU_CAPABILITY_AVX512)
|
| 80 |
+
template <typename Op>
|
| 81 |
+
struct VecReduceAllSIMD<float, Op> {
|
| 82 |
+
static inline float apply(
|
| 83 |
+
const Op& vec_fun,
|
| 84 |
+
const Vectorized<float>& acc_vec) {
|
| 85 |
+
using Vec = Vectorized<float>;
|
| 86 |
+
Vec v = acc_vec;
|
| 87 |
+
// 256-bit shuffle
|
| 88 |
+
Vec v1 = _mm512_shuffle_f32x4(v, v, 0x4E);
|
| 89 |
+
v = vec_fun(v, v1);
|
| 90 |
+
// 128-bit shuffle
|
| 91 |
+
v1 = _mm512_shuffle_f32x4(v, v, 0xB1);
|
| 92 |
+
v = vec_fun(v, v1);
|
| 93 |
+
// 64-bit shuffle
|
| 94 |
+
v1 = _mm512_shuffle_ps(v, v, 0x4E);
|
| 95 |
+
v = vec_fun(v, v1);
|
| 96 |
+
// 32-bit shuffle
|
| 97 |
+
v1 = _mm512_shuffle_ps(v, v, 0xB1);
|
| 98 |
+
v = vec_fun(v, v1);
|
| 99 |
+
return _mm512_cvtss_f32(v);
|
| 100 |
+
}
|
| 101 |
+
};
|
| 102 |
+
#endif // defined(CPU_CAPABILITY_AVX512)
|
| 103 |
+
#endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) &&
|
| 104 |
+
// !defined(C10_MOBILE)
|
| 105 |
+
|
| 106 |
+
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \
|
| 107 |
+
!defined(CPU_CAPABILITY_SVE)
|
| 108 |
+
template <typename Op>
|
| 109 |
+
struct VecReduceAllSIMD<float, Op> {
|
| 110 |
+
static inline float apply(
|
| 111 |
+
const Op& vec_fun,
|
| 112 |
+
const Vectorized<float>& acc_vec) {
|
| 113 |
+
using Vec = Vectorized<float>;
|
| 114 |
+
Vec v = acc_vec;
|
| 115 |
+
|
| 116 |
+
// 64-bit shuffle: [a1+a5, a2+a6, a3+a7, a4+a8, -, -, -, -] -> [a3+a7,
|
| 117 |
+
// a4+a8, a1+a5, a2+a6, -, -, -, -]
|
| 118 |
+
float32x4_t v1_1 = vextq_f32(v, v, 2);
|
| 119 |
+
Vec v1 = v1_1;
|
| 120 |
+
// [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -]
|
| 121 |
+
v = vec_fun(v, v1);
|
| 122 |
+
|
| 123 |
+
// 32-bit shuffle: [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -,
|
| 124 |
+
// -, -, -] -> [a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, -, -, -,
|
| 125 |
+
// -]
|
| 126 |
+
v1_1 = vrev64q_f32(v);
|
| 127 |
+
v1 = v1_1;
|
| 128 |
+
// [a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8,
|
| 129 |
+
// a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, -, -, -, -]
|
| 130 |
+
v = vec_fun(v, v1);
|
| 131 |
+
|
| 132 |
+
return v[0];
|
| 133 |
+
}
|
| 134 |
+
};
|
| 135 |
+
|
| 136 |
+
template <>
|
| 137 |
+
struct VecReduceAllSIMD<float, std::plus<Vectorized<float>>> {
|
| 138 |
+
static inline float apply(
|
| 139 |
+
const std::plus<Vectorized<float>>& vec_fun,
|
| 140 |
+
const Vectorized<float>& acc_vec) {
|
| 141 |
+
return vaddvq_f32(acc_vec);
|
| 142 |
+
}
|
| 143 |
+
};
|
| 144 |
+
#endif // defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__)
|
| 145 |
+
// && !defined(CPU_CAPABILITY_SVE)
|
| 146 |
+
|
| 147 |
+
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \
|
| 148 |
+
defined(CPU_CAPABILITY_SVE256)
|
| 149 |
+
template <typename Op>
|
| 150 |
+
struct VecReduceAllSIMD<float, Op> {
|
| 151 |
+
static inline float apply(
|
| 152 |
+
const Op& vec_fun,
|
| 153 |
+
const Vectorized<float>& acc_vec) {
|
| 154 |
+
using Vec = Vectorized<float>;
|
| 155 |
+
Vec v = acc_vec;
|
| 156 |
+
// 128-bit shuffle
|
| 157 |
+
svuint32_t ind = svdupq_n_u32(4, 5, 6, 7);
|
| 158 |
+
Vec v1 = svtbl_f32(v, ind);
|
| 159 |
+
v = vec_fun(v, v1);
|
| 160 |
+
// 64-bit shuffle
|
| 161 |
+
ind = svdupq_n_u32(2, 3, 0, 1);
|
| 162 |
+
v1 = svtbl_f32(v, ind);
|
| 163 |
+
v = vec_fun(v, v1);
|
| 164 |
+
// 32-bit shuffle
|
| 165 |
+
ind = svdupq_n_u32(1, 0, 2, 3);
|
| 166 |
+
v1 = svtbl_f32(v, ind);
|
| 167 |
+
v = vec_fun(v, v1);
|
| 168 |
+
return svlasta(svpfalse(), v);
|
| 169 |
+
}
|
| 170 |
+
};
|
| 171 |
+
#endif // defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__)
|
| 172 |
+
// && defined(CPU_CAPABILITY_SVE256)
|
| 173 |
+
|
| 174 |
+
template <typename scalar_t, typename Op>
|
| 175 |
+
inline scalar_t vec_reduce_all(
|
| 176 |
+
const Op& vec_fun,
|
| 177 |
+
const Vectorized<scalar_t>& acc_vec) {
|
| 178 |
+
return VecReduceAllSIMD<scalar_t, Op>::apply(vec_fun, acc_vec);
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
template <
|
| 182 |
+
typename scalar_t,
|
| 183 |
+
typename Op,
|
| 184 |
+
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
|
| 185 |
+
inline scalar_t reduce_all(
|
| 186 |
+
const Op& vec_fun,
|
| 187 |
+
const scalar_t* data,
|
| 188 |
+
int64_t size) {
|
| 189 |
+
using Vec = vec::Vectorized<scalar_t>;
|
| 190 |
+
if (size < Vec::size())
|
| 191 |
+
return vec_reduce_all(vec_fun, Vec::loadu(data, size), size);
|
| 192 |
+
int64_t d = Vec::size();
|
| 193 |
+
Vec acc_vec = Vec::loadu(data);
|
| 194 |
+
for (; d < size - (size % Vec::size()); d += Vec::size()) {
|
| 195 |
+
Vec data_vec = Vec::loadu(data + d);
|
| 196 |
+
acc_vec = vec_fun(acc_vec, data_vec);
|
| 197 |
+
}
|
| 198 |
+
if (size - d > 0) {
|
| 199 |
+
Vec data_vec = Vec::loadu(data + d, size - d);
|
| 200 |
+
acc_vec = Vec::set(acc_vec, vec_fun(acc_vec, data_vec), size - d);
|
| 201 |
+
}
|
| 202 |
+
return vec_reduce_all(vec_fun, acc_vec);
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
// similar to reduce_all, but reduces into two outputs
|
| 206 |
+
template <
|
| 207 |
+
typename scalar_t,
|
| 208 |
+
typename Op1,
|
| 209 |
+
typename Op2,
|
| 210 |
+
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
|
| 211 |
+
inline std::pair<scalar_t, scalar_t> reduce2_all(
|
| 212 |
+
const Op1& vec_fun1,
|
| 213 |
+
const Op2& vec_fun2,
|
| 214 |
+
const scalar_t* data,
|
| 215 |
+
int64_t size) {
|
| 216 |
+
using Vec = vec::Vectorized<scalar_t>;
|
| 217 |
+
if (size < Vec::size()) {
|
| 218 |
+
auto loaded_data = Vec::loadu(data, size);
|
| 219 |
+
return std::pair<scalar_t, scalar_t>(
|
| 220 |
+
vec_reduce_all(vec_fun1, loaded_data, size),
|
| 221 |
+
vec_reduce_all(vec_fun2, loaded_data, size));
|
| 222 |
+
}
|
| 223 |
+
int64_t d = Vec::size();
|
| 224 |
+
Vec acc_vec1 = Vec::loadu(data);
|
| 225 |
+
Vec acc_vec2 = Vec::loadu(data);
|
| 226 |
+
for (; d < size - (size % Vec::size()); d += Vec::size()) {
|
| 227 |
+
Vec data_vec = Vec::loadu(data + d);
|
| 228 |
+
acc_vec1 = vec_fun1(acc_vec1, data_vec);
|
| 229 |
+
acc_vec2 = vec_fun2(acc_vec2, data_vec);
|
| 230 |
+
}
|
| 231 |
+
if (size - d > 0) {
|
| 232 |
+
Vec data_vec = Vec::loadu(data + d, size - d);
|
| 233 |
+
acc_vec1 = Vec::set(acc_vec1, vec_fun1(acc_vec1, data_vec), size - d);
|
| 234 |
+
acc_vec2 = Vec::set(acc_vec2, vec_fun2(acc_vec2, data_vec), size - d);
|
| 235 |
+
}
|
| 236 |
+
return std::pair<scalar_t, scalar_t>(
|
| 237 |
+
vec_reduce_all(vec_fun1, acc_vec1), vec_reduce_all(vec_fun2, acc_vec2));
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
template <
|
| 241 |
+
typename scalar_t,
|
| 242 |
+
typename MapOp,
|
| 243 |
+
typename ReduceOp,
|
| 244 |
+
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
|
| 245 |
+
inline scalar_t map_reduce_all(
|
| 246 |
+
const MapOp& map_fun,
|
| 247 |
+
const ReduceOp& red_fun,
|
| 248 |
+
const scalar_t* data,
|
| 249 |
+
int64_t size) {
|
| 250 |
+
using Vec = vec::Vectorized<scalar_t>;
|
| 251 |
+
if (size < Vec::size())
|
| 252 |
+
return vec_reduce_all(red_fun, map_fun(Vec::loadu(data, size)), size);
|
| 253 |
+
int64_t d = Vec::size();
|
| 254 |
+
Vec acc_vec = map_fun(Vec::loadu(data));
|
| 255 |
+
for (; d < size - (size % Vec::size()); d += Vec::size()) {
|
| 256 |
+
Vec data_vec = Vec::loadu(data + d);
|
| 257 |
+
data_vec = map_fun(data_vec);
|
| 258 |
+
acc_vec = red_fun(acc_vec, data_vec);
|
| 259 |
+
}
|
| 260 |
+
if (size - d > 0) {
|
| 261 |
+
Vec data_vec = Vec::loadu(data + d, size - d);
|
| 262 |
+
data_vec = map_fun(data_vec);
|
| 263 |
+
acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
|
| 264 |
+
}
|
| 265 |
+
return vec_reduce_all(red_fun, acc_vec);
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
template <
|
| 269 |
+
typename scalar_t,
|
| 270 |
+
typename MapOp,
|
| 271 |
+
typename ReduceOp,
|
| 272 |
+
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
|
| 273 |
+
inline scalar_t map2_reduce_all(
|
| 274 |
+
const MapOp& map_fun,
|
| 275 |
+
const ReduceOp& red_fun,
|
| 276 |
+
const scalar_t* data,
|
| 277 |
+
const scalar_t* data2,
|
| 278 |
+
int64_t size) {
|
| 279 |
+
using Vec = vec::Vectorized<scalar_t>;
|
| 280 |
+
if (size < Vec::size()) {
|
| 281 |
+
Vec data_vec = Vec::loadu(data, size);
|
| 282 |
+
Vec data2_vec = Vec::loadu(data2, size);
|
| 283 |
+
data_vec = map_fun(data_vec, data2_vec);
|
| 284 |
+
return vec_reduce_all(red_fun, data_vec, size);
|
| 285 |
+
}
|
| 286 |
+
int64_t d = Vec::size();
|
| 287 |
+
Vec acc_vec = map_fun(Vec::loadu(data), Vec::loadu(data2));
|
| 288 |
+
for (; d < size - (size % Vec::size()); d += Vec::size()) {
|
| 289 |
+
Vec data_vec = Vec::loadu(data + d);
|
| 290 |
+
Vec data2_vec = Vec::loadu(data2 + d);
|
| 291 |
+
data_vec = map_fun(data_vec, data2_vec);
|
| 292 |
+
acc_vec = red_fun(acc_vec, data_vec);
|
| 293 |
+
}
|
| 294 |
+
if (size - d > 0) {
|
| 295 |
+
Vec data_vec = Vec::loadu(data + d, size - d);
|
| 296 |
+
Vec data2_vec = Vec::loadu(data2 + d, size - d);
|
| 297 |
+
data_vec = map_fun(data_vec, data2_vec);
|
| 298 |
+
acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
|
| 299 |
+
}
|
| 300 |
+
return vec_reduce_all(red_fun, acc_vec);
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
template <
|
| 304 |
+
typename scalar_t,
|
| 305 |
+
typename MapOp,
|
| 306 |
+
typename ReduceOp,
|
| 307 |
+
typename std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, int> = 0>
|
| 308 |
+
inline scalar_t map3_reduce_all(
|
| 309 |
+
const MapOp& map_fun,
|
| 310 |
+
const ReduceOp& red_fun,
|
| 311 |
+
const scalar_t* data,
|
| 312 |
+
const scalar_t* data2,
|
| 313 |
+
const scalar_t* data3,
|
| 314 |
+
int64_t size) {
|
| 315 |
+
using Vec = vec::Vectorized<scalar_t>;
|
| 316 |
+
if (size < Vec::size()) {
|
| 317 |
+
Vec data_vec = Vec::loadu(data, size);
|
| 318 |
+
Vec data2_vec = Vec::loadu(data2, size);
|
| 319 |
+
Vec data3_vec = Vec::loadu(data3, size);
|
| 320 |
+
data_vec = map_fun(data_vec, data2_vec, data3_vec);
|
| 321 |
+
return vec_reduce_all(red_fun, data_vec, size);
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
int64_t d = Vec::size();
|
| 325 |
+
Vec acc_vec = map_fun(Vec::loadu(data), Vec::loadu(data2), Vec::loadu(data3));
|
| 326 |
+
for (; d < size - (size % Vec::size()); d += Vec::size()) {
|
| 327 |
+
Vec data_vec = Vec::loadu(data + d);
|
| 328 |
+
Vec data2_vec = Vec::loadu(data2 + d);
|
| 329 |
+
Vec data3_vec = Vec::loadu(data3 + d);
|
| 330 |
+
data_vec = map_fun(data_vec, data2_vec, data3_vec);
|
| 331 |
+
acc_vec = red_fun(acc_vec, data_vec);
|
| 332 |
+
}
|
| 333 |
+
if (size - d > 0) {
|
| 334 |
+
Vec data_vec = Vec::loadu(data + d, size - d);
|
| 335 |
+
Vec data2_vec = Vec::loadu(data2 + d, size - d);
|
| 336 |
+
Vec data3_vec = Vec::loadu(data3 + d, size - d);
|
| 337 |
+
data_vec = map_fun(data_vec, data2_vec, data3_vec);
|
| 338 |
+
acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
|
| 339 |
+
}
|
| 340 |
+
return vec_reduce_all(red_fun, acc_vec);
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
template <
|
| 344 |
+
typename scalar_t,
|
| 345 |
+
typename Op,
|
| 346 |
+
typename std::enable_if_t<
|
| 347 |
+
!detail::should_prefer_converting_through_float_v<scalar_t> &&
|
| 348 |
+
std::is_invocable_v<Op, vec::Vectorized<scalar_t>>,
|
| 349 |
+
int> = 0>
|
| 350 |
+
inline void map(
|
| 351 |
+
const Op& vec_fun,
|
| 352 |
+
scalar_t* output_data,
|
| 353 |
+
const scalar_t* input_data,
|
| 354 |
+
int64_t size) {
|
| 355 |
+
using Vec = vec::Vectorized<scalar_t>;
|
| 356 |
+
int64_t d = 0;
|
| 357 |
+
for (; d < size - (size % Vec::size()); d += Vec::size()) {
|
| 358 |
+
Vec output_vec = vec_fun(Vec::loadu(input_data + d));
|
| 359 |
+
output_vec.store(output_data + d);
|
| 360 |
+
}
|
| 361 |
+
if (size - d > 0) {
|
| 362 |
+
Vec output_vec = vec_fun(Vec::loadu(input_data + d, size - d));
|
| 363 |
+
output_vec.store(output_data + d, size - d);
|
| 364 |
+
}
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
template <
|
| 368 |
+
typename scalar_t,
|
| 369 |
+
typename Op,
|
| 370 |
+
typename std::enable_if_t<
|
| 371 |
+
!detail::should_prefer_converting_through_float_v<scalar_t> &&
|
| 372 |
+
std::is_invocable_v<
|
| 373 |
+
Op,
|
| 374 |
+
vec::Vectorized<scalar_t>,
|
| 375 |
+
vec::Vectorized<scalar_t>>,
|
| 376 |
+
int> = 0>
|
| 377 |
+
inline void map2(
|
| 378 |
+
const Op& vec_fun,
|
| 379 |
+
scalar_t* output_data,
|
| 380 |
+
const scalar_t* input_data,
|
| 381 |
+
const scalar_t* input_data2,
|
| 382 |
+
int64_t size) {
|
| 383 |
+
using Vec = vec::Vectorized<scalar_t>;
|
| 384 |
+
int64_t d = 0;
|
| 385 |
+
for (; d < size - (size % Vec::size()); d += Vec::size()) {
|
| 386 |
+
Vec data_vec = Vec::loadu(input_data + d);
|
| 387 |
+
Vec data_vec2 = Vec::loadu(input_data2 + d);
|
| 388 |
+
Vec output_vec = vec_fun(data_vec, data_vec2);
|
| 389 |
+
output_vec.store(output_data + d);
|
| 390 |
+
}
|
| 391 |
+
if (size - d > 0) {
|
| 392 |
+
Vec data_vec = Vec::loadu(input_data + d, size - d);
|
| 393 |
+
Vec data_vec2 = Vec::loadu(input_data2 + d, size - d);
|
| 394 |
+
Vec output_vec = vec_fun(data_vec, data_vec2);
|
| 395 |
+
output_vec.store(output_data + d, size - d);
|
| 396 |
+
}
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
template <
|
| 400 |
+
typename scalar_t,
|
| 401 |
+
typename Op,
|
| 402 |
+
typename std::enable_if_t<
|
| 403 |
+
!detail::should_prefer_converting_through_float_v<scalar_t> &&
|
| 404 |
+
std::is_invocable_v<
|
| 405 |
+
Op,
|
| 406 |
+
vec::Vectorized<scalar_t>,
|
| 407 |
+
vec::Vectorized<scalar_t>,
|
| 408 |
+
vec::Vectorized<scalar_t>>,
|
| 409 |
+
int> = 0>
|
| 410 |
+
inline void map3(
|
| 411 |
+
const Op& vec_fun,
|
| 412 |
+
scalar_t* output_data,
|
| 413 |
+
const scalar_t* input_data1,
|
| 414 |
+
const scalar_t* input_data2,
|
| 415 |
+
const scalar_t* input_data3,
|
| 416 |
+
int64_t size) {
|
| 417 |
+
using Vec = vec::Vectorized<scalar_t>;
|
| 418 |
+
int64_t d = 0;
|
| 419 |
+
for (; d < size - (size % Vec::size()); d += Vec::size()) {
|
| 420 |
+
Vec data_vec1 = Vec::loadu(input_data1 + d);
|
| 421 |
+
Vec data_vec2 = Vec::loadu(input_data2 + d);
|
| 422 |
+
Vec data_vec3 = Vec::loadu(input_data3 + d);
|
| 423 |
+
Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3);
|
| 424 |
+
output_vec.store(output_data + d);
|
| 425 |
+
}
|
| 426 |
+
if (size - d > 0) {
|
| 427 |
+
Vec data_vec1 = Vec::loadu(input_data1 + d, size - d);
|
| 428 |
+
Vec data_vec2 = Vec::loadu(input_data2 + d, size - d);
|
| 429 |
+
Vec data_vec3 = Vec::loadu(input_data3 + d, size - d);
|
| 430 |
+
Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3);
|
| 431 |
+
output_vec.store(output_data + d, size - d);
|
| 432 |
+
}
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
template <
|
| 436 |
+
typename scalar_t,
|
| 437 |
+
typename Op,
|
| 438 |
+
typename std::enable_if_t<
|
| 439 |
+
!detail::should_prefer_converting_through_float_v<scalar_t> &&
|
| 440 |
+
std::is_invocable_v<
|
| 441 |
+
Op,
|
| 442 |
+
vec::Vectorized<scalar_t>,
|
| 443 |
+
vec::Vectorized<scalar_t>,
|
| 444 |
+
vec::Vectorized<scalar_t>,
|
| 445 |
+
vec::Vectorized<scalar_t>>,
|
| 446 |
+
int> = 0>
|
| 447 |
+
inline void map4(
|
| 448 |
+
const Op& vec_fun,
|
| 449 |
+
scalar_t* output_data,
|
| 450 |
+
const scalar_t* input_data1,
|
| 451 |
+
const scalar_t* input_data2,
|
| 452 |
+
const scalar_t* input_data3,
|
| 453 |
+
const scalar_t* input_data4,
|
| 454 |
+
int64_t size) {
|
| 455 |
+
using Vec = vec::Vectorized<scalar_t>;
|
| 456 |
+
int64_t d = 0;
|
| 457 |
+
for (; d < size - (size % Vec::size()); d += Vec::size()) {
|
| 458 |
+
Vec data_vec1 = Vec::loadu(input_data1 + d);
|
| 459 |
+
Vec data_vec2 = Vec::loadu(input_data2 + d);
|
| 460 |
+
Vec data_vec3 = Vec::loadu(input_data3 + d);
|
| 461 |
+
Vec data_vec4 = Vec::loadu(input_data4 + d);
|
| 462 |
+
Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3, data_vec4);
|
| 463 |
+
output_vec.store(output_data + d);
|
| 464 |
+
}
|
| 465 |
+
if (size - d > 0) {
|
| 466 |
+
Vec data_vec1 = Vec::loadu(input_data1 + d, size - d);
|
| 467 |
+
Vec data_vec2 = Vec::loadu(input_data2 + d, size - d);
|
| 468 |
+
Vec data_vec3 = Vec::loadu(input_data3 + d, size - d);
|
| 469 |
+
Vec data_vec4 = Vec::loadu(input_data4 + d, size - d);
|
| 470 |
+
Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3, data_vec4);
|
| 471 |
+
output_vec.store(output_data + d, size - d);
|
| 472 |
+
}
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
} // namespace vec
|
| 476 |
+
} // namespace at
|
| 477 |
+
|
| 478 |
+
#else
|
| 479 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 480 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/functional_bfloat16.h
ADDED
|
@@ -0,0 +1,652 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
|
| 5 |
+
// See Note [Do not compile initializers with AVX]
|
| 6 |
+
|
| 7 |
+
#include <ATen/cpu/vec/vec.h>
|
| 8 |
+
|
| 9 |
+
namespace at::vec {
|
| 10 |
+
// BFloat16 specification
|
| 11 |
+
template <typename scalar_t>
|
| 12 |
+
struct VecScalarType {
|
| 13 |
+
using type = scalar_t;
|
| 14 |
+
};
|
| 15 |
+
template <>
|
| 16 |
+
struct VecScalarType<BFloat16> {
|
| 17 |
+
using type = float;
|
| 18 |
+
};
|
| 19 |
+
template <>
|
| 20 |
+
struct VecScalarType<Half> {
|
| 21 |
+
using type = float;
|
| 22 |
+
};
|
| 23 |
+
|
| 24 |
+
// This is different from at::acc_type since we only need to specialize BFloat16
|
| 25 |
+
template <typename scalar_t>
|
| 26 |
+
using vec_scalar_t = typename VecScalarType<scalar_t>::type;
|
| 27 |
+
|
| 28 |
+
// Vector conversion between float and bfloat16/half
|
| 29 |
+
template <>
|
| 30 |
+
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_to_float<
|
| 31 |
+
BFloat16>(const Vectorized<BFloat16>& a) {
|
| 32 |
+
return convert_bfloat16_float(a);
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
template <>
|
| 36 |
+
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_to_float<Half>(
|
| 37 |
+
const Vectorized<Half>& a) {
|
| 38 |
+
return convert_half_float(a);
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
template <>
|
| 42 |
+
inline Vectorized<BFloat16> convert_from_float<BFloat16>(
|
| 43 |
+
const Vectorized<float>& a,
|
| 44 |
+
const Vectorized<float>& b) {
|
| 45 |
+
return convert_float_bfloat16(a, b);
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
template <>
|
| 49 |
+
inline Vectorized<Half> convert_from_float<Half>(
|
| 50 |
+
const Vectorized<float>& a,
|
| 51 |
+
const Vectorized<float>& b) {
|
| 52 |
+
return convert_float_half(a, b);
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
template <
|
| 56 |
+
typename scalar_t,
|
| 57 |
+
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
| 58 |
+
inline void load_to_float(
|
| 59 |
+
const scalar_t* data,
|
| 60 |
+
Vectorized<float>& out1,
|
| 61 |
+
Vectorized<float>& out2);
|
| 62 |
+
|
| 63 |
+
template <>
|
| 64 |
+
inline void load_to_float<BFloat16>(
|
| 65 |
+
const BFloat16* data,
|
| 66 |
+
Vectorized<float>& out1,
|
| 67 |
+
Vectorized<float>& out2) {
|
| 68 |
+
load_fp32_from_bf16(data, out1, out2);
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
template <>
|
| 72 |
+
inline void load_to_float<Half>(
|
| 73 |
+
const Half* data,
|
| 74 |
+
Vectorized<float>& out1,
|
| 75 |
+
Vectorized<float>& out2) {
|
| 76 |
+
load_fp32_from_fp16(data, out1, out2);
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
template <
|
| 80 |
+
typename scalar_t,
|
| 81 |
+
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
| 82 |
+
inline void load_to_float(const scalar_t* data, Vectorized<float>& out);
|
| 83 |
+
|
| 84 |
+
template <>
|
| 85 |
+
inline void load_to_float<BFloat16>(
|
| 86 |
+
const BFloat16* data,
|
| 87 |
+
Vectorized<float>& out) {
|
| 88 |
+
load_fp32_from_bf16(data, out);
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
template <>
|
| 92 |
+
inline void load_to_float<Half>(const Half* data, Vectorized<float>& out) {
|
| 93 |
+
load_fp32_from_fp16(data, out);
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
// Note that we already have specialized member of Vectorized<scalar_t> for
|
| 97 |
+
// BFloat16 so the following functions would run smoothly:
|
| 98 |
+
// using Vec = Vectorized<BFloat16>;
|
| 99 |
+
// Vec one = Vec(BFloat16(1));
|
| 100 |
+
// vec::map([](Vec x) { return one / (one + x.exp()); }, y_ptr, x_ptr, N);
|
| 101 |
+
//
|
| 102 |
+
// Then why we still need to specialize "functional"?
|
| 103 |
+
// If we do specialization at Vectorized<> level, the above example would need
|
| 104 |
+
// 3 pairs of conversion of bf16->fp32/fp32->bf16, each for ".exp()", "+" and
|
| 105 |
+
// "/". If we do specialization at vec::map<>() level, we have only 1 pair of
|
| 106 |
+
// conversion of bf16->fp32/fp32->bf16, for the input and output BFloat16
|
| 107 |
+
// vector only.
|
| 108 |
+
//
|
| 109 |
+
// The following BFloat16 functionality will only do data type conversion for
|
| 110 |
+
// input and output vector (reduce functionality will only convert the final
|
| 111 |
+
// scalar back to bf16). Compared to Vectorized<> specialization,
|
| 112 |
+
// 1. better performance since we have less data type conversion;
|
| 113 |
+
// 2. less rounding error since immediate results are kept in fp32;
|
| 114 |
+
// 3. accumulation done on data type of fp32.
|
| 115 |
+
//
|
| 116 |
+
// If you plan to extend this file, please ensure adding unit tests at
|
| 117 |
+
// aten/src/ATen/test/vec_test_all_types.cpp
|
| 118 |
+
//
|
| 119 |
+
template <
|
| 120 |
+
typename scalar_t,
|
| 121 |
+
typename Op,
|
| 122 |
+
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
| 123 |
+
inline float reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) {
|
| 124 |
+
using bVec = vec::Vectorized<scalar_t>;
|
| 125 |
+
using fVec = vec::Vectorized<float>;
|
| 126 |
+
if (size < bVec::size()) {
|
| 127 |
+
bVec data_bvec = bVec::loadu(data, size);
|
| 128 |
+
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
|
| 129 |
+
if (size > fVec::size()) {
|
| 130 |
+
data_fvec0 = fVec::set(
|
| 131 |
+
data_fvec0, vec_fun(data_fvec0, data_fvec1), size - fVec::size());
|
| 132 |
+
return vec_reduce_all<float>(vec_fun, data_fvec0, fVec::size());
|
| 133 |
+
} else {
|
| 134 |
+
return vec_reduce_all<float>(vec_fun, data_fvec0, size);
|
| 135 |
+
}
|
| 136 |
+
}
|
| 137 |
+
int64_t d = bVec::size();
|
| 138 |
+
bVec acc_bvec = bVec::loadu(data);
|
| 139 |
+
auto [acc_fvec0, acc_fvec1] = convert_to_float<scalar_t>(acc_bvec);
|
| 140 |
+
for (; d < size - (size % bVec::size()); d += bVec::size()) {
|
| 141 |
+
bVec data_bvec = bVec::loadu(data + d);
|
| 142 |
+
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
|
| 143 |
+
acc_fvec0 = vec_fun(acc_fvec0, data_fvec0);
|
| 144 |
+
acc_fvec1 = vec_fun(acc_fvec1, data_fvec1);
|
| 145 |
+
}
|
| 146 |
+
if (size - d > 0) {
|
| 147 |
+
bVec data_bvec = bVec::loadu(data + d, size - d);
|
| 148 |
+
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
|
| 149 |
+
if (size - d > fVec::size()) {
|
| 150 |
+
acc_fvec0 = vec_fun(acc_fvec0, data_fvec0);
|
| 151 |
+
acc_fvec1 = fVec::set(
|
| 152 |
+
acc_fvec1, vec_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
|
| 153 |
+
} else {
|
| 154 |
+
acc_fvec0 =
|
| 155 |
+
fVec::set(acc_fvec0, vec_fun(acc_fvec0, data_fvec0), size - d);
|
| 156 |
+
}
|
| 157 |
+
}
|
| 158 |
+
acc_fvec0 = vec_fun(acc_fvec0, acc_fvec1);
|
| 159 |
+
return vec_reduce_all<float>(vec_fun, acc_fvec0);
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
template <
|
| 163 |
+
typename scalar_t,
|
| 164 |
+
typename Op1,
|
| 165 |
+
typename Op2,
|
| 166 |
+
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
| 167 |
+
inline std::pair<float, float> reduce2_all(
|
| 168 |
+
const Op1& vec_fun1,
|
| 169 |
+
const Op2& vec_fun2,
|
| 170 |
+
const scalar_t* data,
|
| 171 |
+
int64_t size) {
|
| 172 |
+
using bVec = vec::Vectorized<scalar_t>;
|
| 173 |
+
using fVec = vec::Vectorized<float>;
|
| 174 |
+
if (size < bVec::size()) {
|
| 175 |
+
bVec data_bvec = bVec::loadu(data, size);
|
| 176 |
+
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
|
| 177 |
+
if (size > fVec::size()) {
|
| 178 |
+
fVec acc1_fvec = fVec::set(
|
| 179 |
+
data_fvec0, vec_fun1(data_fvec0, data_fvec1), size - fVec::size());
|
| 180 |
+
fVec acc2_fvec = fVec::set(
|
| 181 |
+
data_fvec0, vec_fun2(data_fvec0, data_fvec1), size - fVec::size());
|
| 182 |
+
return std::pair<scalar_t, scalar_t>(
|
| 183 |
+
vec_reduce_all<float>(vec_fun1, acc1_fvec, fVec::size()),
|
| 184 |
+
vec_reduce_all<float>(vec_fun2, acc2_fvec, fVec::size()));
|
| 185 |
+
} else {
|
| 186 |
+
return std::pair<scalar_t, scalar_t>(
|
| 187 |
+
vec_reduce_all<float>(vec_fun1, data_fvec0, size),
|
| 188 |
+
vec_reduce_all<float>(vec_fun2, data_fvec0, size));
|
| 189 |
+
}
|
| 190 |
+
}
|
| 191 |
+
int64_t d = bVec::size();
|
| 192 |
+
bVec acc_bvec = bVec::loadu(data);
|
| 193 |
+
auto [acc1_fvec0, acc1_fvec1] = convert_to_float<scalar_t>(acc_bvec);
|
| 194 |
+
auto [acc2_fvec0, acc2_fvec1] = convert_to_float<scalar_t>(acc_bvec);
|
| 195 |
+
for (; d < size - (size % bVec::size()); d += bVec::size()) {
|
| 196 |
+
bVec data_bvec = bVec::loadu(data + d);
|
| 197 |
+
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
|
| 198 |
+
acc1_fvec0 = vec_fun1(acc1_fvec0, data_fvec0);
|
| 199 |
+
acc1_fvec1 = vec_fun1(acc1_fvec1, data_fvec1);
|
| 200 |
+
acc2_fvec0 = vec_fun2(acc2_fvec0, data_fvec0);
|
| 201 |
+
acc2_fvec1 = vec_fun2(acc2_fvec1, data_fvec1);
|
| 202 |
+
}
|
| 203 |
+
if (size - d > 0) {
|
| 204 |
+
bVec data_bvec = bVec::loadu(data + d, size - d);
|
| 205 |
+
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
|
| 206 |
+
if (size - d > fVec::size()) {
|
| 207 |
+
acc1_fvec0 = vec_fun1(acc1_fvec0, data_fvec0);
|
| 208 |
+
acc1_fvec1 = fVec::set(
|
| 209 |
+
acc1_fvec1,
|
| 210 |
+
vec_fun1(acc1_fvec1, data_fvec1),
|
| 211 |
+
size - d - fVec::size());
|
| 212 |
+
acc2_fvec0 = vec_fun2(acc2_fvec0, data_fvec0);
|
| 213 |
+
acc2_fvec1 = fVec::set(
|
| 214 |
+
acc2_fvec1,
|
| 215 |
+
vec_fun2(acc2_fvec1, data_fvec1),
|
| 216 |
+
size - d - fVec::size());
|
| 217 |
+
} else {
|
| 218 |
+
acc1_fvec0 =
|
| 219 |
+
fVec::set(acc1_fvec0, vec_fun1(acc1_fvec0, data_fvec0), size - d);
|
| 220 |
+
acc2_fvec0 =
|
| 221 |
+
fVec::set(acc2_fvec0, vec_fun2(acc2_fvec0, data_fvec0), size - d);
|
| 222 |
+
}
|
| 223 |
+
}
|
| 224 |
+
acc1_fvec0 = vec_fun1(acc1_fvec0, acc1_fvec1);
|
| 225 |
+
acc2_fvec0 = vec_fun2(acc2_fvec0, acc2_fvec1);
|
| 226 |
+
return std::pair<scalar_t, scalar_t>(
|
| 227 |
+
vec_reduce_all<float>(vec_fun1, acc1_fvec0),
|
| 228 |
+
vec_reduce_all<float>(vec_fun2, acc2_fvec0));
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
template <
|
| 232 |
+
typename scalar_t,
|
| 233 |
+
typename MapOp,
|
| 234 |
+
typename ReduceOp,
|
| 235 |
+
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
| 236 |
+
inline float map_reduce_all(
|
| 237 |
+
const MapOp& map_fun,
|
| 238 |
+
const ReduceOp& red_fun,
|
| 239 |
+
const scalar_t* data,
|
| 240 |
+
int64_t size) {
|
| 241 |
+
using bVec = vec::Vectorized<scalar_t>;
|
| 242 |
+
using fVec = vec::Vectorized<float>;
|
| 243 |
+
if (size < bVec::size()) {
|
| 244 |
+
bVec data_bvec = bVec::loadu(data, size);
|
| 245 |
+
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
|
| 246 |
+
if (size > fVec::size()) {
|
| 247 |
+
data_fvec0 = map_fun(data_fvec0);
|
| 248 |
+
data_fvec1 = map_fun(data_fvec1);
|
| 249 |
+
data_fvec0 = fVec::set(
|
| 250 |
+
data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size());
|
| 251 |
+
return vec_reduce_all<float>(red_fun, data_fvec0, fVec::size());
|
| 252 |
+
} else {
|
| 253 |
+
data_fvec0 = map_fun(data_fvec0);
|
| 254 |
+
return vec_reduce_all<float>(red_fun, data_fvec0, size);
|
| 255 |
+
}
|
| 256 |
+
}
|
| 257 |
+
int64_t d = bVec::size();
|
| 258 |
+
bVec acc_bvec = bVec::loadu(data);
|
| 259 |
+
auto [acc_fvec0, acc_fvec1] = convert_to_float<scalar_t>(acc_bvec);
|
| 260 |
+
acc_fvec0 = map_fun(acc_fvec0);
|
| 261 |
+
acc_fvec1 = map_fun(acc_fvec1);
|
| 262 |
+
for (; d < size - (size % bVec::size()); d += bVec::size()) {
|
| 263 |
+
bVec data_bvec = bVec::loadu(data + d);
|
| 264 |
+
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
|
| 265 |
+
data_fvec0 = map_fun(data_fvec0);
|
| 266 |
+
data_fvec1 = map_fun(data_fvec1);
|
| 267 |
+
acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
|
| 268 |
+
acc_fvec1 = red_fun(acc_fvec1, data_fvec1);
|
| 269 |
+
}
|
| 270 |
+
if (size - d > 0) {
|
| 271 |
+
bVec data_bvec = bVec::loadu(data + d, size - d);
|
| 272 |
+
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
|
| 273 |
+
if (size - d > fVec::size()) {
|
| 274 |
+
data_fvec0 = map_fun(data_fvec0);
|
| 275 |
+
data_fvec1 = map_fun(data_fvec1);
|
| 276 |
+
acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
|
| 277 |
+
acc_fvec1 = fVec::set(
|
| 278 |
+
acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
|
| 279 |
+
} else {
|
| 280 |
+
data_fvec0 = map_fun(data_fvec0);
|
| 281 |
+
acc_fvec0 =
|
| 282 |
+
fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d);
|
| 283 |
+
}
|
| 284 |
+
}
|
| 285 |
+
acc_fvec0 = red_fun(acc_fvec0, acc_fvec1);
|
| 286 |
+
return vec_reduce_all<float>(red_fun, acc_fvec0);
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
template <
|
| 290 |
+
typename scalar_t,
|
| 291 |
+
typename MapOp,
|
| 292 |
+
typename ReduceOp,
|
| 293 |
+
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
| 294 |
+
inline float map2_reduce_all(
|
| 295 |
+
const MapOp& map_fun,
|
| 296 |
+
const ReduceOp& red_fun,
|
| 297 |
+
const scalar_t* data,
|
| 298 |
+
const scalar_t* data2,
|
| 299 |
+
int64_t size) {
|
| 300 |
+
using bVec = vec::Vectorized<scalar_t>;
|
| 301 |
+
using fVec = vec::Vectorized<float>;
|
| 302 |
+
if (size < bVec::size()) {
|
| 303 |
+
bVec data_bvec = bVec::loadu(data, size);
|
| 304 |
+
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
|
| 305 |
+
bVec data2_bvec = bVec::loadu(data2, size);
|
| 306 |
+
auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
|
| 307 |
+
if (size > fVec::size()) {
|
| 308 |
+
data_fvec0 = map_fun(data_fvec0, data2_fvec0);
|
| 309 |
+
data_fvec1 = map_fun(data_fvec1, data2_fvec1);
|
| 310 |
+
data_fvec0 = fVec::set(
|
| 311 |
+
data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size());
|
| 312 |
+
return vec_reduce_all<float>(red_fun, data_fvec0, fVec::size());
|
| 313 |
+
} else {
|
| 314 |
+
data_fvec0 = map_fun(data_fvec0, data2_fvec0);
|
| 315 |
+
return vec_reduce_all<float>(red_fun, data_fvec0, size);
|
| 316 |
+
}
|
| 317 |
+
}
|
| 318 |
+
int64_t d = bVec::size();
|
| 319 |
+
bVec acc_bvec = bVec::loadu(data);
|
| 320 |
+
auto [acc_fvec0, acc_fvec1] = convert_to_float<scalar_t>(acc_bvec);
|
| 321 |
+
bVec acc2_bvec = bVec::loadu(data2);
|
| 322 |
+
auto [acc2_fvec0, acc2_fvec1] = convert_to_float<scalar_t>(acc2_bvec);
|
| 323 |
+
acc_fvec0 = map_fun(acc_fvec0, acc2_fvec0);
|
| 324 |
+
acc_fvec1 = map_fun(acc_fvec1, acc2_fvec1);
|
| 325 |
+
for (; d < size - (size % bVec::size()); d += bVec::size()) {
|
| 326 |
+
bVec data_bvec = bVec::loadu(data + d);
|
| 327 |
+
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
|
| 328 |
+
bVec data2_bvec = bVec::loadu(data2 + d);
|
| 329 |
+
auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
|
| 330 |
+
data_fvec0 = map_fun(data_fvec0, data2_fvec0);
|
| 331 |
+
data_fvec1 = map_fun(data_fvec1, data2_fvec1);
|
| 332 |
+
acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
|
| 333 |
+
acc_fvec1 = red_fun(acc_fvec1, data_fvec1);
|
| 334 |
+
}
|
| 335 |
+
if (size - d > 0) {
|
| 336 |
+
bVec data_bvec = bVec::loadu(data + d, size - d);
|
| 337 |
+
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
|
| 338 |
+
bVec data2_bvec = bVec::loadu(data2 + d, size - d);
|
| 339 |
+
auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
|
| 340 |
+
if (size - d > fVec::size()) {
|
| 341 |
+
data_fvec0 = map_fun(data_fvec0, data2_fvec0);
|
| 342 |
+
data_fvec1 = map_fun(data_fvec1, data2_fvec1);
|
| 343 |
+
acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
|
| 344 |
+
acc_fvec1 = fVec::set(
|
| 345 |
+
acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
|
| 346 |
+
} else {
|
| 347 |
+
data_fvec0 = map_fun(data_fvec0, data2_fvec0);
|
| 348 |
+
acc_fvec0 =
|
| 349 |
+
fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d);
|
| 350 |
+
}
|
| 351 |
+
}
|
| 352 |
+
acc_fvec0 = red_fun(acc_fvec0, acc_fvec1);
|
| 353 |
+
return vec_reduce_all<float>(red_fun, acc_fvec0);
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
template <
|
| 357 |
+
typename scalar_t,
|
| 358 |
+
typename MapOp,
|
| 359 |
+
typename ReduceOp,
|
| 360 |
+
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
| 361 |
+
inline float map3_reduce_all(
|
| 362 |
+
const MapOp& map_fun,
|
| 363 |
+
const ReduceOp& red_fun,
|
| 364 |
+
const scalar_t* data,
|
| 365 |
+
const scalar_t* data2,
|
| 366 |
+
const scalar_t* data3,
|
| 367 |
+
int64_t size) {
|
| 368 |
+
using bVec = vec::Vectorized<scalar_t>;
|
| 369 |
+
using fVec = vec::Vectorized<float>;
|
| 370 |
+
if (size < bVec::size()) {
|
| 371 |
+
bVec data_bvec = bVec::loadu(data, size);
|
| 372 |
+
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
|
| 373 |
+
bVec data2_bvec = bVec::loadu(data2, size);
|
| 374 |
+
auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
|
| 375 |
+
bVec data3_bvec = bVec::loadu(data3, size);
|
| 376 |
+
auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
|
| 377 |
+
if (size > fVec::size()) {
|
| 378 |
+
data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
|
| 379 |
+
data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1);
|
| 380 |
+
data_fvec0 = fVec::set(
|
| 381 |
+
data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size());
|
| 382 |
+
return vec_reduce_all<float>(red_fun, data_fvec0, fVec::size());
|
| 383 |
+
} else {
|
| 384 |
+
data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
|
| 385 |
+
return vec_reduce_all<float>(red_fun, data_fvec0, size);
|
| 386 |
+
}
|
| 387 |
+
}
|
| 388 |
+
int64_t d = bVec::size();
|
| 389 |
+
bVec acc_bvec = bVec::loadu(data);
|
| 390 |
+
auto [acc_fvec0, acc_fvec1] = convert_to_float<scalar_t>(acc_bvec);
|
| 391 |
+
bVec acc2_bvec = bVec::loadu(data2);
|
| 392 |
+
auto [acc2_fvec0, acc2_fvec1] = convert_to_float<scalar_t>(acc2_bvec);
|
| 393 |
+
bVec acc3_bvec = bVec::loadu(data3);
|
| 394 |
+
auto [acc3_fvec0, acc3_fvec1] = convert_to_float<scalar_t>(acc3_bvec);
|
| 395 |
+
acc_fvec0 = map_fun(acc_fvec0, acc2_fvec0, acc3_fvec0);
|
| 396 |
+
acc_fvec1 = map_fun(acc_fvec1, acc2_fvec1, acc3_fvec1);
|
| 397 |
+
for (; d < size - (size % bVec::size()); d += bVec::size()) {
|
| 398 |
+
bVec data_bvec = bVec::loadu(data + d);
|
| 399 |
+
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
|
| 400 |
+
bVec data2_bvec = bVec::loadu(data2 + d);
|
| 401 |
+
auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
|
| 402 |
+
bVec data3_bvec = bVec::loadu(data3 + d);
|
| 403 |
+
auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
|
| 404 |
+
data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
|
| 405 |
+
data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1);
|
| 406 |
+
acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
|
| 407 |
+
acc_fvec1 = red_fun(acc_fvec1, data_fvec1);
|
| 408 |
+
}
|
| 409 |
+
if (size - d > 0) {
|
| 410 |
+
bVec data_bvec = bVec::loadu(data + d, size - d);
|
| 411 |
+
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
|
| 412 |
+
bVec data2_bvec = bVec::loadu(data2 + d, size - d);
|
| 413 |
+
auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
|
| 414 |
+
bVec data3_bvec = bVec::loadu(data3 + d, size - d);
|
| 415 |
+
auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
|
| 416 |
+
if (size - d > fVec::size()) {
|
| 417 |
+
data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
|
| 418 |
+
data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1);
|
| 419 |
+
acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
|
| 420 |
+
acc_fvec1 = fVec::set(
|
| 421 |
+
acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
|
| 422 |
+
} else {
|
| 423 |
+
data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
|
| 424 |
+
acc_fvec0 =
|
| 425 |
+
fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d);
|
| 426 |
+
}
|
| 427 |
+
}
|
| 428 |
+
acc_fvec0 = red_fun(acc_fvec0, acc_fvec1);
|
| 429 |
+
return vec_reduce_all<float>(red_fun, acc_fvec0);
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
template <
|
| 433 |
+
typename scalar_t,
|
| 434 |
+
typename Op,
|
| 435 |
+
typename std::enable_if_t<
|
| 436 |
+
!(!detail::should_prefer_converting_through_float_v<scalar_t> &&
|
| 437 |
+
std::is_invocable_v<Op, vec::Vectorized<scalar_t>>),
|
| 438 |
+
int> = 0>
|
| 439 |
+
inline void map(
|
| 440 |
+
const Op& vec_fun,
|
| 441 |
+
scalar_t* output_data,
|
| 442 |
+
const scalar_t* input_data,
|
| 443 |
+
int64_t size) {
|
| 444 |
+
using bVec = vec::Vectorized<scalar_t>;
|
| 445 |
+
using fVec = vec::Vectorized<float>;
|
| 446 |
+
int64_t d = 0;
|
| 447 |
+
for (; d < size - (size % bVec::size()); d += bVec::size()) {
|
| 448 |
+
bVec data_bvec = bVec::loadu(input_data + d);
|
| 449 |
+
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
|
| 450 |
+
fVec output_fvec0 = vec_fun(data_fvec0);
|
| 451 |
+
fVec output_fvec1 = vec_fun(data_fvec1);
|
| 452 |
+
bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
|
| 453 |
+
output_bvec.store(output_data + d);
|
| 454 |
+
}
|
| 455 |
+
if (size - d > 0) {
|
| 456 |
+
bVec data_bvec = bVec::loadu(input_data + d, size - d);
|
| 457 |
+
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
|
| 458 |
+
fVec output_fvec0 = vec_fun(data_fvec0);
|
| 459 |
+
fVec output_fvec1 = vec_fun(data_fvec1);
|
| 460 |
+
bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
|
| 461 |
+
output_bvec.store(output_data + d, size - d);
|
| 462 |
+
}
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
template <
|
| 466 |
+
typename scalar_t,
|
| 467 |
+
typename Op,
|
| 468 |
+
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
| 469 |
+
inline void map(
|
| 470 |
+
const Op& vec_fun,
|
| 471 |
+
scalar_t* output_data,
|
| 472 |
+
const float* input_data,
|
| 473 |
+
int64_t size) {
|
| 474 |
+
using bVec = vec::Vectorized<scalar_t>;
|
| 475 |
+
using fVec = vec::Vectorized<float>;
|
| 476 |
+
int64_t d = 0;
|
| 477 |
+
for (; d < size - (size % bVec::size()); d += bVec::size()) {
|
| 478 |
+
fVec data_fvec0 = fVec::loadu(input_data + d);
|
| 479 |
+
fVec data_fvec1 = fVec::loadu(input_data + d + fVec::size());
|
| 480 |
+
fVec output_fvec0 = vec_fun(data_fvec0);
|
| 481 |
+
fVec output_fvec1 = vec_fun(data_fvec1);
|
| 482 |
+
bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
|
| 483 |
+
output_bvec.store(output_data + d);
|
| 484 |
+
}
|
| 485 |
+
if (size - d > 0) {
|
| 486 |
+
fVec data_fvec0, data_fvec1;
|
| 487 |
+
if (size - d > fVec::size()) {
|
| 488 |
+
data_fvec0 = fVec::loadu(input_data + d);
|
| 489 |
+
data_fvec1 =
|
| 490 |
+
fVec::loadu(input_data + d + fVec::size(), size - d - fVec::size());
|
| 491 |
+
} else {
|
| 492 |
+
// choose to align with behaviour of bVec::loadu(ptr, size),
|
| 493 |
+
// which leaves data_fvec1 uninitialized
|
| 494 |
+
data_fvec0 = fVec::loadu(input_data + d, size - d);
|
| 495 |
+
}
|
| 496 |
+
fVec output_fvec0 = vec_fun(data_fvec0);
|
| 497 |
+
fVec output_fvec1 = vec_fun(data_fvec1);
|
| 498 |
+
bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
|
| 499 |
+
output_bvec.store(output_data + d, size - d);
|
| 500 |
+
}
|
| 501 |
+
}
|
| 502 |
+
|
| 503 |
+
template <
|
| 504 |
+
typename scalar_t,
|
| 505 |
+
typename Op,
|
| 506 |
+
typename std::enable_if_t<
|
| 507 |
+
!(!detail::should_prefer_converting_through_float_v<scalar_t> &&
|
| 508 |
+
std::is_invocable_v<
|
| 509 |
+
Op,
|
| 510 |
+
vec::Vectorized<scalar_t>,
|
| 511 |
+
vec::Vectorized<scalar_t>>),
|
| 512 |
+
int> = 0>
|
| 513 |
+
inline void map2(
|
| 514 |
+
const Op& vec_fun,
|
| 515 |
+
scalar_t* output_data,
|
| 516 |
+
const scalar_t* input_data,
|
| 517 |
+
const scalar_t* input_data2,
|
| 518 |
+
int64_t size) {
|
| 519 |
+
using bVec = vec::Vectorized<scalar_t>;
|
| 520 |
+
using fVec = vec::Vectorized<float>;
|
| 521 |
+
int64_t d = 0;
|
| 522 |
+
for (; d < size - (size % bVec::size()); d += bVec::size()) {
|
| 523 |
+
bVec data_bvec = bVec::loadu(input_data + d);
|
| 524 |
+
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
|
| 525 |
+
bVec data2_bvec = bVec::loadu(input_data2 + d);
|
| 526 |
+
auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
|
| 527 |
+
fVec output_fvec0 = vec_fun(data_fvec0, data2_fvec0);
|
| 528 |
+
fVec output_fvec1 = vec_fun(data_fvec1, data2_fvec1);
|
| 529 |
+
bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
|
| 530 |
+
output_bvec.store(output_data + d);
|
| 531 |
+
}
|
| 532 |
+
if (size - d > 0) {
|
| 533 |
+
bVec data_bvec = bVec::loadu(input_data + d, size - d);
|
| 534 |
+
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
|
| 535 |
+
bVec data2_bvec = bVec::loadu(input_data2 + d, size - d);
|
| 536 |
+
auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
|
| 537 |
+
fVec output_fvec0 = vec_fun(data_fvec0, data2_fvec0);
|
| 538 |
+
fVec output_fvec1 = vec_fun(data_fvec1, data2_fvec1);
|
| 539 |
+
bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
|
| 540 |
+
output_bvec.store(output_data + d, size - d);
|
| 541 |
+
}
|
| 542 |
+
}
|
| 543 |
+
|
| 544 |
+
template <
|
| 545 |
+
typename scalar_t,
|
| 546 |
+
typename Op,
|
| 547 |
+
typename std::enable_if_t<
|
| 548 |
+
!(!detail::should_prefer_converting_through_float_v<scalar_t> &&
|
| 549 |
+
std::is_invocable_v<
|
| 550 |
+
Op,
|
| 551 |
+
vec::Vectorized<scalar_t>,
|
| 552 |
+
vec::Vectorized<scalar_t>,
|
| 553 |
+
vec::Vectorized<scalar_t>>),
|
| 554 |
+
int> = 0>
|
| 555 |
+
inline void map3(
|
| 556 |
+
const Op& vec_fun,
|
| 557 |
+
scalar_t* output_data,
|
| 558 |
+
const scalar_t* input_data1,
|
| 559 |
+
const scalar_t* input_data2,
|
| 560 |
+
const scalar_t* input_data3,
|
| 561 |
+
int64_t size) {
|
| 562 |
+
using bVec = vec::Vectorized<scalar_t>;
|
| 563 |
+
using fVec = vec::Vectorized<float>;
|
| 564 |
+
int64_t d = 0;
|
| 565 |
+
for (; d < size - (size % bVec::size()); d += bVec::size()) {
|
| 566 |
+
bVec data1_bvec = bVec::loadu(input_data1 + d);
|
| 567 |
+
auto [data1_fvec0, data1_fvec1] = convert_to_float<scalar_t>(data1_bvec);
|
| 568 |
+
bVec data2_bvec = bVec::loadu(input_data2 + d);
|
| 569 |
+
auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
|
| 570 |
+
bVec data3_bvec = bVec::loadu(input_data3 + d);
|
| 571 |
+
auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
|
| 572 |
+
fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0);
|
| 573 |
+
fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1);
|
| 574 |
+
bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
|
| 575 |
+
output_bvec.store(output_data + d);
|
| 576 |
+
}
|
| 577 |
+
if (size - d > 0) {
|
| 578 |
+
bVec data1_bvec = bVec::loadu(input_data1 + d, size - d);
|
| 579 |
+
auto [data1_fvec0, data1_fvec1] = convert_to_float<scalar_t>(data1_bvec);
|
| 580 |
+
bVec data2_bvec = bVec::loadu(input_data2 + d, size - d);
|
| 581 |
+
auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
|
| 582 |
+
bVec data3_bvec = bVec::loadu(input_data3 + d, size - d);
|
| 583 |
+
auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
|
| 584 |
+
fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0);
|
| 585 |
+
fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1);
|
| 586 |
+
bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
|
| 587 |
+
output_bvec.store(output_data + d, size - d);
|
| 588 |
+
}
|
| 589 |
+
}
|
| 590 |
+
|
| 591 |
+
template <
|
| 592 |
+
typename scalar_t,
|
| 593 |
+
typename Op,
|
| 594 |
+
typename std::enable_if_t<
|
| 595 |
+
!(!detail::should_prefer_converting_through_float_v<scalar_t> &&
|
| 596 |
+
std::is_invocable_v<
|
| 597 |
+
Op,
|
| 598 |
+
vec::Vectorized<scalar_t>,
|
| 599 |
+
vec::Vectorized<scalar_t>,
|
| 600 |
+
vec::Vectorized<scalar_t>,
|
| 601 |
+
vec::Vectorized<scalar_t>>),
|
| 602 |
+
int> = 0>
|
| 603 |
+
inline void map4(
|
| 604 |
+
const Op& vec_fun,
|
| 605 |
+
scalar_t* output_data,
|
| 606 |
+
const scalar_t* input_data1,
|
| 607 |
+
const scalar_t* input_data2,
|
| 608 |
+
const scalar_t* input_data3,
|
| 609 |
+
const scalar_t* input_data4,
|
| 610 |
+
int64_t size) {
|
| 611 |
+
using bVec = vec::Vectorized<scalar_t>;
|
| 612 |
+
using fVec = vec::Vectorized<float>;
|
| 613 |
+
int64_t d = 0;
|
| 614 |
+
for (; d < size - (size % bVec::size()); d += bVec::size()) {
|
| 615 |
+
bVec data1_bvec = bVec::loadu(input_data1 + d);
|
| 616 |
+
auto [data1_fvec0, data1_fvec1] = convert_to_float<scalar_t>(data1_bvec);
|
| 617 |
+
bVec data2_bvec = bVec::loadu(input_data2 + d);
|
| 618 |
+
auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
|
| 619 |
+
bVec data3_bvec = bVec::loadu(input_data3 + d);
|
| 620 |
+
auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
|
| 621 |
+
bVec data4_bvec = bVec::loadu(input_data4 + d);
|
| 622 |
+
auto [data4_fvec0, data4_fvec1] = convert_to_float<scalar_t>(data4_bvec);
|
| 623 |
+
fVec output_fvec0 =
|
| 624 |
+
vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0);
|
| 625 |
+
fVec output_fvec1 =
|
| 626 |
+
vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1);
|
| 627 |
+
bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
|
| 628 |
+
output_bvec.store(output_data + d);
|
| 629 |
+
}
|
| 630 |
+
if (size - d > 0) {
|
| 631 |
+
bVec data1_bvec = bVec::loadu(input_data1 + d, size - d);
|
| 632 |
+
auto [data1_fvec0, data1_fvec1] = convert_to_float<scalar_t>(data1_bvec);
|
| 633 |
+
bVec data2_bvec = bVec::loadu(input_data2 + d, size - d);
|
| 634 |
+
auto [data2_fvec0, data2_fvec1] = convert_to_float<scalar_t>(data2_bvec);
|
| 635 |
+
bVec data3_bvec = bVec::loadu(input_data3 + d, size - d);
|
| 636 |
+
auto [data3_fvec0, data3_fvec1] = convert_to_float<scalar_t>(data3_bvec);
|
| 637 |
+
bVec data4_bvec = bVec::loadu(input_data4 + d, size - d);
|
| 638 |
+
auto [data4_fvec0, data4_fvec1] = convert_to_float<scalar_t>(data4_bvec);
|
| 639 |
+
fVec output_fvec0 =
|
| 640 |
+
vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0);
|
| 641 |
+
fVec output_fvec1 =
|
| 642 |
+
vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1);
|
| 643 |
+
bVec output_bvec = convert_from_float<scalar_t>(output_fvec0, output_fvec1);
|
| 644 |
+
output_bvec.store(output_data + d, size - d);
|
| 645 |
+
}
|
| 646 |
+
}
|
| 647 |
+
|
| 648 |
+
} // namespace at::vec
|
| 649 |
+
|
| 650 |
+
#else
|
| 651 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 652 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/intrinsics.h
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#include <torch/headeronly/cpu/vec/intrinsics.h>
|
| 3 |
+
|
| 4 |
+
#else
|
| 5 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 6 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec.h
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#if defined(CPU_CAPABILITY_AVX512)
|
| 5 |
+
#include <ATen/cpu/vec/vec512/vec512.h>
|
| 6 |
+
#else
|
| 7 |
+
#include <ATen/cpu/vec/vec128/vec128.h>
|
| 8 |
+
#include <ATen/cpu/vec/vec256/vec256.h>
|
| 9 |
+
#endif
|
| 10 |
+
|
| 11 |
+
namespace at::vec {
|
| 12 |
+
// See Note [CPU_CAPABILITY namespace]
|
| 13 |
+
inline namespace CPU_CAPABILITY {
|
| 14 |
+
|
| 15 |
+
inline Vectorized<bool> convert_to_bool(Vectorized<int8_t> x) {
|
| 16 |
+
__at_align__ bool buffer[x.size()];
|
| 17 |
+
x.ne(Vectorized<int8_t>(0)).store(buffer);
|
| 18 |
+
|
| 19 |
+
Vectorized<bool> ret;
|
| 20 |
+
static_assert(x.size() == ret.size());
|
| 21 |
+
std::memcpy(ret, buffer, ret.size() * sizeof(bool));
|
| 22 |
+
return ret;
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
template <>
|
| 26 |
+
inline Vectorized<bool> Vectorized<bool>::loadu(const void* ptr) {
|
| 27 |
+
// See NOTE [Loading boolean values]
|
| 28 |
+
return convert_to_bool(Vectorized<int8_t>::loadu(ptr));
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
template <>
|
| 32 |
+
inline Vectorized<bool> Vectorized<bool>::loadu(
|
| 33 |
+
const void* ptr,
|
| 34 |
+
int64_t count) {
|
| 35 |
+
// See NOTE [Loading boolean values]
|
| 36 |
+
return convert_to_bool(Vectorized<int8_t>::loadu(ptr, count));
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
template <typename VT>
|
| 40 |
+
struct VecHoldType {
|
| 41 |
+
using hold_type = typename VT::value_type;
|
| 42 |
+
};
|
| 43 |
+
|
| 44 |
+
template <>
|
| 45 |
+
struct VecHoldType<Vectorized<BFloat16>> {
|
| 46 |
+
using hold_type = BFloat16;
|
| 47 |
+
};
|
| 48 |
+
|
| 49 |
+
template <>
|
| 50 |
+
struct VecHoldType<Vectorized<Half>> {
|
| 51 |
+
using hold_type = Half;
|
| 52 |
+
};
|
| 53 |
+
|
| 54 |
+
template <typename VT>
|
| 55 |
+
using vechold_type = typename VecHoldType<VT>::hold_type;
|
| 56 |
+
|
| 57 |
+
} // namespace CPU_CAPABILITY
|
| 58 |
+
} // namespace at::vec
|
| 59 |
+
|
| 60 |
+
#else
|
| 61 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 62 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_half_neon.h
ADDED
|
@@ -0,0 +1,627 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
|
| 5 |
+
// See Note [Do not compile initializers with AVX]
|
| 6 |
+
|
| 7 |
+
#include <ATen/cpu/vec/intrinsics.h>
|
| 8 |
+
#include <ATen/cpu/vec/vec128/vec128_convert.h>
|
| 9 |
+
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
|
| 10 |
+
#include <ATen/cpu/vec/vec128/vec128_reduced_precision_common_neon.h>
|
| 11 |
+
#include <ATen/cpu/vec/vec_base.h>
|
| 12 |
+
#include <c10/util/Half.h>
|
| 13 |
+
#include <c10/util/irange.h>
|
| 14 |
+
|
| 15 |
+
namespace at::vec {
|
| 16 |
+
// See Note [CPU_CAPABILITY namespace]
|
| 17 |
+
inline namespace CPU_CAPABILITY {
|
| 18 |
+
|
| 19 |
+
// Right now contains only aarch64 implementation.
|
| 20 |
+
// Due to follow two reasons aarch32 is not currently supported.
|
| 21 |
+
// 1. Due to difference in ISA been aarch32 and aarch64, intrinsics
|
| 22 |
+
// that work for aarch64 dont work for aarch32.
|
| 23 |
+
// 2. Android NDK r21 has problems with compiling aarch32.
|
| 24 |
+
// Clang seg faults.
|
| 25 |
+
// https://github.com/android/ndk/issues/1248
|
| 26 |
+
// https://bugs.llvm.org/show_bug.cgi?id=45824
|
| 27 |
+
// Most likely we will do aarch32 support with inline asm.
|
| 28 |
+
#if !defined(C10_MOBILE) && defined(__aarch64__)
|
| 29 |
+
|
| 30 |
+
#ifdef __BIG_ENDIAN__
|
| 31 |
+
#error "Big endian is not supported."
|
| 32 |
+
#endif
|
| 33 |
+
|
| 34 |
+
template <int index, bool mask_val>
|
| 35 |
+
struct BlendHalfRegs {
|
| 36 |
+
static float16x8_t impl(
|
| 37 |
+
const float16x8_t& a,
|
| 38 |
+
const float16x8_t& b,
|
| 39 |
+
float16x8_t& res);
|
| 40 |
+
};
|
| 41 |
+
|
| 42 |
+
template <int index>
|
| 43 |
+
struct BlendHalfRegs<index, true> {
|
| 44 |
+
static float16x8_t impl(
|
| 45 |
+
const float16x8_t& a,
|
| 46 |
+
const float16x8_t& b,
|
| 47 |
+
float16x8_t& res) {
|
| 48 |
+
return vsetq_lane_f16(vgetq_lane_f16(b, index), res, index);
|
| 49 |
+
}
|
| 50 |
+
};
|
| 51 |
+
|
| 52 |
+
template <int index>
|
| 53 |
+
struct BlendHalfRegs<index, false> {
|
| 54 |
+
static float16x8_t impl(
|
| 55 |
+
const float16x8_t& a,
|
| 56 |
+
const float16x8_t& b,
|
| 57 |
+
float16x8_t& res) {
|
| 58 |
+
return vsetq_lane_f16(vgetq_lane_f16(a, index), res, index);
|
| 59 |
+
}
|
| 60 |
+
};
|
| 61 |
+
|
| 62 |
+
template <>
|
| 63 |
+
struct is_vec_specialized_for<c10::Half> : std::bool_constant<true> {};
|
| 64 |
+
|
| 65 |
+
// On ARM, Half type supports float16_t->Half constructor and Half->float16_t
|
| 66 |
+
// conversion
|
| 67 |
+
template <>
|
| 68 |
+
class Vectorized<c10::Half> : public Vectorized16<
|
| 69 |
+
float16x8_t,
|
| 70 |
+
c10::Half,
|
| 71 |
+
BlendHalfRegs,
|
| 72 |
+
Vectorized<c10::Half>> {
|
| 73 |
+
using Base = Vectorized16<
|
| 74 |
+
float16x8_t,
|
| 75 |
+
c10::Half,
|
| 76 |
+
BlendHalfRegs,
|
| 77 |
+
Vectorized<c10::Half>>;
|
| 78 |
+
friend Base;
|
| 79 |
+
|
| 80 |
+
private:
|
| 81 |
+
// We use these private map functions to implement various methods
|
| 82 |
+
Vectorized<c10::Half> map_with_vec_float_method(
|
| 83 |
+
Vectorized<float> (Vectorized<float>::*m)() const) const {
|
| 84 |
+
float32x4_t v00 = vcvt_f32_f16(vget_low_f16(values));
|
| 85 |
+
float32x4_t v01 = vcvt_f32_f16(vget_high_f16(values));
|
| 86 |
+
Vectorized<float> mv0 = (Vectorized<float>(v00).*m)();
|
| 87 |
+
Vectorized<float> mv1 = (Vectorized<float>(v01).*m)();
|
| 88 |
+
float16x4_t r00 = vcvt_f16_f32(mv0);
|
| 89 |
+
float16x4_t r01 = vcvt_f16_f32(mv1);
|
| 90 |
+
return Vectorized<c10::Half>(vcombine_f16(r00, r01));
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
Vectorized<c10::Half> map2_with_vec_float_method(
|
| 94 |
+
const Vectorized<c10::Half>& second,
|
| 95 |
+
Vectorized<float> (Vectorized<float>::*m)(const Vectorized<float>&)
|
| 96 |
+
const) const {
|
| 97 |
+
float32x4_t v00 = vcvt_f32_f16(vget_low_f16(values));
|
| 98 |
+
float32x4_t v01 = vcvt_f32_f16(vget_high_f16(values));
|
| 99 |
+
float32x4_t second_v00 = vcvt_f32_f16(vget_low_f16(second.values));
|
| 100 |
+
float32x4_t second_v01 = vcvt_f32_f16(vget_high_f16(second.values));
|
| 101 |
+
Vectorized<float> mv0 =
|
| 102 |
+
(Vectorized<float>(v00).*m)(Vectorized<float>(second_v00));
|
| 103 |
+
Vectorized<float> mv1 =
|
| 104 |
+
(Vectorized<float>(v01).*m)(Vectorized<float>(second_v01));
|
| 105 |
+
float16x4_t r00 = vcvt_f16_f32(mv0);
|
| 106 |
+
float16x4_t r01 = vcvt_f16_f32(mv1);
|
| 107 |
+
|
| 108 |
+
// Pack result into Vectorized<c10::Half>
|
| 109 |
+
return Vectorized<c10::Half>(vcombine_f16(r00, r01));
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
Vectorized<c10::Half> map2_bitmask_with_vec_float_method(
|
| 113 |
+
const Vectorized<c10::Half>& second,
|
| 114 |
+
Vectorized<float> (Vectorized<float>::*m)(const Vectorized<float>&)
|
| 115 |
+
const) const {
|
| 116 |
+
float32x4_t v00 = vcvt_f32_f16(vget_low_f16(values));
|
| 117 |
+
float32x4_t v01 = vcvt_f32_f16(vget_high_f16(values));
|
| 118 |
+
float32x4_t second_v00 = vcvt_f32_f16(vget_low_f16(second.values));
|
| 119 |
+
float32x4_t second_v01 = vcvt_f32_f16(vget_high_f16(second.values));
|
| 120 |
+
Vectorized<float> mv0 =
|
| 121 |
+
(Vectorized<float>(v00).*m)(Vectorized<float>(second_v00));
|
| 122 |
+
Vectorized<float> mv1 =
|
| 123 |
+
(Vectorized<float>(v01).*m)(Vectorized<float>(second_v01));
|
| 124 |
+
// Assume the operator returns a bitmask, not "real" floats, and
|
| 125 |
+
// just narrow the bits. All-ones is a NaN and will get mangled by
|
| 126 |
+
// conversion!
|
| 127 |
+
float16x4_t r00 =
|
| 128 |
+
vreinterpret_f16_u16(vmovn_u32(vreinterpretq_u32_f32(mv0)));
|
| 129 |
+
float16x4_t r01 =
|
| 130 |
+
vreinterpret_f16_u16(vmovn_u32(vreinterpretq_u32_f32(mv1)));
|
| 131 |
+
|
| 132 |
+
// Pack result into Vectorized<c10::Half>
|
| 133 |
+
return Vectorized<c10::Half>(vcombine_f16(r00, r01));
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
public:
|
| 137 |
+
using Vectorized16::Vectorized16;
|
| 138 |
+
|
| 139 |
+
Vectorized() = default;
|
| 140 |
+
|
| 141 |
+
// A ctor that accepts c10::Half is needed to fit interface with vec_base.h
|
| 142 |
+
// A second constructor that takes float16_t is also included
|
| 143 |
+
Vectorized(c10::Half val) : Vectorized((float16_t)val) {}
|
| 144 |
+
Vectorized(float16_t val) : Vectorized16(vdupq_n_f16(val)) {}
|
| 145 |
+
Vectorized(
|
| 146 |
+
value_type val0,
|
| 147 |
+
value_type val1,
|
| 148 |
+
value_type val2,
|
| 149 |
+
value_type val3,
|
| 150 |
+
value_type val4,
|
| 151 |
+
value_type val5,
|
| 152 |
+
value_type val6,
|
| 153 |
+
value_type val7)
|
| 154 |
+
: Vectorized16(
|
| 155 |
+
float16x8_t{val0, val1, val2, val3, val4, val5, val6, val7}) {}
|
| 156 |
+
|
| 157 |
+
static Vectorized<c10::Half> blendv(
|
| 158 |
+
const Vectorized<c10::Half>& a,
|
| 159 |
+
const Vectorized<c10::Half>& b,
|
| 160 |
+
const Vectorized<c10::Half>& mask) {
|
| 161 |
+
// Note: using blendv is very awkward because 0xFFFF is one of
|
| 162 |
+
// many NaN's in FP16 It's unfortunate that the mask has type Half
|
| 163 |
+
// (required from vec_base)
|
| 164 |
+
|
| 165 |
+
// TODO
|
| 166 |
+
// NB: This requires that each value, i.e., each uint value,
|
| 167 |
+
// of the mask either all be zeros or all be 1s.
|
| 168 |
+
// We perhaps need some kind of an assert?
|
| 169 |
+
// But that will affect performance.
|
| 170 |
+
|
| 171 |
+
// NOTE [vbslq_f16]: vbslq_f16 doesn't work on clang without
|
| 172 |
+
// __ARM_FEATURE_FP16_VECTOR_ARITHMETIC. vbslq_u16 generates the
|
| 173 |
+
// same instruction anyway. see https://godbolt.org/z/cY4a55Y7P
|
| 174 |
+
Vectorized<c10::Half> vec(mask.values);
|
| 175 |
+
vec.values = vreinterpretq_f16_u16(vbslq_u16(
|
| 176 |
+
vreinterpretq_u16_f16(vec.values),
|
| 177 |
+
vreinterpretq_u16_f16(b.values),
|
| 178 |
+
vreinterpretq_u16_f16(a.values)));
|
| 179 |
+
return vec;
|
| 180 |
+
}
|
| 181 |
+
static Vectorized<c10::Half> set(
|
| 182 |
+
const Vectorized<c10::Half>& a,
|
| 183 |
+
const Vectorized<c10::Half>& b,
|
| 184 |
+
int64_t count = size()) {
|
| 185 |
+
uint16_t pre_mask[size()] = {0};
|
| 186 |
+
for (int i = 0; i < count; i++) {
|
| 187 |
+
pre_mask[i] = 0xFFFF;
|
| 188 |
+
}
|
| 189 |
+
uint16x8_t mask = vld1q_u16(pre_mask);
|
| 190 |
+
|
| 191 |
+
// Using blendv is awkward because 0xFFFF is one of many NaN's in FP16
|
| 192 |
+
// so we directly use vbslq_u16 instead. (See NOTE [vbslq_f16] above.)
|
| 193 |
+
Vectorized<c10::Half> vec(vreinterpretq_f16_u16(vbslq_u16(
|
| 194 |
+
mask,
|
| 195 |
+
vreinterpretq_u16_f16(b.values),
|
| 196 |
+
vreinterpretq_u16_f16(a.values))));
|
| 197 |
+
|
| 198 |
+
return vec;
|
| 199 |
+
}
|
| 200 |
+
static Vectorized<c10::Half> loadu(const void* ptr, int64_t count = size()) {
|
| 201 |
+
if (count == size()) {
|
| 202 |
+
return vld1q_f16(reinterpret_cast<const float16_t*>(ptr));
|
| 203 |
+
}
|
| 204 |
+
__at_align__ float16_t tmp_values[size()];
|
| 205 |
+
for (const auto i : c10::irange(size())) {
|
| 206 |
+
tmp_values[i] = 0;
|
| 207 |
+
}
|
| 208 |
+
std::memcpy(
|
| 209 |
+
tmp_values,
|
| 210 |
+
reinterpret_cast<const float16_t*>(ptr),
|
| 211 |
+
count * sizeof(float16_t));
|
| 212 |
+
return vld1q_f16(reinterpret_cast<const float16_t*>(tmp_values));
|
| 213 |
+
}
|
| 214 |
+
void store(void* ptr, int64_t count = size()) const {
|
| 215 |
+
if (count == size()) {
|
| 216 |
+
vst1q_f16(reinterpret_cast<float16_t*>(ptr), values);
|
| 217 |
+
return;
|
| 218 |
+
} else {
|
| 219 |
+
float16_t tmp_values[size()];
|
| 220 |
+
vst1q_f16(reinterpret_cast<float16_t*>(tmp_values), values);
|
| 221 |
+
std::memcpy(ptr, tmp_values, count * sizeof(float16_t));
|
| 222 |
+
}
|
| 223 |
+
}
|
| 224 |
+
int zero_mask() const {
|
| 225 |
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
| 226 |
+
uint16x8_t is_zero_vec = vceqzq_f16(values);
|
| 227 |
+
const int16x8_t shift = vcombine_s16(
|
| 228 |
+
vcreate_s16(
|
| 229 |
+
0x0 | (int64_t(0x1) << 16) | (int64_t(0x2) << 32) |
|
| 230 |
+
(int64_t(0x3) << 48)),
|
| 231 |
+
vcreate_s16(
|
| 232 |
+
0x4 | (int64_t(0x5) << 16) | (int64_t(0x6) << 32) |
|
| 233 |
+
(int64_t(0x7) << 48)));
|
| 234 |
+
uint16x8_t bits_vec =
|
| 235 |
+
vshlq_u16(vandq_u16(is_zero_vec, vdupq_n_u16(1)), shift);
|
| 236 |
+
return vaddvq_u16(bits_vec);
|
| 237 |
+
#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
| 238 |
+
// use known working implementation.
|
| 239 |
+
__at_align__ value_type tmp[size()];
|
| 240 |
+
store(tmp);
|
| 241 |
+
int mask = 0;
|
| 242 |
+
for (int i = 0; i < size(); ++i) {
|
| 243 |
+
if (tmp[i] == 0) {
|
| 244 |
+
mask |= (1 << i);
|
| 245 |
+
}
|
| 246 |
+
}
|
| 247 |
+
return mask;
|
| 248 |
+
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
| 249 |
+
}
|
| 250 |
+
Vectorized<c10::Half> isnan() const {
|
| 251 |
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
| 252 |
+
return vreinterpretq_f16_u16(vmvnq_u16(vceqq_f16(values, values)));
|
| 253 |
+
#else
|
| 254 |
+
// NOTE: we could make this faster by doing vectorized checks of
|
| 255 |
+
// exponent/payload bits.
|
| 256 |
+
__at_align__ c10::Half tmp[size()];
|
| 257 |
+
__at_align__ c10::Half res[size()];
|
| 258 |
+
store(tmp);
|
| 259 |
+
for (const auto i : c10::irange(size())) {
|
| 260 |
+
if (_isnan(tmp[i])) {
|
| 261 |
+
std::memset(static_cast<void*>(&res[i]), 0xFF, sizeof(c10::Half));
|
| 262 |
+
} else {
|
| 263 |
+
std::memset(static_cast<void*>(&res[i]), 0, sizeof(c10::Half));
|
| 264 |
+
}
|
| 265 |
+
}
|
| 266 |
+
return loadu(res);
|
| 267 |
+
#endif
|
| 268 |
+
}
|
| 269 |
+
bool has_inf_nan() const {
|
| 270 |
+
__at_align__ c10::Half tmp[size()];
|
| 271 |
+
store(tmp);
|
| 272 |
+
for (const auto i : c10::irange(size())) {
|
| 273 |
+
if (_isnan(tmp[i]) || _isinf(tmp[i])) {
|
| 274 |
+
return true;
|
| 275 |
+
}
|
| 276 |
+
}
|
| 277 |
+
return false;
|
| 278 |
+
}
|
| 279 |
+
Vectorized<c10::Half> abs() const {
|
| 280 |
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
| 281 |
+
return Vectorized<c10::Half>(vabsq_f16(values));
|
| 282 |
+
#else
|
| 283 |
+
return map_with_vec_float_method(&Vectorized<float>::abs);
|
| 284 |
+
#endif
|
| 285 |
+
}
|
| 286 |
+
Vectorized<c10::Half> frac() const;
|
| 287 |
+
Vectorized<c10::Half> neg() const {
|
| 288 |
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
| 289 |
+
return Vectorized<c10::Half>(vnegq_f16(values));
|
| 290 |
+
#else
|
| 291 |
+
return map_with_vec_float_method(&Vectorized<float>::neg);
|
| 292 |
+
#endif
|
| 293 |
+
}
|
| 294 |
+
Vectorized<c10::Half> trunc() const {
|
| 295 |
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
| 296 |
+
return Vectorized<c10::Half>(vrndq_f16(values));
|
| 297 |
+
#else
|
| 298 |
+
return map_with_vec_float_method(&Vectorized<float>::trunc);
|
| 299 |
+
#endif
|
| 300 |
+
}
|
| 301 |
+
Vectorized<c10::Half> sqrt() const {
|
| 302 |
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
| 303 |
+
return Vectorized<c10::Half>(vsqrtq_f16(values));
|
| 304 |
+
#else
|
| 305 |
+
return map_with_vec_float_method(&Vectorized<float>::sqrt);
|
| 306 |
+
#endif
|
| 307 |
+
}
|
| 308 |
+
Vectorized<c10::Half> reciprocal() const {
|
| 309 |
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
| 310 |
+
auto ones = vdupq_n_f16(1.0f);
|
| 311 |
+
return Vectorized<c10::Half>(vdivq_f16(ones, values));
|
| 312 |
+
#else
|
| 313 |
+
return map_with_vec_float_method(&Vectorized<float>::reciprocal);
|
| 314 |
+
#endif
|
| 315 |
+
}
|
| 316 |
+
Vectorized<c10::Half> operator==(const Vectorized<c10::Half>& other) const {
|
| 317 |
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
| 318 |
+
return Vectorized<c10::Half>(
|
| 319 |
+
vreinterpretq_f16_u16(vceqq_f16(values, other.values)));
|
| 320 |
+
#else
|
| 321 |
+
return map2_bitmask_with_vec_float_method(
|
| 322 |
+
other, &Vectorized<float>::operator==);
|
| 323 |
+
#endif
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
Vectorized<c10::Half> operator!=(const Vectorized<c10::Half>& other) const {
|
| 327 |
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
| 328 |
+
return Vectorized<c10::Half>(
|
| 329 |
+
vreinterpretq_f16_u16(vmvnq_u16(vceqq_f16(values, other.values))));
|
| 330 |
+
#else
|
| 331 |
+
return map2_bitmask_with_vec_float_method(
|
| 332 |
+
other, &Vectorized<float>::operator!=);
|
| 333 |
+
#endif
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
Vectorized<c10::Half> operator<(const Vectorized<c10::Half>& other) const {
|
| 337 |
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
| 338 |
+
return Vectorized<c10::Half>(
|
| 339 |
+
vreinterpretq_f16_u16(vcltq_f16(values, other.values)));
|
| 340 |
+
#else
|
| 341 |
+
return map2_bitmask_with_vec_float_method(
|
| 342 |
+
other, &Vectorized<float>::operator<);
|
| 343 |
+
#endif
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
Vectorized<c10::Half> operator<=(const Vectorized<c10::Half>& other) const {
|
| 347 |
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
| 348 |
+
return Vectorized<c10::Half>(
|
| 349 |
+
vreinterpretq_f16_u16(vcleq_f16(values, other.values)));
|
| 350 |
+
#else
|
| 351 |
+
return map2_bitmask_with_vec_float_method(
|
| 352 |
+
other, &Vectorized<float>::operator<=);
|
| 353 |
+
#endif
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
Vectorized<c10::Half> operator>(const Vectorized<c10::Half>& other) const {
|
| 357 |
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
| 358 |
+
return Vectorized<c10::Half>(
|
| 359 |
+
vreinterpretq_f16_u16(vcgtq_f16(values, other.values)));
|
| 360 |
+
#else
|
| 361 |
+
return map2_bitmask_with_vec_float_method(
|
| 362 |
+
other, &Vectorized<float>::operator>);
|
| 363 |
+
#endif
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
Vectorized<c10::Half> operator>=(const Vectorized<c10::Half>& other) const {
|
| 367 |
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
| 368 |
+
return Vectorized<c10::Half>(
|
| 369 |
+
vreinterpretq_f16_u16(vcgeq_f16(values, other.values)));
|
| 370 |
+
#else
|
| 371 |
+
return map2_bitmask_with_vec_float_method(
|
| 372 |
+
other, &Vectorized<float>::operator>=);
|
| 373 |
+
#endif
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
Vectorized<c10::Half> eq(const Vectorized<c10::Half>& other) const;
|
| 377 |
+
Vectorized<c10::Half> ne(const Vectorized<c10::Half>& other) const;
|
| 378 |
+
Vectorized<c10::Half> gt(const Vectorized<c10::Half>& other) const;
|
| 379 |
+
Vectorized<c10::Half> ge(const Vectorized<c10::Half>& other) const;
|
| 380 |
+
Vectorized<c10::Half> lt(const Vectorized<c10::Half>& other) const;
|
| 381 |
+
Vectorized<c10::Half> le(const Vectorized<c10::Half>& other) const;
|
| 382 |
+
}; // Vectorized<Half>
|
| 383 |
+
|
| 384 |
+
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_half_float(
|
| 385 |
+
const Vectorized<Half>& a) {
|
| 386 |
+
static_assert(Vectorized<Half>::size() == 2 * Vectorized<float>::size());
|
| 387 |
+
float16x8_t x = a;
|
| 388 |
+
float32x4_t x1 = vcvt_f32_f16(vget_low_f16(x));
|
| 389 |
+
float32x4_t x2 = vcvt_f32_f16(vget_high_f16(x));
|
| 390 |
+
return {Vectorized<float>(x1), Vectorized<float>(x2)};
|
| 391 |
+
}
|
| 392 |
+
inline Vectorized<Half> convert_float_half(
|
| 393 |
+
const Vectorized<float>& a,
|
| 394 |
+
const Vectorized<float>& b) {
|
| 395 |
+
static_assert(Vectorized<Half>::size() == 2 * Vectorized<float>::size());
|
| 396 |
+
float32x4_t x = a;
|
| 397 |
+
float32x4_t y = b;
|
| 398 |
+
float16x4_t x1 = vcvt_f16_f32(x);
|
| 399 |
+
float16x4_t x2 = vcvt_f16_f32(y);
|
| 400 |
+
return Vectorized<Half>(vcombine_f16(x1, x2));
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
template <typename Op>
|
| 404 |
+
Vectorized<c10::Half> binary_operator_via_float(
|
| 405 |
+
Op op,
|
| 406 |
+
const Vectorized<c10::Half>& a,
|
| 407 |
+
const Vectorized<c10::Half>& b) {
|
| 408 |
+
const auto [a_float_low, a_float_high] = convert_half_float(a);
|
| 409 |
+
const auto [b_float_low, b_float_high] = convert_half_float(b);
|
| 410 |
+
return convert_float_half(
|
| 411 |
+
op(a_float_low, b_float_low), op(a_float_high, b_float_high));
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
template <>
|
| 415 |
+
Vectorized<c10::Half> inline operator+(
|
| 416 |
+
const Vectorized<c10::Half>& a,
|
| 417 |
+
const Vectorized<c10::Half>& b) {
|
| 418 |
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
| 419 |
+
return Vectorized<c10::Half>(vaddq_f16(a, b));
|
| 420 |
+
#else
|
| 421 |
+
return binary_operator_via_float(std::plus<Vectorized<float>>(), a, b);
|
| 422 |
+
#endif
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
template <>
|
| 426 |
+
Vectorized<c10::Half> inline operator-(
|
| 427 |
+
const Vectorized<c10::Half>& a,
|
| 428 |
+
const Vectorized<c10::Half>& b) {
|
| 429 |
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
| 430 |
+
return Vectorized<c10::Half>(vsubq_f16(a, b));
|
| 431 |
+
#else
|
| 432 |
+
return binary_operator_via_float(std::minus<Vectorized<float>>(), a, b);
|
| 433 |
+
#endif
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
template <>
|
| 437 |
+
Vectorized<c10::Half> inline operator*(
|
| 438 |
+
const Vectorized<c10::Half>& a,
|
| 439 |
+
const Vectorized<c10::Half>& b) {
|
| 440 |
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
| 441 |
+
return Vectorized<c10::Half>(vmulq_f16(a, b));
|
| 442 |
+
#else
|
| 443 |
+
return binary_operator_via_float(std::multiplies<Vectorized<float>>(), a, b);
|
| 444 |
+
#endif
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
template <>
|
| 448 |
+
Vectorized<c10::Half> inline operator/(
|
| 449 |
+
const Vectorized<c10::Half>& a,
|
| 450 |
+
const Vectorized<c10::Half>& b) {
|
| 451 |
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
| 452 |
+
return Vectorized<c10::Half>(vdivq_f16(a, b));
|
| 453 |
+
#else
|
| 454 |
+
return binary_operator_via_float(std::divides<Vectorized<float>>(), a, b);
|
| 455 |
+
#endif
|
| 456 |
+
}
|
| 457 |
+
|
| 458 |
+
// frac. Implement this here so we can use subtraction
|
| 459 |
+
inline Vectorized<c10::Half> Vectorized<c10::Half>::frac() const {
|
| 460 |
+
return *this - this->trunc();
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
|
| 464 |
+
// either input is a NaN.
|
| 465 |
+
template <>
|
| 466 |
+
Vectorized<c10::Half> inline maximum(
|
| 467 |
+
const Vectorized<c10::Half>& a,
|
| 468 |
+
const Vectorized<c10::Half>& b) {
|
| 469 |
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
| 470 |
+
return Vectorized<c10::Half>(vmaxq_f16(a, b));
|
| 471 |
+
#else
|
| 472 |
+
return binary_operator_via_float(
|
| 473 |
+
static_cast<Vectorized<float> (*)(
|
| 474 |
+
const Vectorized<float>&, const Vectorized<float>&)>(&maximum),
|
| 475 |
+
a,
|
| 476 |
+
b);
|
| 477 |
+
#endif
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
|
| 481 |
+
// either input is a NaN.
|
| 482 |
+
template <>
|
| 483 |
+
Vectorized<c10::Half> inline minimum(
|
| 484 |
+
const Vectorized<c10::Half>& a,
|
| 485 |
+
const Vectorized<c10::Half>& b) {
|
| 486 |
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
| 487 |
+
return Vectorized<c10::Half>(vminq_f16(a, b));
|
| 488 |
+
#else
|
| 489 |
+
return binary_operator_via_float(
|
| 490 |
+
static_cast<Vectorized<float> (*)(
|
| 491 |
+
const Vectorized<float>&, const Vectorized<float>&)>(&minimum),
|
| 492 |
+
a,
|
| 493 |
+
b);
|
| 494 |
+
#endif
|
| 495 |
+
}
|
| 496 |
+
|
| 497 |
+
template <>
|
| 498 |
+
Vectorized<c10::Half> inline clamp(
|
| 499 |
+
const Vectorized<c10::Half>& a,
|
| 500 |
+
const Vectorized<c10::Half>& min,
|
| 501 |
+
const Vectorized<c10::Half>& max) {
|
| 502 |
+
return minimum(max, maximum(min, a));
|
| 503 |
+
}
|
| 504 |
+
|
| 505 |
+
template <>
|
| 506 |
+
Vectorized<c10::Half> inline clamp_max(
|
| 507 |
+
const Vectorized<c10::Half>& a,
|
| 508 |
+
const Vectorized<c10::Half>& max) {
|
| 509 |
+
return minimum(max, a);
|
| 510 |
+
}
|
| 511 |
+
|
| 512 |
+
template <>
|
| 513 |
+
Vectorized<c10::Half> inline clamp_min(
|
| 514 |
+
const Vectorized<c10::Half>& a,
|
| 515 |
+
const Vectorized<c10::Half>& min) {
|
| 516 |
+
return maximum(min, a);
|
| 517 |
+
}
|
| 518 |
+
|
| 519 |
+
template <>
|
| 520 |
+
Vectorized<c10::Half> inline operator&(
|
| 521 |
+
const Vectorized<c10::Half>& a,
|
| 522 |
+
const Vectorized<c10::Half>& b) {
|
| 523 |
+
return Vectorized<c10::Half>(vreinterpretq_f16_u16(
|
| 524 |
+
vandq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b))));
|
| 525 |
+
}
|
| 526 |
+
|
| 527 |
+
template <>
|
| 528 |
+
Vectorized<c10::Half> inline operator|(
|
| 529 |
+
const Vectorized<c10::Half>& a,
|
| 530 |
+
const Vectorized<c10::Half>& b) {
|
| 531 |
+
return Vectorized<c10::Half>(vreinterpretq_f16_u16(
|
| 532 |
+
vorrq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b))));
|
| 533 |
+
}
|
| 534 |
+
|
| 535 |
+
template <>
|
| 536 |
+
Vectorized<c10::Half> inline operator^(
|
| 537 |
+
const Vectorized<c10::Half>& a,
|
| 538 |
+
const Vectorized<c10::Half>& b) {
|
| 539 |
+
return Vectorized<c10::Half>(vreinterpretq_f16_u16(
|
| 540 |
+
veorq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b))));
|
| 541 |
+
}
|
| 542 |
+
|
| 543 |
+
inline Vectorized<c10::Half> Vectorized<c10::Half>::eq(
|
| 544 |
+
const Vectorized<c10::Half>& other) const {
|
| 545 |
+
return (*this == other) & Vectorized<c10::Half>(1);
|
| 546 |
+
}
|
| 547 |
+
|
| 548 |
+
inline Vectorized<c10::Half> Vectorized<c10::Half>::ne(
|
| 549 |
+
const Vectorized<c10::Half>& other) const {
|
| 550 |
+
return (*this != other) & Vectorized<c10::Half>(1);
|
| 551 |
+
}
|
| 552 |
+
|
| 553 |
+
inline Vectorized<c10::Half> Vectorized<c10::Half>::gt(
|
| 554 |
+
const Vectorized<c10::Half>& other) const {
|
| 555 |
+
return (*this > other) & Vectorized<c10::Half>(1);
|
| 556 |
+
}
|
| 557 |
+
|
| 558 |
+
inline Vectorized<c10::Half> Vectorized<c10::Half>::ge(
|
| 559 |
+
const Vectorized<c10::Half>& other) const {
|
| 560 |
+
return (*this >= other) & Vectorized<c10::Half>(1);
|
| 561 |
+
}
|
| 562 |
+
|
| 563 |
+
inline Vectorized<c10::Half> Vectorized<c10::Half>::lt(
|
| 564 |
+
const Vectorized<c10::Half>& other) const {
|
| 565 |
+
return (*this < other) & Vectorized<c10::Half>(1);
|
| 566 |
+
}
|
| 567 |
+
|
| 568 |
+
inline Vectorized<c10::Half> Vectorized<c10::Half>::le(
|
| 569 |
+
const Vectorized<c10::Half>& other) const {
|
| 570 |
+
return (*this <= other) & Vectorized<c10::Half>(1);
|
| 571 |
+
}
|
| 572 |
+
|
| 573 |
+
template <>
|
| 574 |
+
Vectorized<c10::Half> inline fmadd(
|
| 575 |
+
const Vectorized<c10::Half>& a,
|
| 576 |
+
const Vectorized<c10::Half>& b,
|
| 577 |
+
const Vectorized<c10::Half>& c) {
|
| 578 |
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
| 579 |
+
return Vectorized<c10::Half>(vfmaq_f16(c, a, b));
|
| 580 |
+
#else
|
| 581 |
+
return a * b + c;
|
| 582 |
+
#endif
|
| 583 |
+
}
|
| 584 |
+
|
| 585 |
+
template <>
|
| 586 |
+
Vectorized<c10::Half> inline fnmadd(
|
| 587 |
+
const Vectorized<c10::Half>& a,
|
| 588 |
+
const Vectorized<c10::Half>& b,
|
| 589 |
+
const Vectorized<c10::Half>& c) {
|
| 590 |
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
| 591 |
+
return Vectorized<c10::Half>(vfmsq_f16(c, a, b));
|
| 592 |
+
#else
|
| 593 |
+
return -a * b + c;
|
| 594 |
+
#endif
|
| 595 |
+
}
|
| 596 |
+
|
| 597 |
+
template <>
|
| 598 |
+
Vectorized<c10::Half> inline fmsub(
|
| 599 |
+
const Vectorized<c10::Half>& a,
|
| 600 |
+
const Vectorized<c10::Half>& b,
|
| 601 |
+
const Vectorized<c10::Half>& c) {
|
| 602 |
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
| 603 |
+
return Vectorized<c10::Half>(vnegq_f16(vfmsq_f16(c, a, b)));
|
| 604 |
+
#else
|
| 605 |
+
return a * b - c;
|
| 606 |
+
#endif
|
| 607 |
+
}
|
| 608 |
+
|
| 609 |
+
template <>
|
| 610 |
+
Vectorized<c10::Half> inline fnmsub(
|
| 611 |
+
const Vectorized<c10::Half>& a,
|
| 612 |
+
const Vectorized<c10::Half>& b,
|
| 613 |
+
const Vectorized<c10::Half>& c) {
|
| 614 |
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
| 615 |
+
return Vectorized<c10::Half>(vnegq_f16(vfmaq_f16(c, a, b)));
|
| 616 |
+
#else
|
| 617 |
+
return -a * b - c;
|
| 618 |
+
#endif
|
| 619 |
+
}
|
| 620 |
+
#endif // !defined(C10_MOBILE) && defined(__aarch64__)
|
| 621 |
+
|
| 622 |
+
} // namespace CPU_CAPABILITY
|
| 623 |
+
} // namespace at::vec
|
| 624 |
+
|
| 625 |
+
#else
|
| 626 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 627 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec128/vec128_reduced_precision_common_neon.h
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
// Shared code for bfloat16 and float16.
|
| 4 |
+
|
| 5 |
+
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
|
| 6 |
+
// See Note [Do not compile initializers with AVX]
|
| 7 |
+
|
| 8 |
+
namespace at::vec {
|
| 9 |
+
inline namespace CPU_CAPABILITY {
|
| 10 |
+
|
| 11 |
+
// Shared implementation between Vectorized<c10::Half> and
|
| 12 |
+
// Vectorized<c10::BFloat16>. Uses CRTP to allow derived class
|
| 13 |
+
// customization.
|
| 14 |
+
template <
|
| 15 |
+
typename VecT,
|
| 16 |
+
typename ValueT,
|
| 17 |
+
template <int, bool> typename BlendRegs,
|
| 18 |
+
typename Derived>
|
| 19 |
+
struct Vectorized16 {
|
| 20 |
+
protected:
|
| 21 |
+
VecT values;
|
| 22 |
+
|
| 23 |
+
public:
|
| 24 |
+
using value_type = ValueT;
|
| 25 |
+
using size_type = int;
|
| 26 |
+
static constexpr size_type size() {
|
| 27 |
+
static_assert(sizeof(VecT) == 8 * sizeof(value_type));
|
| 28 |
+
return 8;
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
protected:
|
| 32 |
+
Derived map2(
|
| 33 |
+
const Derived& second,
|
| 34 |
+
value_type (*const f)(value_type, value_type)) const {
|
| 35 |
+
__at_align__ value_type tmp_first[size()];
|
| 36 |
+
__at_align__ value_type tmp_second[size()];
|
| 37 |
+
static_cast<const Derived*>(this)->store(
|
| 38 |
+
tmp_first); // store this to tmp_first
|
| 39 |
+
second.store(tmp_second);
|
| 40 |
+
for (const auto i : c10::irange(size())) {
|
| 41 |
+
tmp_first[i] = f(tmp_first[i], tmp_second[i]);
|
| 42 |
+
}
|
| 43 |
+
return Derived::loadu(tmp_first);
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
public:
|
| 47 |
+
Vectorized16() = default;
|
| 48 |
+
Vectorized16(VecT v) : values(v) {}
|
| 49 |
+
|
| 50 |
+
operator VecT() const {
|
| 51 |
+
return values;
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
template <int64_t mask>
|
| 55 |
+
static Derived blend(const Derived& a, const Derived& b) {
|
| 56 |
+
Derived vec;
|
| 57 |
+
vec.values = BlendRegs < 0,
|
| 58 |
+
(mask & 0x01) != 0 > ::impl(a.values, b.values, vec.values);
|
| 59 |
+
vec.values = BlendRegs < 1,
|
| 60 |
+
(mask & 0x02) != 0 > ::impl(a.values, b.values, vec.values);
|
| 61 |
+
vec.values = BlendRegs < 2,
|
| 62 |
+
(mask & 0x04) != 0 > ::impl(a.values, b.values, vec.values);
|
| 63 |
+
vec.values = BlendRegs < 3,
|
| 64 |
+
(mask & 0x08) != 0 > ::impl(a.values, b.values, vec.values);
|
| 65 |
+
|
| 66 |
+
vec.values = BlendRegs < 4,
|
| 67 |
+
(mask & 0x10) != 0 > ::impl(a.values, b.values, vec.values);
|
| 68 |
+
vec.values = BlendRegs < 5,
|
| 69 |
+
(mask & 0x20) != 0 > ::impl(a.values, b.values, vec.values);
|
| 70 |
+
vec.values = BlendRegs < 6,
|
| 71 |
+
(mask & 0x40) != 0 > ::impl(a.values, b.values, vec.values);
|
| 72 |
+
vec.values = BlendRegs < 7,
|
| 73 |
+
(mask & 0x80) != 0 > ::impl(a.values, b.values, vec.values);
|
| 74 |
+
|
| 75 |
+
return vec;
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
template <typename step_t>
|
| 79 |
+
static Derived arange(
|
| 80 |
+
value_type base = 0,
|
| 81 |
+
step_t step = static_cast<step_t>(1)) {
|
| 82 |
+
const Derived base_vec(base);
|
| 83 |
+
const Derived step_vec(step);
|
| 84 |
+
const Derived step_sizes(
|
| 85 |
+
value_type(0),
|
| 86 |
+
value_type(1),
|
| 87 |
+
value_type(2),
|
| 88 |
+
value_type(3),
|
| 89 |
+
value_type(4),
|
| 90 |
+
value_type(5),
|
| 91 |
+
value_type(6),
|
| 92 |
+
value_type(7));
|
| 93 |
+
return fmadd(step_sizes, step_vec, base_vec);
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
// Very slow implementation of indexing.
|
| 97 |
+
// Only required because vec256_qint refers to this.
|
| 98 |
+
// Once we specialize that implementation for ARM
|
| 99 |
+
// this should be removed. TODO (kimishpatel)
|
| 100 |
+
value_type operator[](int idx) const {
|
| 101 |
+
__at_align__ value_type tmp[size()];
|
| 102 |
+
static_cast<const Derived*>(this)->store(tmp);
|
| 103 |
+
return tmp[idx];
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
int zero_mask() const {
|
| 107 |
+
__at_align__ value_type tmp[size()];
|
| 108 |
+
static_cast<const Derived*>(this)->store(tmp);
|
| 109 |
+
int mask = 0;
|
| 110 |
+
for (int i = 0; i < size(); ++i) {
|
| 111 |
+
if (tmp[i] == 0) {
|
| 112 |
+
mask |= (1 << i);
|
| 113 |
+
}
|
| 114 |
+
}
|
| 115 |
+
return mask;
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
Derived map(value_type (*const f)(value_type)) const {
|
| 119 |
+
__at_align__ value_type tmp[size()];
|
| 120 |
+
static_cast<const Derived*>(this)->store(tmp);
|
| 121 |
+
for (const auto i : c10::irange(size())) {
|
| 122 |
+
tmp[i] = f(tmp[i]);
|
| 123 |
+
}
|
| 124 |
+
return Derived::loadu(tmp);
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
Derived angle() const {
|
| 128 |
+
auto zero = Derived(0);
|
| 129 |
+
auto pi = Derived(c10::pi<value_type>);
|
| 130 |
+
auto tmp =
|
| 131 |
+
Derived::blendv(zero, pi, *static_cast<const Derived*>(this) < zero);
|
| 132 |
+
return Derived::blendv(
|
| 133 |
+
tmp,
|
| 134 |
+
*static_cast<const Derived*>(this),
|
| 135 |
+
static_cast<const Derived*>(this)->isnan());
|
| 136 |
+
}
|
| 137 |
+
Derived real() const {
|
| 138 |
+
return *this;
|
| 139 |
+
}
|
| 140 |
+
Derived imag() const {
|
| 141 |
+
return Derived(0);
|
| 142 |
+
}
|
| 143 |
+
Derived conj() const {
|
| 144 |
+
return *this;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
// Sleef does not support FP16/BF16, so many math functions are applied by
|
| 148 |
+
// converting to FP32, applying the math function, and then converting back to
|
| 149 |
+
// FP16/BF16.
|
| 150 |
+
Derived acos() const {
|
| 151 |
+
return static_cast<const Derived*>(this)->map_with_vec_float_method(
|
| 152 |
+
&Vectorized<float>::acos);
|
| 153 |
+
}
|
| 154 |
+
Derived acosh() const {
|
| 155 |
+
return static_cast<const Derived*>(this)->map_with_vec_float_method(
|
| 156 |
+
&Vectorized<float>::acosh);
|
| 157 |
+
}
|
| 158 |
+
Derived asin() const {
|
| 159 |
+
return static_cast<const Derived*>(this)->map_with_vec_float_method(
|
| 160 |
+
&Vectorized<float>::asin);
|
| 161 |
+
}
|
| 162 |
+
Derived asinh() const {
|
| 163 |
+
return static_cast<const Derived*>(this)->map_with_vec_float_method(
|
| 164 |
+
&Vectorized<float>::asinh);
|
| 165 |
+
}
|
| 166 |
+
Derived atan() const {
|
| 167 |
+
return static_cast<const Derived*>(this)->map_with_vec_float_method(
|
| 168 |
+
&Vectorized<float>::atan);
|
| 169 |
+
}
|
| 170 |
+
Derived atanh() const {
|
| 171 |
+
return static_cast<const Derived*>(this)->map_with_vec_float_method(
|
| 172 |
+
&Vectorized<float>::atanh);
|
| 173 |
+
}
|
| 174 |
+
Derived atan2(const Derived& exp) const {
|
| 175 |
+
return static_cast<const Derived*>(this)->map2_with_vec_float_method(
|
| 176 |
+
exp, &Vectorized<float>::atan2);
|
| 177 |
+
}
|
| 178 |
+
Derived copysign(const Derived& sign) const {
|
| 179 |
+
return static_cast<const Derived*>(this)->map2_with_vec_float_method(
|
| 180 |
+
sign, &Vectorized<float>::copysign);
|
| 181 |
+
}
|
| 182 |
+
Derived erf() const {
|
| 183 |
+
return static_cast<const Derived*>(this)->map_with_vec_float_method(
|
| 184 |
+
&Vectorized<float>::erf);
|
| 185 |
+
}
|
| 186 |
+
Derived erfc() const {
|
| 187 |
+
return static_cast<const Derived*>(this)->map_with_vec_float_method(
|
| 188 |
+
&Vectorized<float>::erfc);
|
| 189 |
+
}
|
| 190 |
+
Derived erfinv() const {
|
| 191 |
+
return static_cast<const Derived*>(this)->map_with_vec_float_method(
|
| 192 |
+
&Vectorized<float>::erfinv);
|
| 193 |
+
}
|
| 194 |
+
Derived exp() const {
|
| 195 |
+
return static_cast<const Derived*>(this)->map_with_vec_float_method(
|
| 196 |
+
&Vectorized<float>::exp);
|
| 197 |
+
}
|
| 198 |
+
Derived exp2() const {
|
| 199 |
+
return static_cast<const Derived*>(this)->map_with_vec_float_method(
|
| 200 |
+
&Vectorized<float>::exp2);
|
| 201 |
+
}
|
| 202 |
+
Derived expm1() const {
|
| 203 |
+
return static_cast<const Derived*>(this)->map_with_vec_float_method(
|
| 204 |
+
&Vectorized<float>::expm1);
|
| 205 |
+
}
|
| 206 |
+
Derived exp_u20() const {
|
| 207 |
+
return static_cast<const Derived*>(this)->map_with_vec_float_method(
|
| 208 |
+
&Vectorized<float>::exp_u20);
|
| 209 |
+
}
|
| 210 |
+
Derived fexp_u20() const {
|
| 211 |
+
return static_cast<const Derived*>(this)->map_with_vec_float_method(
|
| 212 |
+
&Vectorized<float>::exp_u20);
|
| 213 |
+
}
|
| 214 |
+
Derived fmod(const Derived& q) const {
|
| 215 |
+
// This function is questionable with a conversion, so we use map2
|
| 216 |
+
return map2(q, std::fmod);
|
| 217 |
+
}
|
| 218 |
+
Derived hypot(const Derived& b) const {
|
| 219 |
+
return static_cast<const Derived*>(this)->map2_with_vec_float_method(
|
| 220 |
+
b, &Vectorized<float>::hypot);
|
| 221 |
+
}
|
| 222 |
+
Derived i0() const {
|
| 223 |
+
return static_cast<const Derived*>(this)->map_with_vec_float_method(
|
| 224 |
+
&Vectorized<float>::i0);
|
| 225 |
+
}
|
| 226 |
+
Derived i0e() const {
|
| 227 |
+
return static_cast<const Derived*>(this)->map_with_vec_float_method(
|
| 228 |
+
&Vectorized<float>::i0e);
|
| 229 |
+
}
|
| 230 |
+
Derived digamma() const {
|
| 231 |
+
return static_cast<const Derived*>(this)->map_with_vec_float_method(
|
| 232 |
+
&Vectorized<float>::digamma);
|
| 233 |
+
}
|
| 234 |
+
Derived igamma(const Derived& x) const {
|
| 235 |
+
return static_cast<const Derived*>(this)->map2_with_vec_float_method(
|
| 236 |
+
x, &Vectorized<float>::igamma);
|
| 237 |
+
}
|
| 238 |
+
Derived igammac(const Derived& x) const {
|
| 239 |
+
return static_cast<const Derived*>(this)->map2_with_vec_float_method(
|
| 240 |
+
x, &Vectorized<float>::igammac);
|
| 241 |
+
}
|
| 242 |
+
Derived log() const {
|
| 243 |
+
return static_cast<const Derived*>(this)->map_with_vec_float_method(
|
| 244 |
+
&Vectorized<float>::log);
|
| 245 |
+
}
|
| 246 |
+
Derived log10() const {
|
| 247 |
+
return static_cast<const Derived*>(this)->map_with_vec_float_method(
|
| 248 |
+
&Vectorized<float>::log10);
|
| 249 |
+
}
|
| 250 |
+
Derived log1p() const {
|
| 251 |
+
return static_cast<const Derived*>(this)->map_with_vec_float_method(
|
| 252 |
+
&Vectorized<float>::log1p);
|
| 253 |
+
}
|
| 254 |
+
Derived log2() const {
|
| 255 |
+
return static_cast<const Derived*>(this)->map_with_vec_float_method(
|
| 256 |
+
&Vectorized<float>::log2);
|
| 257 |
+
}
|
| 258 |
+
Derived nextafter(const Derived& b) const {
|
| 259 |
+
// This function does not make sense with conversion, so we use map2
|
| 260 |
+
return map2(b, std::nextafter);
|
| 261 |
+
}
|
| 262 |
+
Derived sin() const {
|
| 263 |
+
return static_cast<const Derived*>(this)->map_with_vec_float_method(
|
| 264 |
+
&Vectorized<float>::sin);
|
| 265 |
+
}
|
| 266 |
+
Derived sinh() const {
|
| 267 |
+
return static_cast<const Derived*>(this)->map_with_vec_float_method(
|
| 268 |
+
&Vectorized<float>::sinh);
|
| 269 |
+
}
|
| 270 |
+
Derived cos() const {
|
| 271 |
+
return static_cast<const Derived*>(this)->map_with_vec_float_method(
|
| 272 |
+
&Vectorized<float>::cos);
|
| 273 |
+
}
|
| 274 |
+
Derived cosh() const {
|
| 275 |
+
return static_cast<const Derived*>(this)->map_with_vec_float_method(
|
| 276 |
+
&Vectorized<float>::cosh);
|
| 277 |
+
}
|
| 278 |
+
Derived ceil() const {
|
| 279 |
+
// This function is questionable with a conversion, so we use map
|
| 280 |
+
return map(at::native::ceil_impl);
|
| 281 |
+
}
|
| 282 |
+
Derived floor() const {
|
| 283 |
+
// This function is questionable with a conversion, so we use map
|
| 284 |
+
return map(at::native::floor_impl);
|
| 285 |
+
}
|
| 286 |
+
Derived round() const {
|
| 287 |
+
// This function is questionable with a conversion, so we use map
|
| 288 |
+
return map(at::native::round_impl);
|
| 289 |
+
}
|
| 290 |
+
Derived tan() const {
|
| 291 |
+
return static_cast<const Derived*>(this)->map_with_vec_float_method(
|
| 292 |
+
&Vectorized<float>::tan);
|
| 293 |
+
}
|
| 294 |
+
Derived tanh() const {
|
| 295 |
+
return static_cast<const Derived*>(this)->map_with_vec_float_method(
|
| 296 |
+
&Vectorized<float>::tanh);
|
| 297 |
+
}
|
| 298 |
+
Derived lgamma() const {
|
| 299 |
+
return static_cast<const Derived*>(this)->map_with_vec_float_method(
|
| 300 |
+
&Vectorized<float>::lgamma);
|
| 301 |
+
}
|
| 302 |
+
Derived rsqrt() const {
|
| 303 |
+
return static_cast<const Derived*>(this)->sqrt().reciprocal();
|
| 304 |
+
}
|
| 305 |
+
Derived pow(const Derived& exp) const {
|
| 306 |
+
return static_cast<const Derived*>(this)->map2_with_vec_float_method(
|
| 307 |
+
exp, &Vectorized<float>::pow);
|
| 308 |
+
}
|
| 309 |
+
};
|
| 310 |
+
|
| 311 |
+
} // namespace CPU_CAPABILITY
|
| 312 |
+
} // namespace at::vec
|
| 313 |
+
|
| 314 |
+
#else
|
| 315 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 316 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_base.h
ADDED
|
@@ -0,0 +1,1537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
#if defined(__GNUC__) && __GNUC__ == 10 && __GNUC_MINOR__ <= 2 && \
|
| 4 |
+
defined(__ARM_FEATURE_SVE)
|
| 5 |
+
// Workaround for https: //gcc.gnu.org/bugzilla/show_bug.cgi?id=117161
|
| 6 |
+
#pragma GCC optimize("no-tree-vectorize")
|
| 7 |
+
#endif
|
| 8 |
+
|
| 9 |
+
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
|
| 10 |
+
// See Note [Do not compile initializers with AVX]
|
| 11 |
+
//
|
| 12 |
+
// Note [Do not compile initializers with AVX]
|
| 13 |
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 14 |
+
// If you define a static initializer in this file, the initialization will use
|
| 15 |
+
// AVX instructions because these object files are compiled with AVX enabled.
|
| 16 |
+
// We need to avoid non-trivial global data in these architecture specific files
|
| 17 |
+
// because there's no way to guard the global initializers with CPU capability
|
| 18 |
+
// detection.
|
| 19 |
+
//
|
| 20 |
+
// See https://github.com/pytorch/pytorch/issues/37577 for an instance
|
| 21 |
+
// of this bug in the past.
|
| 22 |
+
|
| 23 |
+
#include <algorithm>
|
| 24 |
+
#include <array>
|
| 25 |
+
#include <cassert>
|
| 26 |
+
#include <climits>
|
| 27 |
+
#include <cmath>
|
| 28 |
+
#include <cstring>
|
| 29 |
+
#include <functional>
|
| 30 |
+
#include <type_traits>
|
| 31 |
+
|
| 32 |
+
#include <ATen/NumericUtils.h>
|
| 33 |
+
#include <ATen/cpu/vec/intrinsics.h>
|
| 34 |
+
#include <ATen/native/Math.h>
|
| 35 |
+
#include <ATen/native/cpu/zmath.h>
|
| 36 |
+
#include <c10/macros/Macros.h>
|
| 37 |
+
#include <c10/util/BFloat16-math.h>
|
| 38 |
+
#include <c10/util/BFloat16.h>
|
| 39 |
+
#include <c10/util/Half.h>
|
| 40 |
+
#include <c10/util/Load.h>
|
| 41 |
+
#include <c10/util/TypeCast.h>
|
| 42 |
+
#include <c10/util/copysign.h>
|
| 43 |
+
#include <c10/util/irange.h>
|
| 44 |
+
|
| 45 |
+
#if defined(__GNUC__)
|
| 46 |
+
#define __FORCE_INLINE __attribute__((always_inline)) inline
|
| 47 |
+
#elif defined(_MSC_VER)
|
| 48 |
+
#define __FORCE_INLINE __forceinline
|
| 49 |
+
#endif
|
| 50 |
+
|
| 51 |
+
#if defined(_MSC_FULL_VER)
|
| 52 |
+
/*
|
| 53 |
+
https://learn.microsoft.com/en-us/cpp/overview/compiler-versions?view=msvc-170
|
| 54 |
+
Use _MSC_FULL_VER to identify current compiler is msvc,
|
| 55 |
+
Windows llvm will not have this definition.
|
| 56 |
+
*/
|
| 57 |
+
#define __msvc_cl__
|
| 58 |
+
#endif
|
| 59 |
+
|
| 60 |
+
// These macros helped us unify vec_base.h
|
| 61 |
+
#ifdef CPU_CAPABILITY_AVX512
|
| 62 |
+
#if defined(__GNUC__)
|
| 63 |
+
#define __at_align__ __attribute__((aligned(64)))
|
| 64 |
+
#elif defined(_WIN32)
|
| 65 |
+
#define __at_align__ __declspec(align(64))
|
| 66 |
+
#else
|
| 67 |
+
#define __at_align__
|
| 68 |
+
#endif
|
| 69 |
+
#define VECTOR_WIDTH 64
|
| 70 |
+
#define int_vector __m512i
|
| 71 |
+
#elif defined(__aarch64__) && \
|
| 72 |
+
!defined(CPU_CAPABILITY_SVE) // CPU_CAPABILITY_AVX512
|
| 73 |
+
// SVE code expects 256-vectors; leave that set for SVE?
|
| 74 |
+
#if defined(__GNUC__)
|
| 75 |
+
#define __at_align__ __attribute__((aligned(16)))
|
| 76 |
+
#elif defined(_WIN32)
|
| 77 |
+
#define __at_align__ __declspec(align(16))
|
| 78 |
+
#else
|
| 79 |
+
#define __at_align__
|
| 80 |
+
#endif
|
| 81 |
+
#define VECTOR_WIDTH 16
|
| 82 |
+
#else // CPU_CAPABILITY_AVX512
|
| 83 |
+
#if defined(__GNUC__)
|
| 84 |
+
#define __at_align__ __attribute__((aligned(32)))
|
| 85 |
+
#elif defined(_WIN32)
|
| 86 |
+
#define __at_align__ __declspec(align(32))
|
| 87 |
+
#else
|
| 88 |
+
#define __at_align__
|
| 89 |
+
#endif
|
| 90 |
+
#define VECTOR_WIDTH 32
|
| 91 |
+
#define int_vector __m256i
|
| 92 |
+
#endif // CPU_CAPABILITY_AVX512
|
| 93 |
+
|
| 94 |
+
namespace at::vec {
|
| 95 |
+
// See Note [CPU_CAPABILITY namespace]
|
| 96 |
+
inline namespace CPU_CAPABILITY {
|
| 97 |
+
// at::Half and at::BFloat16 should be treated as floating point
|
| 98 |
+
template <typename T>
|
| 99 |
+
struct is_floating_point
|
| 100 |
+
: std::integral_constant<
|
| 101 |
+
bool,
|
| 102 |
+
std::is_floating_point_v<T> || std::is_same_v<T, at::Half> ||
|
| 103 |
+
std::is_same_v<T, at::BFloat16>> {};
|
| 104 |
+
|
| 105 |
+
template <typename T>
|
| 106 |
+
constexpr bool is_floating_point_v = is_floating_point<T>::value;
|
| 107 |
+
|
| 108 |
+
template <typename T>
|
| 109 |
+
struct is_reduced_floating_point
|
| 110 |
+
: std::integral_constant<
|
| 111 |
+
bool,
|
| 112 |
+
std::is_same_v<T, at::Half> || std::is_same_v<T, at::BFloat16>> {};
|
| 113 |
+
|
| 114 |
+
template <typename T>
|
| 115 |
+
constexpr bool is_reduced_floating_point_v =
|
| 116 |
+
is_reduced_floating_point<T>::value;
|
| 117 |
+
|
| 118 |
+
template <typename T>
|
| 119 |
+
struct is_8bit_integer
|
| 120 |
+
: std::integral_constant<
|
| 121 |
+
bool,
|
| 122 |
+
std::is_same_v<T, unsigned char> || std::is_same_v<T, signed char>> {
|
| 123 |
+
};
|
| 124 |
+
|
| 125 |
+
template <typename T>
|
| 126 |
+
constexpr bool is_8bit_integer_v = is_8bit_integer<T>::value;
|
| 127 |
+
|
| 128 |
+
template <size_t n>
|
| 129 |
+
struct int_of_size;
|
| 130 |
+
|
| 131 |
+
#define DEFINE_INT_OF_SIZE(int_t) \
|
| 132 |
+
template <> \
|
| 133 |
+
struct int_of_size<sizeof(int_t)> { \
|
| 134 |
+
using type = int_t; \
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
DEFINE_INT_OF_SIZE(int64_t);
|
| 138 |
+
DEFINE_INT_OF_SIZE(int32_t);
|
| 139 |
+
DEFINE_INT_OF_SIZE(int16_t);
|
| 140 |
+
DEFINE_INT_OF_SIZE(int8_t);
|
| 141 |
+
|
| 142 |
+
#undef DEFINE_INT_OF_SIZE
|
| 143 |
+
|
| 144 |
+
template <typename T>
|
| 145 |
+
using int_same_size_t = typename int_of_size<sizeof(T)>::type;
|
| 146 |
+
|
| 147 |
+
/**
|
| 148 |
+
* Detect at compile time whether Vectorized has an explicit
|
| 149 |
+
* specialization for T. (You are required to specialize this type
|
| 150 |
+
* whenever you specialize Vectorized). Useful for generic algorithms
|
| 151 |
+
* to decide whether to rely on a specialization being fast. For
|
| 152 |
+
* example, they might choose to handle reduced-precision floating
|
| 153 |
+
* point types directly if they're supported, or convert through float
|
| 154 |
+
* if not.
|
| 155 |
+
*/
|
| 156 |
+
#if defined(__s390x__)
|
| 157 |
+
template <class T, class TEMP = void>
|
| 158 |
+
#else
|
| 159 |
+
template <typename T>
|
| 160 |
+
#endif
|
| 161 |
+
struct is_vec_specialized_for : std::bool_constant<false> {
|
| 162 |
+
};
|
| 163 |
+
|
| 164 |
+
template <typename T>
|
| 165 |
+
constexpr bool is_vec_specialized_for_v = is_vec_specialized_for<T>::value;
|
| 166 |
+
|
| 167 |
+
// NOTE: If you specialize Vectorized on a type, you must define all
|
| 168 |
+
// operations! You must also specialize is_vec_specialized_for for
|
| 169 |
+
// that type.
|
| 170 |
+
|
| 171 |
+
// emulates Vectorized types
|
| 172 |
+
#if defined(__s390x__)
|
| 173 |
+
template <class T, class TEMP = void>
|
| 174 |
+
#else
|
| 175 |
+
template <class T>
|
| 176 |
+
#endif
|
| 177 |
+
struct Vectorized {
|
| 178 |
+
private:
|
| 179 |
+
__at_align__ T values[VECTOR_WIDTH / sizeof(T)];
|
| 180 |
+
|
| 181 |
+
public:
|
| 182 |
+
using value_type = T;
|
| 183 |
+
using size_type = int;
|
| 184 |
+
|
| 185 |
+
static constexpr size_type kSize = VECTOR_WIDTH / sizeof(T);
|
| 186 |
+
static constexpr size_type size() {
|
| 187 |
+
return kSize;
|
| 188 |
+
}
|
| 189 |
+
Vectorized() : values{static_cast<T>(0)} {}
|
| 190 |
+
Vectorized(T val) {
|
| 191 |
+
for (int i = 0; i != size(); i++) {
|
| 192 |
+
values[i] = val;
|
| 193 |
+
}
|
| 194 |
+
}
|
| 195 |
+
template <
|
| 196 |
+
typename... Args,
|
| 197 |
+
typename = std::enable_if_t<(sizeof...(Args) == size())>>
|
| 198 |
+
Vectorized(Args... vals) : values{vals...} {}
|
| 199 |
+
Vectorized(const T (&arr)[kSize]) {
|
| 200 |
+
std::memcpy(values, arr, sizeof(values));
|
| 201 |
+
}
|
| 202 |
+
// This also implies const T& operator[](int idx) const
|
| 203 |
+
inline operator const T*() const {
|
| 204 |
+
return values;
|
| 205 |
+
}
|
| 206 |
+
// This also implies T& operator[](int idx)
|
| 207 |
+
inline operator T*() {
|
| 208 |
+
return values;
|
| 209 |
+
}
|
| 210 |
+
// Return the values as char* for type punning
|
| 211 |
+
auto as_bytes() const -> const char* {
|
| 212 |
+
return reinterpret_cast<const char*>(values);
|
| 213 |
+
}
|
| 214 |
+
template <int64_t mask_>
|
| 215 |
+
static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b) {
|
| 216 |
+
int64_t mask = mask_;
|
| 217 |
+
Vectorized vector;
|
| 218 |
+
for (const auto i : c10::irange(size())) {
|
| 219 |
+
if (mask & 0x01) {
|
| 220 |
+
vector[i] = b[i];
|
| 221 |
+
} else {
|
| 222 |
+
vector[i] = a[i];
|
| 223 |
+
}
|
| 224 |
+
mask = mask >> 1;
|
| 225 |
+
}
|
| 226 |
+
return vector;
|
| 227 |
+
}
|
| 228 |
+
// Workaround for https: //gcc.gnu.org/bugzilla/show_bug.cgi?id=117001
|
| 229 |
+
#if __GNUC__ <= 12 && !defined(__clang__) && defined(__ARM_FEATURE_SVE)
|
| 230 |
+
static Vectorized<T> __attribute__((optimize("-fno-tree-loop-vectorize")))
|
| 231 |
+
blendv(
|
| 232 |
+
const Vectorized<T>& a,
|
| 233 |
+
#else
|
| 234 |
+
static Vectorized<T> blendv(
|
| 235 |
+
const Vectorized<T>& a,
|
| 236 |
+
#endif
|
| 237 |
+
const Vectorized<T>& b,
|
| 238 |
+
const Vectorized<T>& mask) {
|
| 239 |
+
Vectorized vector;
|
| 240 |
+
int_same_size_t<T> buffer[size()];
|
| 241 |
+
mask.store(buffer);
|
| 242 |
+
for (const auto i : c10::irange(size())) {
|
| 243 |
+
if (buffer[i] & 0x01) {
|
| 244 |
+
vector[i] = b[i];
|
| 245 |
+
} else {
|
| 246 |
+
vector[i] = a[i];
|
| 247 |
+
}
|
| 248 |
+
}
|
| 249 |
+
return vector;
|
| 250 |
+
}
|
| 251 |
+
template <typename step_t> // step sometimes requires a higher precision type
|
| 252 |
+
// (e.g., T=int, step_t=double)
|
| 253 |
+
static Vectorized<T> arange(
|
| 254 |
+
T base = static_cast<T>(0),
|
| 255 |
+
step_t step = static_cast<step_t>(1)) {
|
| 256 |
+
Vectorized vector;
|
| 257 |
+
for (const auto i : c10::irange(size())) {
|
| 258 |
+
vector.values[i] = base + i * step;
|
| 259 |
+
}
|
| 260 |
+
return vector;
|
| 261 |
+
}
|
| 262 |
+
static Vectorized<T> set(
|
| 263 |
+
const Vectorized<T>& a,
|
| 264 |
+
const Vectorized<T>& b,
|
| 265 |
+
int64_t count = size()) {
|
| 266 |
+
Vectorized vector;
|
| 267 |
+
for (const auto i : c10::irange(size())) {
|
| 268 |
+
if (i < count) {
|
| 269 |
+
vector[i] = b[i];
|
| 270 |
+
} else {
|
| 271 |
+
vector[i] = a[i];
|
| 272 |
+
}
|
| 273 |
+
}
|
| 274 |
+
return vector;
|
| 275 |
+
}
|
| 276 |
+
static Vectorized<T> loadu(const void* ptr) {
|
| 277 |
+
Vectorized vector;
|
| 278 |
+
std::memcpy(vector.values, ptr, VECTOR_WIDTH);
|
| 279 |
+
return vector;
|
| 280 |
+
}
|
| 281 |
+
static Vectorized<T> loadu(const void* ptr, int64_t count) {
|
| 282 |
+
Vectorized vector;
|
| 283 |
+
std::memcpy(vector.values, ptr, count * sizeof(T));
|
| 284 |
+
return vector;
|
| 285 |
+
}
|
| 286 |
+
static Vectorized<T> loadu_one_fourth(const void* ptr) {
|
| 287 |
+
static_assert(
|
| 288 |
+
std::is_same_v<T, signed char> || std::is_same_v<T, unsigned char>,
|
| 289 |
+
"For byte types only");
|
| 290 |
+
return Vectorized::loadu(ptr, 8);
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
void store(void* ptr, int count = size()) const {
|
| 294 |
+
std::memcpy(ptr, values, count * sizeof(T));
|
| 295 |
+
}
|
| 296 |
+
int zero_mask() const {
|
| 297 |
+
// returns an integer mask where all zero elements are translated to 1-bit
|
| 298 |
+
// and others are translated to 0-bit
|
| 299 |
+
int mask = 0;
|
| 300 |
+
for (int i = 0; i < size(); ++i) {
|
| 301 |
+
if (values[i] == static_cast<T>(0)) {
|
| 302 |
+
mask |= (1 << i);
|
| 303 |
+
}
|
| 304 |
+
}
|
| 305 |
+
return mask;
|
| 306 |
+
}
|
| 307 |
+
Vectorized<T> isnan() const {
|
| 308 |
+
Vectorized<T> vector;
|
| 309 |
+
for (int64_t i = 0; i != size(); i++) {
|
| 310 |
+
if (_isnan(values[i])) {
|
| 311 |
+
std::memset(static_cast<void*>(vector.values + i), 0xFF, sizeof(T));
|
| 312 |
+
} else {
|
| 313 |
+
std::memset(static_cast<void*>(vector.values + i), 0, sizeof(T));
|
| 314 |
+
}
|
| 315 |
+
}
|
| 316 |
+
return vector;
|
| 317 |
+
}
|
| 318 |
+
bool has_inf_nan() const {
|
| 319 |
+
for (int64_t i = 0; i != size(); i++) {
|
| 320 |
+
if (_isnan(values[i]) || _isinf(values[i])) {
|
| 321 |
+
return true;
|
| 322 |
+
}
|
| 323 |
+
}
|
| 324 |
+
return false;
|
| 325 |
+
}
|
| 326 |
+
// MSVC versions between 14.36 and 14.42 has a loop unrolling bug on Windows
|
| 327 |
+
// Arm64
|
| 328 |
+
// See
|
| 329 |
+
// https://developercommunity.visualstudio.com/t/MSVC-loop-unrolling-problem-194033813-/10720692
|
| 330 |
+
#if defined(_WIN32) && defined(__aarch64__) && \
|
| 331 |
+
((_MSVC_VER >= 1936) && (_MSVC_VER <= 1942))
|
| 332 |
+
Vectorized<T> map(T (*const f)(T)) const {
|
| 333 |
+
Vectorized<T> ret;
|
| 334 |
+
for (int64_t i = 0; i < size(); i++) {
|
| 335 |
+
ret[i] = f(values[i]);
|
| 336 |
+
if (++i < size())
|
| 337 |
+
ret[i] = f(values[i]);
|
| 338 |
+
}
|
| 339 |
+
return ret;
|
| 340 |
+
}
|
| 341 |
+
T reduce(T (*const f)(T)) const {
|
| 342 |
+
T ret = 0;
|
| 343 |
+
for (int64_t i = 0; i < size(); i++) {
|
| 344 |
+
ret = f(ret, values[i]);
|
| 345 |
+
if (++i < size())
|
| 346 |
+
ret = f(ret, values[i]);
|
| 347 |
+
}
|
| 348 |
+
return ret;
|
| 349 |
+
}
|
| 350 |
+
#else
|
| 351 |
+
Vectorized<T> map(T (*const f)(T)) const {
|
| 352 |
+
Vectorized<T> ret;
|
| 353 |
+
for (int64_t i = 0; i != size(); i++) {
|
| 354 |
+
ret[i] = f(values[i]);
|
| 355 |
+
}
|
| 356 |
+
return ret;
|
| 357 |
+
}
|
| 358 |
+
T reduce(T (*const f)(T)) const {
|
| 359 |
+
T ret = 0;
|
| 360 |
+
for (int64_t i = 0; i != size(); i++) {
|
| 361 |
+
ret = f(ret, values[i]);
|
| 362 |
+
}
|
| 363 |
+
return ret;
|
| 364 |
+
}
|
| 365 |
+
#endif
|
| 366 |
+
Vectorized<T> map(T (*const f)(const T&)) const {
|
| 367 |
+
Vectorized<T> ret;
|
| 368 |
+
for (int64_t i = 0; i != size(); i++) {
|
| 369 |
+
ret[i] = f(values[i]);
|
| 370 |
+
}
|
| 371 |
+
return ret;
|
| 372 |
+
}
|
| 373 |
+
T reduce(T (*const f)(const T&)) const {
|
| 374 |
+
T ret = 0;
|
| 375 |
+
for (int64_t i = 0; i != size(); i++) {
|
| 376 |
+
ret = f(ret, values[i]);
|
| 377 |
+
}
|
| 378 |
+
return ret;
|
| 379 |
+
}
|
| 380 |
+
template <
|
| 381 |
+
typename other_t_abs = T,
|
| 382 |
+
typename std::enable_if_t<
|
| 383 |
+
!is_floating_point_v<other_t_abs> &&
|
| 384 |
+
!c10::is_complex<other_t_abs>::value,
|
| 385 |
+
int> = 0>
|
| 386 |
+
Vectorized<T> abs() const {
|
| 387 |
+
// other_t_abs is for SFINAE and clarity. Make sure it is not changed.
|
| 388 |
+
static_assert(std::is_same_v<other_t_abs, T>, "other_t_abs must be T");
|
| 389 |
+
return map([](T x) -> T { return x < static_cast<T>(0) ? -x : x; });
|
| 390 |
+
}
|
| 391 |
+
template <
|
| 392 |
+
typename float_t_abs = T,
|
| 393 |
+
typename std::enable_if_t<is_floating_point_v<float_t_abs>, int> = 0>
|
| 394 |
+
Vectorized<T> abs() const {
|
| 395 |
+
// float_t_abs is for SFINAE and clarity. Make sure it is not changed.
|
| 396 |
+
static_assert(std::is_same_v<float_t_abs, T>, "float_t_abs must be T");
|
| 397 |
+
// Specifically deal with floating-point because the generic code above
|
| 398 |
+
// won't handle -0.0 (which should result in 0.0) properly.
|
| 399 |
+
return map([](T x) -> T { return std::abs(x); });
|
| 400 |
+
}
|
| 401 |
+
template <
|
| 402 |
+
typename complex_t_abs = T,
|
| 403 |
+
typename std::enable_if_t<c10::is_complex<complex_t_abs>::value, int> = 0>
|
| 404 |
+
Vectorized<T> abs() const {
|
| 405 |
+
// complex_t_abs is for SFINAE and clarity. Make sure it is not changed.
|
| 406 |
+
static_assert(std::is_same_v<complex_t_abs, T>, "complex_t_abs must be T");
|
| 407 |
+
// Specifically map() does not perform the type conversion needed by abs.
|
| 408 |
+
return map([](T x) { return static_cast<T>(std::abs(x)); });
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
template <
|
| 412 |
+
typename other_t_sgn = T,
|
| 413 |
+
typename std::enable_if_t<c10::is_complex<other_t_sgn>::value, int> = 0>
|
| 414 |
+
Vectorized<T> sgn() const {
|
| 415 |
+
return map(at::native::sgn_impl);
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
template <
|
| 419 |
+
typename other_t_angle = T,
|
| 420 |
+
typename std::enable_if_t<!c10::is_complex<other_t_angle>::value, int> =
|
| 421 |
+
0>
|
| 422 |
+
Vectorized<T> angle() const {
|
| 423 |
+
// other_t_angle is for SFINAE and clarity. Make sure it is not changed.
|
| 424 |
+
static_assert(std::is_same_v<other_t_angle, T>, "other_t_angle must be T");
|
| 425 |
+
return map(at::native::angle_impl<T>); // compiler is unable to resolve the
|
| 426 |
+
// overload without <T>
|
| 427 |
+
}
|
| 428 |
+
template <
|
| 429 |
+
typename complex_t_angle = T,
|
| 430 |
+
typename std::enable_if_t<c10::is_complex<complex_t_angle>::value, int> =
|
| 431 |
+
0>
|
| 432 |
+
Vectorized<T> angle() const {
|
| 433 |
+
// complex_t_angle is for SFINAE and clarity. Make sure it is not changed.
|
| 434 |
+
static_assert(
|
| 435 |
+
std::is_same_v<complex_t_angle, T>, "complex_t_angle must be T");
|
| 436 |
+
return map([](T x) { return static_cast<T>(std::arg(x)); });
|
| 437 |
+
}
|
| 438 |
+
template <
|
| 439 |
+
typename other_t_real = T,
|
| 440 |
+
typename std::enable_if_t<!c10::is_complex<other_t_real>::value, int> = 0>
|
| 441 |
+
Vectorized<T> real() const {
|
| 442 |
+
// other_t_real is for SFINAE and clarity. Make sure it is not changed.
|
| 443 |
+
static_assert(std::is_same_v<other_t_real, T>, "other_t_real must be T");
|
| 444 |
+
return *this;
|
| 445 |
+
}
|
| 446 |
+
template <
|
| 447 |
+
typename complex_t_real = T,
|
| 448 |
+
typename std::enable_if_t<c10::is_complex<complex_t_real>::value, int> =
|
| 449 |
+
0>
|
| 450 |
+
Vectorized<T> real() const {
|
| 451 |
+
// complex_t_real is for SFINAE and clarity. Make sure it is not changed.
|
| 452 |
+
static_assert(
|
| 453 |
+
std::is_same_v<complex_t_real, T>, "complex_t_real must be T");
|
| 454 |
+
return map([](T x) { return static_cast<T>(x.real()); });
|
| 455 |
+
}
|
| 456 |
+
template <
|
| 457 |
+
typename other_t_imag = T,
|
| 458 |
+
typename std::enable_if_t<!c10::is_complex<other_t_imag>::value, int> = 0>
|
| 459 |
+
Vectorized<T> imag() const {
|
| 460 |
+
// other_t_imag is for SFINAE and clarity. Make sure it is not changed.
|
| 461 |
+
static_assert(std::is_same_v<other_t_imag, T>, "other_t_imag must be T");
|
| 462 |
+
return Vectorized(0);
|
| 463 |
+
}
|
| 464 |
+
template <
|
| 465 |
+
typename complex_t_imag = T,
|
| 466 |
+
typename std::enable_if_t<c10::is_complex<complex_t_imag>::value, int> =
|
| 467 |
+
0>
|
| 468 |
+
Vectorized<T> imag() const {
|
| 469 |
+
// complex_t_imag is for SFINAE and clarity. Make sure it is not changed.
|
| 470 |
+
static_assert(
|
| 471 |
+
std::is_same_v<complex_t_imag, T>, "complex_t_imag must be T");
|
| 472 |
+
return map([](T x) { return static_cast<T>(x.imag()); });
|
| 473 |
+
}
|
| 474 |
+
template <
|
| 475 |
+
typename other_t_conj = T,
|
| 476 |
+
typename std::enable_if_t<!c10::is_complex<other_t_conj>::value, int> = 0>
|
| 477 |
+
Vectorized<T> conj() const {
|
| 478 |
+
// other_t_conj is for SFINAE and clarity. Make sure it is not changed.
|
| 479 |
+
static_assert(std::is_same_v<other_t_conj, T>, "other_t_conj must be T");
|
| 480 |
+
return *this;
|
| 481 |
+
}
|
| 482 |
+
template <
|
| 483 |
+
typename complex_t_conj = T,
|
| 484 |
+
typename std::enable_if_t<c10::is_complex<complex_t_conj>::value, int> =
|
| 485 |
+
0>
|
| 486 |
+
Vectorized<T> conj() const {
|
| 487 |
+
// complex_t_conj is for SFINAE and clarity. Make sure it is not changed.
|
| 488 |
+
static_assert(
|
| 489 |
+
std::is_same_v<complex_t_conj, T>, "complex_t_conj must be T");
|
| 490 |
+
return map([](T x) { return static_cast<T>(std::conj(x)); });
|
| 491 |
+
}
|
| 492 |
+
Vectorized<T> acos() const {
|
| 493 |
+
return map(std::acos);
|
| 494 |
+
}
|
| 495 |
+
Vectorized<T> acosh() const {
|
| 496 |
+
return map(std::acosh);
|
| 497 |
+
}
|
| 498 |
+
Vectorized<T> asin() const {
|
| 499 |
+
return map(std::asin);
|
| 500 |
+
}
|
| 501 |
+
Vectorized<T> asinh() const {
|
| 502 |
+
return map(std::asinh);
|
| 503 |
+
}
|
| 504 |
+
Vectorized<T> atan() const {
|
| 505 |
+
return map(std::atan);
|
| 506 |
+
}
|
| 507 |
+
Vectorized<T> atanh() const {
|
| 508 |
+
return map(std::atanh);
|
| 509 |
+
}
|
| 510 |
+
Vectorized<T> atan2(const Vectorized<T>& exp) const {
|
| 511 |
+
Vectorized<T> ret;
|
| 512 |
+
for (const auto i : c10::irange(size())) {
|
| 513 |
+
ret[i] = std::atan2(values[i], exp[i]);
|
| 514 |
+
}
|
| 515 |
+
return ret;
|
| 516 |
+
}
|
| 517 |
+
template <
|
| 518 |
+
typename U = T,
|
| 519 |
+
typename std::enable_if_t<is_floating_point_v<U>, int> = 0>
|
| 520 |
+
Vectorized<T> copysign(const Vectorized<T>& sign) const {
|
| 521 |
+
Vectorized<T> ret;
|
| 522 |
+
for (size_type i = 0; i < size(); i++) {
|
| 523 |
+
ret[i] = c10::copysign(values[i], sign[i]);
|
| 524 |
+
}
|
| 525 |
+
return ret;
|
| 526 |
+
}
|
| 527 |
+
Vectorized<T> erf() const {
|
| 528 |
+
return map(std::erf);
|
| 529 |
+
}
|
| 530 |
+
Vectorized<T> erfc() const {
|
| 531 |
+
return map(std::erfc);
|
| 532 |
+
}
|
| 533 |
+
Vectorized<T> erfinv() const {
|
| 534 |
+
return map(calc_erfinv);
|
| 535 |
+
}
|
| 536 |
+
Vectorized<T> exp() const {
|
| 537 |
+
return map(std::exp);
|
| 538 |
+
}
|
| 539 |
+
Vectorized<T> exp2() const {
|
| 540 |
+
return map(exp2_impl);
|
| 541 |
+
}
|
| 542 |
+
Vectorized<T> expm1() const {
|
| 543 |
+
return map(std::expm1);
|
| 544 |
+
}
|
| 545 |
+
Vectorized<T> exp_u20() const {
|
| 546 |
+
return map(std::exp);
|
| 547 |
+
}
|
| 548 |
+
Vectorized<T> fexp_u20() const {
|
| 549 |
+
return map(std::exp);
|
| 550 |
+
}
|
| 551 |
+
Vectorized<T> frac() const {
|
| 552 |
+
return *this - this->trunc();
|
| 553 |
+
}
|
| 554 |
+
template <
|
| 555 |
+
typename U = T,
|
| 556 |
+
typename std::enable_if_t<is_floating_point_v<U>, int> = 0>
|
| 557 |
+
Vectorized<T> fmod(const Vectorized<T>& q) const {
|
| 558 |
+
// U is for SFINAE purposes only. Make sure it is not changed.
|
| 559 |
+
static_assert(std::is_same_v<U, T>, "U must be T");
|
| 560 |
+
Vectorized<T> ret;
|
| 561 |
+
for (const auto i : c10::irange(size())) {
|
| 562 |
+
ret[i] = std::fmod(values[i], q[i]);
|
| 563 |
+
}
|
| 564 |
+
return ret;
|
| 565 |
+
}
|
| 566 |
+
Vectorized<T> log() const {
|
| 567 |
+
return map(std::log);
|
| 568 |
+
}
|
| 569 |
+
Vectorized<T> log10() const {
|
| 570 |
+
return map(std::log10);
|
| 571 |
+
}
|
| 572 |
+
Vectorized<T> log1p() const {
|
| 573 |
+
return map(std::log1p);
|
| 574 |
+
}
|
| 575 |
+
template <
|
| 576 |
+
typename other_t_log2 = T,
|
| 577 |
+
typename std::enable_if_t<!c10::is_complex<other_t_log2>::value, int> = 0>
|
| 578 |
+
Vectorized<T> log2() const {
|
| 579 |
+
// other_t_log2 is for SFINAE and clarity. Make sure it is not changed.
|
| 580 |
+
static_assert(std::is_same_v<other_t_log2, T>, "other_t_log2 must be T");
|
| 581 |
+
return map(std::log2);
|
| 582 |
+
}
|
| 583 |
+
template <
|
| 584 |
+
typename complex_t_log2 = T,
|
| 585 |
+
typename std::enable_if_t<c10::is_complex<complex_t_log2>::value, int> =
|
| 586 |
+
0>
|
| 587 |
+
Vectorized<T> log2() const {
|
| 588 |
+
// complex_t_log2 is for SFINAE and clarity. Make sure it is not changed.
|
| 589 |
+
static_assert(
|
| 590 |
+
std::is_same_v<complex_t_log2, T>, "complex_t_log2 must be T");
|
| 591 |
+
const T log_2 = T(std::log(2.0));
|
| 592 |
+
return Vectorized(map(std::log)) / Vectorized(log_2);
|
| 593 |
+
}
|
| 594 |
+
Vectorized<T> ceil() const {
|
| 595 |
+
return map(at::native::ceil_impl);
|
| 596 |
+
}
|
| 597 |
+
Vectorized<T> cos() const {
|
| 598 |
+
return map(std::cos);
|
| 599 |
+
}
|
| 600 |
+
Vectorized<T> cosh() const {
|
| 601 |
+
return map(std::cosh);
|
| 602 |
+
}
|
| 603 |
+
Vectorized<T> floor() const {
|
| 604 |
+
return map(at::native::floor_impl);
|
| 605 |
+
}
|
| 606 |
+
Vectorized<T> hypot(const Vectorized<T>& b) const {
|
| 607 |
+
Vectorized<T> ret;
|
| 608 |
+
for (const auto i : c10::irange(size())) {
|
| 609 |
+
ret[i] = std::hypot(values[i], b[i]);
|
| 610 |
+
}
|
| 611 |
+
return ret;
|
| 612 |
+
}
|
| 613 |
+
Vectorized<T> i0() const {
|
| 614 |
+
return map(calc_i0);
|
| 615 |
+
}
|
| 616 |
+
Vectorized<T> i0e() const {
|
| 617 |
+
return map(calc_i0e);
|
| 618 |
+
}
|
| 619 |
+
Vectorized<T> digamma() const {
|
| 620 |
+
return map(calc_digamma);
|
| 621 |
+
}
|
| 622 |
+
Vectorized<T> igamma(const Vectorized<T>& x) const {
|
| 623 |
+
Vectorized<T> ret;
|
| 624 |
+
for (const auto i : c10::irange(size())) {
|
| 625 |
+
ret[i] = calc_igamma(values[i], x[i]);
|
| 626 |
+
}
|
| 627 |
+
return ret;
|
| 628 |
+
}
|
| 629 |
+
Vectorized<T> igammac(const Vectorized<T>& x) const {
|
| 630 |
+
Vectorized<T> ret;
|
| 631 |
+
for (const auto i : c10::irange(size())) {
|
| 632 |
+
ret[i] = calc_igammac(values[i], x[i]);
|
| 633 |
+
}
|
| 634 |
+
return ret;
|
| 635 |
+
}
|
| 636 |
+
Vectorized<T> neg() const {
|
| 637 |
+
// NB: the trailing return type is needed because we need to coerce the
|
| 638 |
+
// return value back to T in the case of unary operator- incurring a
|
| 639 |
+
// promotion
|
| 640 |
+
return map([](T x) -> T { return -x; });
|
| 641 |
+
}
|
| 642 |
+
Vectorized<T> nextafter(const Vectorized<T>& b) const {
|
| 643 |
+
Vectorized<T> ret;
|
| 644 |
+
for (const auto i : c10::irange(size())) {
|
| 645 |
+
ret[i] = std::nextafter(values[i], b[i]);
|
| 646 |
+
}
|
| 647 |
+
return ret;
|
| 648 |
+
}
|
| 649 |
+
Vectorized<T> round() const {
|
| 650 |
+
// We do not use std::round because we would like to round midway numbers to
|
| 651 |
+
// the nearest even integer.
|
| 652 |
+
return map(at::native::round_impl);
|
| 653 |
+
}
|
| 654 |
+
Vectorized<T> sin() const {
|
| 655 |
+
return map(std::sin);
|
| 656 |
+
}
|
| 657 |
+
Vectorized<T> sinh() const {
|
| 658 |
+
return map(std::sinh);
|
| 659 |
+
}
|
| 660 |
+
Vectorized<T> tan() const {
|
| 661 |
+
return map(std::tan);
|
| 662 |
+
}
|
| 663 |
+
Vectorized<T> tanh() const {
|
| 664 |
+
return map(std::tanh);
|
| 665 |
+
}
|
| 666 |
+
Vectorized<T> trunc() const {
|
| 667 |
+
return map(at::native::trunc_impl);
|
| 668 |
+
}
|
| 669 |
+
Vectorized<T> lgamma() const {
|
| 670 |
+
return map(std::lgamma);
|
| 671 |
+
}
|
| 672 |
+
Vectorized<T> sqrt() const {
|
| 673 |
+
return map(std::sqrt);
|
| 674 |
+
}
|
| 675 |
+
Vectorized<T> reciprocal() const {
|
| 676 |
+
return map([](T x) { return (T)1 / x; });
|
| 677 |
+
}
|
| 678 |
+
Vectorized<T> rsqrt() const {
|
| 679 |
+
return map([](T x) { return (T)1 / std::sqrt(x); });
|
| 680 |
+
}
|
| 681 |
+
Vectorized<T> pow(const Vectorized<T>& exp) const {
|
| 682 |
+
Vectorized<T> ret;
|
| 683 |
+
for (const auto i : c10::irange(size())) {
|
| 684 |
+
ret[i] = std::pow(values[i], exp[i]);
|
| 685 |
+
}
|
| 686 |
+
return ret;
|
| 687 |
+
}
|
| 688 |
+
T reduce_add() const {
|
| 689 |
+
return reduce([](T x, T y) -> T { return x + y; });
|
| 690 |
+
}
|
| 691 |
+
T reduce_max() const {
|
| 692 |
+
return reduce(std::max);
|
| 693 |
+
}
|
| 694 |
+
|
| 695 |
+
private:
|
| 696 |
+
template <typename Op>
|
| 697 |
+
inline Vectorized<T> binary_pred(const Vectorized<T>& other, Op op) const {
|
| 698 |
+
// All bits are set to 1 if the pred is true, otherwise 0.
|
| 699 |
+
Vectorized<T> vector;
|
| 700 |
+
for (int64_t i = 0; i != size(); i++) {
|
| 701 |
+
if (op(values[i], other.values[i])) {
|
| 702 |
+
std::memset(static_cast<void*>(vector.values + i), 0xFF, sizeof(T));
|
| 703 |
+
} else {
|
| 704 |
+
std::memset(static_cast<void*>(vector.values + i), 0, sizeof(T));
|
| 705 |
+
}
|
| 706 |
+
}
|
| 707 |
+
return vector;
|
| 708 |
+
}
|
| 709 |
+
|
| 710 |
+
public:
|
| 711 |
+
Vectorized<T> operator==(const Vectorized<T>& other) const {
|
| 712 |
+
return binary_pred(other, std::equal_to<T>());
|
| 713 |
+
}
|
| 714 |
+
Vectorized<T> operator!=(const Vectorized<T>& other) const {
|
| 715 |
+
return binary_pred(other, std::not_equal_to<T>());
|
| 716 |
+
}
|
| 717 |
+
Vectorized<T> operator>=(const Vectorized<T>& other) const {
|
| 718 |
+
return binary_pred(other, std::greater_equal<T>());
|
| 719 |
+
}
|
| 720 |
+
Vectorized<T> operator<=(const Vectorized<T>& other) const {
|
| 721 |
+
return binary_pred(other, std::less_equal<T>());
|
| 722 |
+
}
|
| 723 |
+
Vectorized<T> operator>(const Vectorized<T>& other) const {
|
| 724 |
+
return binary_pred(other, std::greater<T>());
|
| 725 |
+
}
|
| 726 |
+
Vectorized<T> operator<(const Vectorized<T>& other) const {
|
| 727 |
+
return binary_pred(other, std::less<T>());
|
| 728 |
+
}
|
| 729 |
+
|
| 730 |
+
private:
|
| 731 |
+
template <typename Op>
|
| 732 |
+
inline Vectorized<T> binary_pred_bool(const Vectorized<T>& other, Op op)
|
| 733 |
+
const {
|
| 734 |
+
// 1 if the pred is true, otherwise 0.
|
| 735 |
+
Vectorized<T> vector;
|
| 736 |
+
for (int i = 0; i != size(); ++i) {
|
| 737 |
+
vector[i] = static_cast<T>(op(values[i], other.values[i]));
|
| 738 |
+
}
|
| 739 |
+
return vector;
|
| 740 |
+
}
|
| 741 |
+
|
| 742 |
+
public:
|
| 743 |
+
Vectorized<T> eq(const Vectorized<T>& other) const {
|
| 744 |
+
return binary_pred_bool(other, std::equal_to<T>());
|
| 745 |
+
}
|
| 746 |
+
Vectorized<T> ne(const Vectorized<T>& other) const {
|
| 747 |
+
return binary_pred_bool(other, std::not_equal_to<T>());
|
| 748 |
+
}
|
| 749 |
+
Vectorized<T> gt(const Vectorized<T>& other) const {
|
| 750 |
+
return binary_pred_bool(other, std::greater<T>());
|
| 751 |
+
}
|
| 752 |
+
Vectorized<T> ge(const Vectorized<T>& other) const {
|
| 753 |
+
return binary_pred_bool(other, std::greater_equal<T>());
|
| 754 |
+
}
|
| 755 |
+
Vectorized<T> lt(const Vectorized<T>& other) const {
|
| 756 |
+
return binary_pred_bool(other, std::less<T>());
|
| 757 |
+
}
|
| 758 |
+
Vectorized<T> le(const Vectorized<T>& other) const {
|
| 759 |
+
return binary_pred_bool(other, std::less_equal<T>());
|
| 760 |
+
}
|
| 761 |
+
};
|
| 762 |
+
|
| 763 |
+
template <class T>
|
| 764 |
+
Vectorized<T> inline operator-(const Vectorized<T>& a) {
|
| 765 |
+
return a.neg();
|
| 766 |
+
}
|
| 767 |
+
|
| 768 |
+
// There is an implicit conversion that would make this work if
|
| 769 |
+
// these operators weren't template functions, but they are template
|
| 770 |
+
// functions (and can't be moved to be non-member friends defined in
|
| 771 |
+
// the class body as suggested in
|
| 772 |
+
// https://stackoverflow.com/questions/9787593/implicit-type-conversion-with-template/9788255#9788255
|
| 773 |
+
// because we have a lot of disparate specializations of
|
| 774 |
+
// Vectorized). So, just explicitly make scalars work.
|
| 775 |
+
#define VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(name) \
|
| 776 |
+
template <class T> \
|
| 777 |
+
Vectorized<T> inline name(const Vectorized<T>& a, T b) { \
|
| 778 |
+
return name(a, Vectorized<T>(b)); \
|
| 779 |
+
} \
|
| 780 |
+
template <class T> \
|
| 781 |
+
Vectorized<T> inline name(T a, const Vectorized<T>& b) { \
|
| 782 |
+
return name(Vectorized<T>(a), b); \
|
| 783 |
+
}
|
| 784 |
+
#define VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(op) \
|
| 785 |
+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(operator op)
|
| 786 |
+
|
| 787 |
+
template <class T>
|
| 788 |
+
Vectorized<T> inline operator+(const Vectorized<T>& a, const Vectorized<T>& b) {
|
| 789 |
+
Vectorized<T> c;
|
| 790 |
+
for (int i = 0; i != Vectorized<T>::size(); i++) {
|
| 791 |
+
c[i] = a[i] + b[i];
|
| 792 |
+
}
|
| 793 |
+
return c;
|
| 794 |
+
}
|
| 795 |
+
|
| 796 |
+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(+)
|
| 797 |
+
|
| 798 |
+
template <class T>
|
| 799 |
+
Vectorized<T> inline operator-(const Vectorized<T>& a, const Vectorized<T>& b) {
|
| 800 |
+
Vectorized<T> c;
|
| 801 |
+
for (int i = 0; i != Vectorized<T>::size(); i++) {
|
| 802 |
+
c[i] = a[i] - b[i];
|
| 803 |
+
}
|
| 804 |
+
return c;
|
| 805 |
+
}
|
| 806 |
+
|
| 807 |
+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(-)
|
| 808 |
+
|
| 809 |
+
template <class T>
|
| 810 |
+
Vectorized<T> inline operator*(const Vectorized<T>& a, const Vectorized<T>& b) {
|
| 811 |
+
Vectorized<T> c;
|
| 812 |
+
for (int i = 0; i != Vectorized<T>::size(); i++) {
|
| 813 |
+
c[i] = a[i] * b[i];
|
| 814 |
+
}
|
| 815 |
+
return c;
|
| 816 |
+
}
|
| 817 |
+
|
| 818 |
+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(*)
|
| 819 |
+
|
| 820 |
+
template <class T>
|
| 821 |
+
Vectorized<T> inline operator/(const Vectorized<T>& a, const Vectorized<T>& b)
|
| 822 |
+
__ubsan_ignore_float_divide_by_zero__ {
|
| 823 |
+
Vectorized<T> c;
|
| 824 |
+
for (int i = 0; i != Vectorized<T>::size(); i++) {
|
| 825 |
+
c[i] = a[i] / b[i];
|
| 826 |
+
}
|
| 827 |
+
return c;
|
| 828 |
+
}
|
| 829 |
+
|
| 830 |
+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(/)
|
| 831 |
+
|
| 832 |
+
template <class T, typename std::enable_if_t<!is_floating_point_v<T>, int> = 0>
|
| 833 |
+
Vectorized<T> inline operator%(const Vectorized<T>& a, const Vectorized<T>& b)
|
| 834 |
+
__ubsan_ignore_float_divide_by_zero__ {
|
| 835 |
+
return a - a / b * b;
|
| 836 |
+
}
|
| 837 |
+
|
| 838 |
+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(%)
|
| 839 |
+
|
| 840 |
+
template <class T>
|
| 841 |
+
Vectorized<T> inline operator||(
|
| 842 |
+
const Vectorized<T>& a,
|
| 843 |
+
const Vectorized<T>& b) {
|
| 844 |
+
Vectorized<T> c;
|
| 845 |
+
for (int i = 0; i != Vectorized<T>::size(); i++) {
|
| 846 |
+
c[i] = a[i] || b[i];
|
| 847 |
+
}
|
| 848 |
+
return c;
|
| 849 |
+
}
|
| 850 |
+
|
| 851 |
+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(||)
|
| 852 |
+
|
| 853 |
+
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
|
| 854 |
+
// either input is a NaN.
|
| 855 |
+
template <
|
| 856 |
+
class T,
|
| 857 |
+
typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
|
| 858 |
+
Vectorized<T> inline maximum(const Vectorized<T>& a, const Vectorized<T>& b) {
|
| 859 |
+
Vectorized<T> c;
|
| 860 |
+
for (int i = 0; i != Vectorized<T>::size(); i++) {
|
| 861 |
+
c[i] = (a[i] > b[i]) ? a[i] : b[i];
|
| 862 |
+
if (_isnan(a[i])) {
|
| 863 |
+
// If either input is NaN, propagate a NaN.
|
| 864 |
+
// NOTE: The case where b[i] was NaN is handled correctly by the naive
|
| 865 |
+
// ternary operator above.
|
| 866 |
+
c[i] = a[i];
|
| 867 |
+
}
|
| 868 |
+
}
|
| 869 |
+
return c;
|
| 870 |
+
}
|
| 871 |
+
|
| 872 |
+
template <
|
| 873 |
+
class T,
|
| 874 |
+
typename std::enable_if_t<c10::is_complex<T>::value, int> = 0>
|
| 875 |
+
Vectorized<T> inline maximum(const Vectorized<T>& a, const Vectorized<T>& b) {
|
| 876 |
+
Vectorized<T> c;
|
| 877 |
+
for (int i = 0; i != Vectorized<T>::size(); i++) {
|
| 878 |
+
c[i] = (std::abs(a[i]) > std::abs(b[i])) ? a[i] : b[i];
|
| 879 |
+
if (_isnan(a[i])) {
|
| 880 |
+
// If either input is NaN, propagate a NaN.
|
| 881 |
+
// NOTE: The case where b[i] was NaN is handled correctly by the naive
|
| 882 |
+
// ternary operator above.
|
| 883 |
+
c[i] = a[i];
|
| 884 |
+
}
|
| 885 |
+
}
|
| 886 |
+
return c;
|
| 887 |
+
}
|
| 888 |
+
|
| 889 |
+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(maximum)
|
| 890 |
+
|
| 891 |
+
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
|
| 892 |
+
// either input is a NaN.
|
| 893 |
+
template <
|
| 894 |
+
class T,
|
| 895 |
+
typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
|
| 896 |
+
Vectorized<T> inline minimum(const Vectorized<T>& a, const Vectorized<T>& b) {
|
| 897 |
+
Vectorized<T> c;
|
| 898 |
+
for (int i = 0; i != Vectorized<T>::size(); i++) {
|
| 899 |
+
c[i] = (a[i] < b[i]) ? a[i] : b[i];
|
| 900 |
+
if (_isnan(a[i])) {
|
| 901 |
+
// If either input is NaN, propagate a NaN.
|
| 902 |
+
// NOTE: The case where b[i] was NaN is handled correctly by the naive
|
| 903 |
+
// ternary operator above.
|
| 904 |
+
c[i] = a[i];
|
| 905 |
+
}
|
| 906 |
+
}
|
| 907 |
+
return c;
|
| 908 |
+
}
|
| 909 |
+
|
| 910 |
+
template <
|
| 911 |
+
class T,
|
| 912 |
+
typename std::enable_if_t<c10::is_complex<T>::value, int> = 0>
|
| 913 |
+
Vectorized<T> inline minimum(const Vectorized<T>& a, const Vectorized<T>& b) {
|
| 914 |
+
Vectorized<T> c;
|
| 915 |
+
for (int i = 0; i != Vectorized<T>::size(); i++) {
|
| 916 |
+
c[i] = (std::abs(a[i]) < std::abs(b[i])) ? a[i] : b[i];
|
| 917 |
+
if (_isnan(a[i])) {
|
| 918 |
+
// If either input is NaN, propagate a NaN.
|
| 919 |
+
// NOTE: The case where b[i] was NaN is handled correctly by the naive
|
| 920 |
+
// ternary operator above.
|
| 921 |
+
c[i] = a[i];
|
| 922 |
+
}
|
| 923 |
+
}
|
| 924 |
+
return c;
|
| 925 |
+
}
|
| 926 |
+
|
| 927 |
+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(minimum)
|
| 928 |
+
|
| 929 |
+
template <
|
| 930 |
+
class T,
|
| 931 |
+
typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
|
| 932 |
+
Vectorized<T> inline clamp(
|
| 933 |
+
const Vectorized<T>& a,
|
| 934 |
+
const Vectorized<T>& min_vec,
|
| 935 |
+
const Vectorized<T>& max_vec) {
|
| 936 |
+
Vectorized<T> c;
|
| 937 |
+
for (int i = 0; i != Vectorized<T>::size(); i++) {
|
| 938 |
+
c[i] = std::min(std::max(a[i], min_vec[i]), max_vec[i]);
|
| 939 |
+
}
|
| 940 |
+
return c;
|
| 941 |
+
}
|
| 942 |
+
|
| 943 |
+
#define VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(name) \
|
| 944 |
+
template <class T> \
|
| 945 |
+
Vectorized<T> inline name( \
|
| 946 |
+
const Vectorized<T>& a, const Vectorized<T>& b, T c) { \
|
| 947 |
+
return name(a, b, Vectorized<T>(c)); \
|
| 948 |
+
} \
|
| 949 |
+
\
|
| 950 |
+
template <class T> \
|
| 951 |
+
Vectorized<T> inline name( \
|
| 952 |
+
const Vectorized<T>& a, T b, const Vectorized<T>& c) { \
|
| 953 |
+
return name(a, Vectorized<T>(b), c); \
|
| 954 |
+
} \
|
| 955 |
+
\
|
| 956 |
+
template <class T> \
|
| 957 |
+
Vectorized<T> inline name(const Vectorized<T>& a, T b, T c) { \
|
| 958 |
+
return name(a, Vectorized<T>(b), Vectorized<T>(c)); \
|
| 959 |
+
} \
|
| 960 |
+
\
|
| 961 |
+
template <class T> \
|
| 962 |
+
Vectorized<T> inline name( \
|
| 963 |
+
T a, const Vectorized<T>& b, const Vectorized<T>& c) { \
|
| 964 |
+
return name(Vectorized<T>(a), b, c); \
|
| 965 |
+
} \
|
| 966 |
+
\
|
| 967 |
+
template <class T> \
|
| 968 |
+
Vectorized<T> inline name(T a, const Vectorized<T>& b, T c) { \
|
| 969 |
+
return name(Vectorized<T>(a), b, Vectorized<T>(c)); \
|
| 970 |
+
} \
|
| 971 |
+
\
|
| 972 |
+
template <class T> \
|
| 973 |
+
Vectorized<T> inline name(T a, T b, const Vectorized<T>& c) { \
|
| 974 |
+
return name(Vectorized<T>(a), Vectorized<T>(b), c); \
|
| 975 |
+
}
|
| 976 |
+
|
| 977 |
+
VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(clamp)
|
| 978 |
+
|
| 979 |
+
template <
|
| 980 |
+
class T,
|
| 981 |
+
typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
|
| 982 |
+
Vectorized<T> inline clamp_max(
|
| 983 |
+
const Vectorized<T>& a,
|
| 984 |
+
const Vectorized<T>& max_vec) {
|
| 985 |
+
Vectorized<T> c;
|
| 986 |
+
for (int i = 0; i != Vectorized<T>::size(); i++) {
|
| 987 |
+
c[i] = a[i] > max_vec[i] ? max_vec[i] : a[i];
|
| 988 |
+
}
|
| 989 |
+
return c;
|
| 990 |
+
}
|
| 991 |
+
|
| 992 |
+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(clamp_max)
|
| 993 |
+
|
| 994 |
+
template <
|
| 995 |
+
class T,
|
| 996 |
+
typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
|
| 997 |
+
Vectorized<T> inline clamp_min(
|
| 998 |
+
const Vectorized<T>& a,
|
| 999 |
+
const Vectorized<T>& min_vec) {
|
| 1000 |
+
Vectorized<T> c;
|
| 1001 |
+
for (int i = 0; i != Vectorized<T>::size(); i++) {
|
| 1002 |
+
c[i] = a[i] < min_vec[i] ? min_vec[i] : a[i];
|
| 1003 |
+
}
|
| 1004 |
+
return c;
|
| 1005 |
+
}
|
| 1006 |
+
|
| 1007 |
+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(clamp_min)
|
| 1008 |
+
|
| 1009 |
+
struct Vectorizedi;
|
| 1010 |
+
|
| 1011 |
+
#if defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)
|
| 1012 |
+
template <class T, typename Op>
|
| 1013 |
+
static inline Vectorized<T> bitwise_binary_op(
|
| 1014 |
+
const Vectorized<T>& a,
|
| 1015 |
+
const Vectorized<T>& b,
|
| 1016 |
+
Op op) {
|
| 1017 |
+
int_vector buffer;
|
| 1018 |
+
#if defined(CPU_CAPABILITY_AVX2)
|
| 1019 |
+
int_vector a_buffer =
|
| 1020 |
+
_mm256_load_si256(reinterpret_cast<const int_vector*>((const T*)a));
|
| 1021 |
+
int_vector b_buffer =
|
| 1022 |
+
_mm256_load_si256(reinterpret_cast<const int_vector*>((const T*)b));
|
| 1023 |
+
#elif defined(CPU_CAPABILITY_AVX512)
|
| 1024 |
+
int_vector a_buffer =
|
| 1025 |
+
_mm512_load_si512(reinterpret_cast<const int_vector*>((const T*)a));
|
| 1026 |
+
int_vector b_buffer =
|
| 1027 |
+
_mm512_load_si512(reinterpret_cast<const int_vector*>((const T*)b));
|
| 1028 |
+
#endif
|
| 1029 |
+
buffer = op(a_buffer, b_buffer);
|
| 1030 |
+
__at_align__ T results[Vectorized<T>::size()];
|
| 1031 |
+
|
| 1032 |
+
#if defined(CPU_CAPABILITY_AVX2)
|
| 1033 |
+
_mm256_store_si256(reinterpret_cast<int_vector*>(results), buffer);
|
| 1034 |
+
#elif defined(CPU_CAPABILITY_AVX512)
|
| 1035 |
+
_mm512_store_si512(reinterpret_cast<int_vector*>(results), buffer);
|
| 1036 |
+
#endif
|
| 1037 |
+
return Vectorized<T>::loadu(results);
|
| 1038 |
+
}
|
| 1039 |
+
|
| 1040 |
+
template <
|
| 1041 |
+
class T,
|
| 1042 |
+
typename std::enable_if_t<
|
| 1043 |
+
!std::is_base_of<Vectorizedi, Vectorized<T>>::value,
|
| 1044 |
+
int> = 0>
|
| 1045 |
+
inline Vectorized<T> operator&(const Vectorized<T>& a, const Vectorized<T>& b) {
|
| 1046 |
+
// We enclose _mm512_and_si512 or _mm256_and_si256 with lambda because it is
|
| 1047 |
+
// always_inline
|
| 1048 |
+
#if defined(CPU_CAPABILITY_AVX2)
|
| 1049 |
+
return bitwise_binary_op(
|
| 1050 |
+
a, b, [](int_vector a, int_vector b) { return _mm256_and_si256(a, b); });
|
| 1051 |
+
#elif defined(CPU_CAPABILITY_AVX512)
|
| 1052 |
+
return bitwise_binary_op(
|
| 1053 |
+
a, b, [](int_vector a, int_vector b) { return _mm512_and_si512(a, b); });
|
| 1054 |
+
#endif
|
| 1055 |
+
}
|
| 1056 |
+
template <
|
| 1057 |
+
class T,
|
| 1058 |
+
typename std::enable_if_t<
|
| 1059 |
+
!std::is_base_of<Vectorizedi, Vectorized<T>>::value,
|
| 1060 |
+
int> = 0>
|
| 1061 |
+
inline Vectorized<T> operator|(const Vectorized<T>& a, const Vectorized<T>& b) {
|
| 1062 |
+
// We enclose _mm512_or_si512 or _mm256_or_si256 with lambda because it is
|
| 1063 |
+
// always_inline
|
| 1064 |
+
#if defined(CPU_CAPABILITY_AVX2)
|
| 1065 |
+
return bitwise_binary_op(
|
| 1066 |
+
a, b, [](int_vector a, int_vector b) { return _mm256_or_si256(a, b); });
|
| 1067 |
+
#elif defined(CPU_CAPABILITY_AVX512)
|
| 1068 |
+
return bitwise_binary_op(
|
| 1069 |
+
a, b, [](int_vector a, int_vector b) { return _mm512_or_si512(a, b); });
|
| 1070 |
+
#endif
|
| 1071 |
+
}
|
| 1072 |
+
template <
|
| 1073 |
+
class T,
|
| 1074 |
+
typename std::enable_if_t<
|
| 1075 |
+
!std::is_base_of<Vectorizedi, Vectorized<T>>::value,
|
| 1076 |
+
int> = 0>
|
| 1077 |
+
inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) {
|
| 1078 |
+
// We enclose _mm512_xor_si512 or _mm256_xor_si256 with lambda because it is
|
| 1079 |
+
// always_inline
|
| 1080 |
+
#if defined(CPU_CAPABILITY_AVX2)
|
| 1081 |
+
return bitwise_binary_op(
|
| 1082 |
+
a, b, [](int_vector a, int_vector b) { return _mm256_xor_si256(a, b); });
|
| 1083 |
+
#elif defined(CPU_CAPABILITY_AVX512)
|
| 1084 |
+
return bitwise_binary_op(
|
| 1085 |
+
a, b, [](int_vector a, int_vector b) { return _mm512_xor_si512(a, b); });
|
| 1086 |
+
#endif
|
| 1087 |
+
}
|
| 1088 |
+
|
| 1089 |
+
#else
|
| 1090 |
+
|
| 1091 |
+
template <typename T>
|
| 1092 |
+
auto load(char const* data) -> T {
|
| 1093 |
+
T ret;
|
| 1094 |
+
std::memcpy(&ret, data, sizeof(ret));
|
| 1095 |
+
return ret;
|
| 1096 |
+
}
|
| 1097 |
+
|
| 1098 |
+
template <class T, typename Op>
|
| 1099 |
+
static inline Vectorized<T> bitwise_binary_op(
|
| 1100 |
+
const Vectorized<T>& a,
|
| 1101 |
+
const Vectorized<T>& b,
|
| 1102 |
+
Op op) {
|
| 1103 |
+
static constexpr uint32_t element_no = VECTOR_WIDTH / sizeof(intmax_t);
|
| 1104 |
+
__at_align__ intmax_t buffer[element_no];
|
| 1105 |
+
static_assert(
|
| 1106 |
+
VECTOR_WIDTH % sizeof(intmax_t) == 0,
|
| 1107 |
+
"VECTOR_WIDTH not a multiple of sizeof(intmax_t)");
|
| 1108 |
+
static_assert(
|
| 1109 |
+
sizeof(buffer) == sizeof(Vectorized<T>),
|
| 1110 |
+
"sizeof(buffer) must match sizeof(Vectorized<T>)");
|
| 1111 |
+
// We should be using memcpy in order to respect the strict aliasing rule
|
| 1112 |
+
// see: https://github.com/pytorch/pytorch/issues/66119
|
| 1113 |
+
// Using char* is defined in the C11 standard 6.5 Expression paragraph 7
|
| 1114 |
+
// (http://www.open-std.org/jtc1/sc22/wg14/www/docs/n1570.pdf)
|
| 1115 |
+
const auto* a_data = a.as_bytes();
|
| 1116 |
+
const auto* b_data = b.as_bytes();
|
| 1117 |
+
// load each intmax_t chunk and process; increase pointers by sizeof(intmax_t)
|
| 1118 |
+
for (auto& out : buffer) {
|
| 1119 |
+
out = op(load<intmax_t>(a_data), load<intmax_t>(b_data));
|
| 1120 |
+
a_data += sizeof(intmax_t);
|
| 1121 |
+
b_data += sizeof(intmax_t);
|
| 1122 |
+
}
|
| 1123 |
+
assert(a_data == a.as_bytes() + sizeof(a));
|
| 1124 |
+
assert(b_data == b.as_bytes() + sizeof(b));
|
| 1125 |
+
return Vectorized<T>::loadu(buffer);
|
| 1126 |
+
}
|
| 1127 |
+
|
| 1128 |
+
template <
|
| 1129 |
+
class T,
|
| 1130 |
+
typename std::
|
| 1131 |
+
enable_if_t<!std::is_base_of_v<Vectorizedi, Vectorized<T>>, int> = 0>
|
| 1132 |
+
inline Vectorized<T> operator&(const Vectorized<T>& a, const Vectorized<T>& b) {
|
| 1133 |
+
return bitwise_binary_op(a, b, std::bit_and<intmax_t>());
|
| 1134 |
+
}
|
| 1135 |
+
template <
|
| 1136 |
+
class T,
|
| 1137 |
+
typename std::
|
| 1138 |
+
enable_if_t<!std::is_base_of_v<Vectorizedi, Vectorized<T>>, int> = 0>
|
| 1139 |
+
inline Vectorized<T> operator|(const Vectorized<T>& a, const Vectorized<T>& b) {
|
| 1140 |
+
return bitwise_binary_op(a, b, std::bit_or<intmax_t>());
|
| 1141 |
+
}
|
| 1142 |
+
template <
|
| 1143 |
+
class T,
|
| 1144 |
+
typename std::
|
| 1145 |
+
enable_if_t<!std::is_base_of_v<Vectorizedi, Vectorized<T>>, int> = 0>
|
| 1146 |
+
inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) {
|
| 1147 |
+
return bitwise_binary_op(a, b, std::bit_xor<intmax_t>());
|
| 1148 |
+
}
|
| 1149 |
+
|
| 1150 |
+
#endif // defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)
|
| 1151 |
+
|
| 1152 |
+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(&)
|
| 1153 |
+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(|)
|
| 1154 |
+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(^)
|
| 1155 |
+
|
| 1156 |
+
template <
|
| 1157 |
+
class T,
|
| 1158 |
+
typename std::
|
| 1159 |
+
enable_if_t<!std::is_base_of_v<Vectorizedi, Vectorized<T>>, int> = 0>
|
| 1160 |
+
inline Vectorized<T> operator~(const Vectorized<T>& a) {
|
| 1161 |
+
using int_t = int_same_size_t<T>;
|
| 1162 |
+
Vectorized<T> ones(c10::bit_cast<T>((int_t)(~(int_t)0))); // All bits are 1
|
| 1163 |
+
return a ^ ones;
|
| 1164 |
+
}
|
| 1165 |
+
|
| 1166 |
+
template <class T>
|
| 1167 |
+
Vectorized<T> inline operator<<(
|
| 1168 |
+
const Vectorized<T>& a,
|
| 1169 |
+
const Vectorized<T>& b) {
|
| 1170 |
+
constexpr T max_shift = sizeof(T) * CHAR_BIT;
|
| 1171 |
+
Vectorized<T> c;
|
| 1172 |
+
for (int i = 0; i != Vectorized<T>::size(); i++) {
|
| 1173 |
+
T shift = b[i];
|
| 1174 |
+
if ((static_cast<std::make_signed_t<T>>(shift) < 0) ||
|
| 1175 |
+
(shift >= max_shift)) {
|
| 1176 |
+
c[i] = 0;
|
| 1177 |
+
} else {
|
| 1178 |
+
c[i] = static_cast<std::make_unsigned_t<T>>(a[i]) << shift;
|
| 1179 |
+
}
|
| 1180 |
+
}
|
| 1181 |
+
return c;
|
| 1182 |
+
}
|
| 1183 |
+
|
| 1184 |
+
template <class T>
|
| 1185 |
+
Vectorized<T> inline operator>>(
|
| 1186 |
+
const Vectorized<T>& a,
|
| 1187 |
+
const Vectorized<T>& b) {
|
| 1188 |
+
// right shift value to retain sign bit for signed and no bits for unsigned
|
| 1189 |
+
constexpr T max_shift = sizeof(T) * CHAR_BIT - std::is_signed_v<T>;
|
| 1190 |
+
Vectorized<T> c;
|
| 1191 |
+
for (int i = 0; i != Vectorized<T>::size(); i++) {
|
| 1192 |
+
T shift = b[i];
|
| 1193 |
+
if ((static_cast<std::make_signed_t<T>>(shift) < 0) ||
|
| 1194 |
+
(shift >= max_shift)) {
|
| 1195 |
+
c[i] = a[i] >> max_shift;
|
| 1196 |
+
} else {
|
| 1197 |
+
c[i] = a[i] >> shift;
|
| 1198 |
+
}
|
| 1199 |
+
}
|
| 1200 |
+
return c;
|
| 1201 |
+
}
|
| 1202 |
+
|
| 1203 |
+
template <typename T>
|
| 1204 |
+
inline Vectorized<T>& operator+=(Vectorized<T>& a, const Vectorized<T>& b) {
|
| 1205 |
+
a = a + b;
|
| 1206 |
+
return a;
|
| 1207 |
+
}
|
| 1208 |
+
template <typename T>
|
| 1209 |
+
inline Vectorized<T>& operator-=(Vectorized<T>& a, const Vectorized<T>& b) {
|
| 1210 |
+
a = a - b;
|
| 1211 |
+
return a;
|
| 1212 |
+
}
|
| 1213 |
+
template <typename T>
|
| 1214 |
+
inline Vectorized<T>& operator/=(Vectorized<T>& a, const Vectorized<T>& b) {
|
| 1215 |
+
a = a / b;
|
| 1216 |
+
return a;
|
| 1217 |
+
}
|
| 1218 |
+
template <typename T>
|
| 1219 |
+
inline Vectorized<T>& operator%=(Vectorized<T>& a, const Vectorized<T>& b) {
|
| 1220 |
+
a = a % b;
|
| 1221 |
+
return a;
|
| 1222 |
+
}
|
| 1223 |
+
template <typename T>
|
| 1224 |
+
inline Vectorized<T>& operator*=(Vectorized<T>& a, const Vectorized<T>& b) {
|
| 1225 |
+
a = a * b;
|
| 1226 |
+
return a;
|
| 1227 |
+
}
|
| 1228 |
+
|
| 1229 |
+
template <typename T>
|
| 1230 |
+
inline Vectorized<T>& operator<<=(Vectorized<T>& a, const Vectorized<T>& b) {
|
| 1231 |
+
a = a << b;
|
| 1232 |
+
return a;
|
| 1233 |
+
}
|
| 1234 |
+
|
| 1235 |
+
template <typename T>
|
| 1236 |
+
inline Vectorized<T>& operator>>=(Vectorized<T>& a, const Vectorized<T>& b) {
|
| 1237 |
+
a = a >> b;
|
| 1238 |
+
return a;
|
| 1239 |
+
}
|
| 1240 |
+
|
| 1241 |
+
template <typename T>
|
| 1242 |
+
inline Vectorized<T> fmadd(
|
| 1243 |
+
const Vectorized<T>& a,
|
| 1244 |
+
const Vectorized<T>& b,
|
| 1245 |
+
const Vectorized<T>& c) {
|
| 1246 |
+
return a * b + c;
|
| 1247 |
+
}
|
| 1248 |
+
|
| 1249 |
+
VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(fmadd)
|
| 1250 |
+
|
| 1251 |
+
template <typename T>
|
| 1252 |
+
inline Vectorized<T> fnmadd(
|
| 1253 |
+
const Vectorized<T>& a,
|
| 1254 |
+
const Vectorized<T>& b,
|
| 1255 |
+
const Vectorized<T>& c) {
|
| 1256 |
+
return -(a * b) + c;
|
| 1257 |
+
}
|
| 1258 |
+
|
| 1259 |
+
VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(fnmadd)
|
| 1260 |
+
|
| 1261 |
+
template <typename T>
|
| 1262 |
+
inline Vectorized<T> fmsub(
|
| 1263 |
+
const Vectorized<T>& a,
|
| 1264 |
+
const Vectorized<T>& b,
|
| 1265 |
+
const Vectorized<T>& c) {
|
| 1266 |
+
return a * b - c;
|
| 1267 |
+
}
|
| 1268 |
+
|
| 1269 |
+
VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(fmsub)
|
| 1270 |
+
|
| 1271 |
+
template <typename T>
|
| 1272 |
+
inline Vectorized<T> fnmsub(
|
| 1273 |
+
const Vectorized<T>& a,
|
| 1274 |
+
const Vectorized<T>& b,
|
| 1275 |
+
const Vectorized<T>& c) {
|
| 1276 |
+
return -(a * b) - c;
|
| 1277 |
+
}
|
| 1278 |
+
|
| 1279 |
+
VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(fnmsub)
|
| 1280 |
+
|
| 1281 |
+
template <typename T>
|
| 1282 |
+
Vectorized<T> inline operator&&(
|
| 1283 |
+
const Vectorized<T>& a,
|
| 1284 |
+
const Vectorized<T>& b) {
|
| 1285 |
+
Vectorized<T> ret;
|
| 1286 |
+
for (int i = 0; i != Vectorized<T>::size(); i++) {
|
| 1287 |
+
ret[i] = a[i] && b[i];
|
| 1288 |
+
}
|
| 1289 |
+
return ret;
|
| 1290 |
+
}
|
| 1291 |
+
|
| 1292 |
+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(&&)
|
| 1293 |
+
|
| 1294 |
+
template <int64_t scale = 1, typename T = void>
|
| 1295 |
+
std::enable_if_t<
|
| 1296 |
+
scale == 1 || scale == 2 || scale == 4 || scale == 8,
|
| 1297 |
+
Vectorized<
|
| 1298 |
+
T>> inline gather(T const* base_addr, const Vectorized<int_same_size_t<T>>& vindex) {
|
| 1299 |
+
static constexpr int size = Vectorized<T>::size();
|
| 1300 |
+
int_same_size_t<T> index_arr[size];
|
| 1301 |
+
vindex.store(static_cast<void*>(index_arr));
|
| 1302 |
+
T buffer[size];
|
| 1303 |
+
for (const auto i : c10::irange(size)) {
|
| 1304 |
+
buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)];
|
| 1305 |
+
}
|
| 1306 |
+
return Vectorized<T>::loadu(static_cast<void*>(buffer));
|
| 1307 |
+
}
|
| 1308 |
+
|
| 1309 |
+
template <int64_t scale = 1, typename T = void>
|
| 1310 |
+
std::
|
| 1311 |
+
enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<T>> inline mask_gather(
|
| 1312 |
+
const Vectorized<T>& src,
|
| 1313 |
+
T const* base_addr,
|
| 1314 |
+
const Vectorized<int_same_size_t<T>>& vindex,
|
| 1315 |
+
Vectorized<T>& mask) {
|
| 1316 |
+
static constexpr int size = Vectorized<T>::size();
|
| 1317 |
+
T src_arr[size];
|
| 1318 |
+
int_same_size_t<T> mask_arr[size]; // use int type so we can logical and
|
| 1319 |
+
int_same_size_t<T> index_arr[size];
|
| 1320 |
+
src.store(static_cast<void*>(src_arr));
|
| 1321 |
+
mask.store(static_cast<void*>(mask_arr));
|
| 1322 |
+
vindex.store(static_cast<void*>(index_arr));
|
| 1323 |
+
T buffer[size];
|
| 1324 |
+
for (const auto i : c10::irange(size)) {
|
| 1325 |
+
if (mask_arr[i] & 0x01) { // check highest bit
|
| 1326 |
+
buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)];
|
| 1327 |
+
} else {
|
| 1328 |
+
buffer[i] = src_arr[i];
|
| 1329 |
+
}
|
| 1330 |
+
}
|
| 1331 |
+
mask = Vectorized<T>(static_cast<T>(0)); // "zero out" mask
|
| 1332 |
+
return Vectorized<T>::loadu(static_cast<void*>(buffer));
|
| 1333 |
+
}
|
| 1334 |
+
|
| 1335 |
+
// Cast a given vector to another type without changing the bits representation.
|
| 1336 |
+
// So a Vectorized<double> of 512 bits containing all ones can be cast to a
|
| 1337 |
+
// Vectorized<int64_t> of 512 bits containing all ones (i.e., eight negative
|
| 1338 |
+
// 1s). A Vec<double> of 256 bits containing all ones can be cast to a
|
| 1339 |
+
// Vec<int64_t> of 256 bits containing all ones (i.e., four negative 1s).
|
| 1340 |
+
// There is a struct here because we don't have static_if and I can't
|
| 1341 |
+
// partially specialize a templated function.
|
| 1342 |
+
template <typename dst_t, typename src_t>
|
| 1343 |
+
struct CastImpl {
|
| 1344 |
+
static inline Vectorized<dst_t> apply(const Vectorized<src_t>& src) {
|
| 1345 |
+
src_t src_arr[Vectorized<src_t>::size()];
|
| 1346 |
+
src.store(static_cast<void*>(src_arr));
|
| 1347 |
+
return Vectorized<dst_t>::loadu(static_cast<const void*>(src_arr));
|
| 1348 |
+
}
|
| 1349 |
+
};
|
| 1350 |
+
|
| 1351 |
+
template <typename scalar_t>
|
| 1352 |
+
struct CastImpl<scalar_t, scalar_t> {
|
| 1353 |
+
static inline Vectorized<scalar_t> apply(const Vectorized<scalar_t>& src) {
|
| 1354 |
+
return src;
|
| 1355 |
+
}
|
| 1356 |
+
};
|
| 1357 |
+
|
| 1358 |
+
template <typename dst_t, typename src_t>
|
| 1359 |
+
inline Vectorized<dst_t> cast(const Vectorized<src_t>& src) {
|
| 1360 |
+
return CastImpl<dst_t, src_t>::apply(src);
|
| 1361 |
+
}
|
| 1362 |
+
|
| 1363 |
+
template <typename T, typename IntType = int_same_size_t<T>>
|
| 1364 |
+
inline Vectorized<IntType> convert_to_int_of_same_size(
|
| 1365 |
+
const Vectorized<T>& src) {
|
| 1366 |
+
static_assert(sizeof(T) == sizeof(IntType));
|
| 1367 |
+
static constexpr int size = Vectorized<T>::size();
|
| 1368 |
+
|
| 1369 |
+
std::array<T, size> src_arr = {};
|
| 1370 |
+
src.store(static_cast<void*>(src_arr.data()));
|
| 1371 |
+
std::array<IntType, size> buffer;
|
| 1372 |
+
std::transform(
|
| 1373 |
+
src_arr.cbegin(), src_arr.cend(), buffer.begin(), [](const T& x) {
|
| 1374 |
+
return static_cast<IntType>(x);
|
| 1375 |
+
});
|
| 1376 |
+
return Vectorized<IntType>::loadu(static_cast<const void*>(buffer.data()));
|
| 1377 |
+
}
|
| 1378 |
+
|
| 1379 |
+
template <typename T, typename IntType = int_same_size_t<T>>
|
| 1380 |
+
inline Vectorized<T> convert_to_fp_of_same_size(
|
| 1381 |
+
const Vectorized<IntType>& src) {
|
| 1382 |
+
static_assert(sizeof(T) == sizeof(IntType));
|
| 1383 |
+
static constexpr int size = Vectorized<T>::size();
|
| 1384 |
+
|
| 1385 |
+
std::array<IntType, size> src_arr;
|
| 1386 |
+
src.store(static_cast<void*>(src_arr.data()));
|
| 1387 |
+
std::array<T, size> buffer;
|
| 1388 |
+
std::transform(
|
| 1389 |
+
src_arr.cbegin(), src_arr.cend(), buffer.begin(), [](const IntType& x) {
|
| 1390 |
+
return static_cast<T>(x);
|
| 1391 |
+
});
|
| 1392 |
+
return Vectorized<T>::loadu(static_cast<const void*>(buffer.data()));
|
| 1393 |
+
}
|
| 1394 |
+
|
| 1395 |
+
// clang-format off
|
| 1396 |
+
// Example inputs for AVX512:
|
| 1397 |
+
// a Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
|
| 1398 |
+
// b Vectorized<float> = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
|
| 1399 |
+
// returns:
|
| 1400 |
+
// Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
|
| 1401 |
+
// Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
|
| 1402 |
+
// Example inputs for AVX2: a Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3}
|
| 1403 |
+
// b Vectorized<float> = {a4, b4, a5, b5, a6, b6, a7, b7}
|
| 1404 |
+
// returns: Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7}
|
| 1405 |
+
// Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7}
|
| 1406 |
+
// clang-format on
|
| 1407 |
+
template <typename T>
|
| 1408 |
+
inline std::enable_if_t<
|
| 1409 |
+
Vectorized<T>::size() % 2 == 0,
|
| 1410 |
+
std::pair<Vectorized<T>, Vectorized<T>>>
|
| 1411 |
+
deinterleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
|
| 1412 |
+
static constexpr int size = Vectorized<T>::size();
|
| 1413 |
+
static constexpr int half_size = size / 2;
|
| 1414 |
+
T a_arr[size];
|
| 1415 |
+
T b_arr[size];
|
| 1416 |
+
T buffer1[size];
|
| 1417 |
+
T buffer2[size];
|
| 1418 |
+
a.store(static_cast<void*>(a_arr));
|
| 1419 |
+
b.store(static_cast<void*>(b_arr));
|
| 1420 |
+
for (const auto i : c10::irange(half_size)) {
|
| 1421 |
+
buffer1[i] = a_arr[i * 2];
|
| 1422 |
+
buffer1[half_size + i] = b_arr[i * 2];
|
| 1423 |
+
buffer2[i] = a_arr[i * 2 + 1];
|
| 1424 |
+
buffer2[half_size + i] = b_arr[i * 2 + 1];
|
| 1425 |
+
}
|
| 1426 |
+
return std::make_pair(
|
| 1427 |
+
Vectorized<T>::loadu(static_cast<void*>(buffer1)),
|
| 1428 |
+
Vectorized<T>::loadu(static_cast<void*>(buffer2)));
|
| 1429 |
+
}
|
| 1430 |
+
|
| 1431 |
+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(deinterleave2)
|
| 1432 |
+
|
| 1433 |
+
// clang-format off
|
| 1434 |
+
// inverse operation of deinterleave2
|
| 1435 |
+
// Example inputs for AVX512:
|
| 1436 |
+
// a Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
|
| 1437 |
+
// b Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
|
| 1438 |
+
// returns, for AVX512:
|
| 1439 |
+
// Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
|
| 1440 |
+
// Vectorized<float> = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
|
| 1441 |
+
// Example inputs for AVX2 : a Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7}
|
| 1442 |
+
// b Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7}
|
| 1443 |
+
// returns: Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3}
|
| 1444 |
+
// Vectorized<float> = {a4, b4, a5, b5, a6, b6, a7, b7}
|
| 1445 |
+
// clang-format on
|
| 1446 |
+
template <typename T>
|
| 1447 |
+
inline std::enable_if_t<
|
| 1448 |
+
Vectorized<T>::size() % 2 == 0,
|
| 1449 |
+
std::pair<Vectorized<T>, Vectorized<T>>>
|
| 1450 |
+
interleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
|
| 1451 |
+
static constexpr int size = Vectorized<T>::size();
|
| 1452 |
+
static constexpr int half_size = size / 2;
|
| 1453 |
+
T a_arr[size];
|
| 1454 |
+
T b_arr[size];
|
| 1455 |
+
T buffer1[size];
|
| 1456 |
+
T buffer2[size];
|
| 1457 |
+
a.store(static_cast<void*>(a_arr));
|
| 1458 |
+
b.store(static_cast<void*>(b_arr));
|
| 1459 |
+
for (const auto i : c10::irange(half_size)) {
|
| 1460 |
+
buffer1[i * 2] = a_arr[i];
|
| 1461 |
+
buffer1[i * 2 + 1] = b_arr[i];
|
| 1462 |
+
buffer2[i * 2] = a_arr[half_size + i];
|
| 1463 |
+
buffer2[i * 2 + 1] = b_arr[half_size + i];
|
| 1464 |
+
}
|
| 1465 |
+
return std::make_pair(
|
| 1466 |
+
Vectorized<T>::loadu(static_cast<void*>(buffer1)),
|
| 1467 |
+
Vectorized<T>::loadu(static_cast<void*>(buffer2)));
|
| 1468 |
+
}
|
| 1469 |
+
|
| 1470 |
+
VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(interleave2)
|
| 1471 |
+
|
| 1472 |
+
#undef VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC
|
| 1473 |
+
#undef VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP
|
| 1474 |
+
#undef VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC
|
| 1475 |
+
|
| 1476 |
+
template <typename src_T, typename dst_T>
|
| 1477 |
+
inline void convert(const src_T* src, dst_T* dst, int64_t n) {
|
| 1478 |
+
#ifndef _MSC_VER
|
| 1479 |
+
#pragma unroll
|
| 1480 |
+
#endif
|
| 1481 |
+
for ([[maybe_unused]] const auto i : c10::irange(n)) {
|
| 1482 |
+
*dst = c10::convert<dst_T>(c10::load(src));
|
| 1483 |
+
src++;
|
| 1484 |
+
dst++;
|
| 1485 |
+
}
|
| 1486 |
+
}
|
| 1487 |
+
|
| 1488 |
+
template <typename T>
|
| 1489 |
+
inline Vectorized<T> flip(const Vectorized<T>& data) {
|
| 1490 |
+
static constexpr int size = Vectorized<T>::size();
|
| 1491 |
+
T output[size];
|
| 1492 |
+
T buffer[size];
|
| 1493 |
+
data.store(static_cast<void*>(buffer));
|
| 1494 |
+
for (const auto i : c10::irange(size)) {
|
| 1495 |
+
output[i] = buffer[size - i - 1];
|
| 1496 |
+
}
|
| 1497 |
+
return Vectorized<T>::loadu(static_cast<void*>(output));
|
| 1498 |
+
}
|
| 1499 |
+
|
| 1500 |
+
// Transpose the `src` buffer of type `T` and size (M,N) into the `dst` buffer.
|
| 1501 |
+
// `ld_src` is the leading dimension of `src` and `ld_dst` is the leading
|
| 1502 |
+
// dimension of `dst`.
|
| 1503 |
+
template <typename T>
|
| 1504 |
+
inline void transpose_mxn(
|
| 1505 |
+
const T* src,
|
| 1506 |
+
int64_t ld_src,
|
| 1507 |
+
T* dst,
|
| 1508 |
+
int64_t ld_dst,
|
| 1509 |
+
int M,
|
| 1510 |
+
int N) {
|
| 1511 |
+
for (int i = 0; i < M; i++) {
|
| 1512 |
+
for (int j = 0; j < N; j++) {
|
| 1513 |
+
dst[j * ld_dst + i] = src[i * ld_src + j];
|
| 1514 |
+
}
|
| 1515 |
+
}
|
| 1516 |
+
}
|
| 1517 |
+
|
| 1518 |
+
template <typename T, int M, int N>
|
| 1519 |
+
inline void transpose_mxn(
|
| 1520 |
+
const T* src,
|
| 1521 |
+
int64_t ld_src,
|
| 1522 |
+
T* dst,
|
| 1523 |
+
int64_t ld_dst) {
|
| 1524 |
+
transpose_mxn<T>(src, ld_src, dst, ld_dst, M, N);
|
| 1525 |
+
}
|
| 1526 |
+
|
| 1527 |
+
} // namespace CPU_CAPABILITY
|
| 1528 |
+
} // namespace at::vec
|
| 1529 |
+
|
| 1530 |
+
// additional headers for more operations that depend on vec_base
|
| 1531 |
+
#include <ATen/cpu/vec/vec_convert.h>
|
| 1532 |
+
#include <ATen/cpu/vec/vec_mask.h>
|
| 1533 |
+
#include <ATen/cpu/vec/vec_n.h>
|
| 1534 |
+
|
| 1535 |
+
#else
|
| 1536 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 1537 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_convert.h
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/cpu/vec/vec_base.h>
|
| 5 |
+
#include <ATen/cpu/vec/vec_n.h>
|
| 6 |
+
|
| 7 |
+
namespace at::vec {
|
| 8 |
+
inline namespace CPU_CAPABILITY {
|
| 9 |
+
|
| 10 |
+
template <
|
| 11 |
+
typename dst_t,
|
| 12 |
+
int dst_n,
|
| 13 |
+
typename src_t,
|
| 14 |
+
int src_n,
|
| 15 |
+
typename Enabled = void>
|
| 16 |
+
struct VecConvert {
|
| 17 |
+
static inline VectorizedN<dst_t, dst_n> apply(
|
| 18 |
+
const VectorizedN<src_t, src_n>& src) {
|
| 19 |
+
constexpr int count = std::min(
|
| 20 |
+
VectorizedN<src_t, src_n>::size(), VectorizedN<dst_t, dst_n>::size());
|
| 21 |
+
__at_align__ src_t src_buf[VectorizedN<src_t, src_n>::size()];
|
| 22 |
+
src.store(src_buf);
|
| 23 |
+
__at_align__ dst_t dst_buf[VectorizedN<dst_t, dst_n>::size()];
|
| 24 |
+
for (int i = 0; i < count; i++) {
|
| 25 |
+
dst_buf[i] = static_cast<dst_t>(src_buf[i]);
|
| 26 |
+
}
|
| 27 |
+
return VectorizedN<dst_t, dst_n>::loadu(dst_buf, count);
|
| 28 |
+
}
|
| 29 |
+
};
|
| 30 |
+
|
| 31 |
+
template <typename dst_t, typename src_t>
|
| 32 |
+
inline std::enable_if_t<std::is_same_v<dst_t, src_t>, Vectorized<src_t>> convert(
|
| 33 |
+
const Vectorized<src_t>& src) {
|
| 34 |
+
return src;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
template <typename dst_t, typename src_t>
|
| 38 |
+
inline std::enable_if_t<!std::is_same_v<dst_t, src_t>, Vectorized<dst_t>>
|
| 39 |
+
convert(const Vectorized<src_t>& src) {
|
| 40 |
+
return VecConvert<dst_t, 1, src_t, 1>::apply(src);
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
template <
|
| 44 |
+
typename dst_t,
|
| 45 |
+
int dst_n,
|
| 46 |
+
typename src_t,
|
| 47 |
+
int src_n,
|
| 48 |
+
std::enable_if_t<dst_n != 1, int> = 0>
|
| 49 |
+
inline VectorizedN<dst_t, dst_n> convert(const VectorizedN<src_t, src_n>& src) {
|
| 50 |
+
return VecConvert<dst_t, dst_n, src_t, src_n>::apply(src);
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
template <
|
| 54 |
+
typename dst_t,
|
| 55 |
+
int dst_n,
|
| 56 |
+
typename src_t,
|
| 57 |
+
int src_n,
|
| 58 |
+
bool keep = false,
|
| 59 |
+
std::enable_if_t<dst_n == 1, int> = 0>
|
| 60 |
+
inline std::conditional_t<keep, VectorizedN<dst_t, 1>, Vectorized<dst_t>>
|
| 61 |
+
convert(const VectorizedN<src_t, src_n>& src) {
|
| 62 |
+
return VecConvert<dst_t, dst_n, src_t, src_n>::apply(src);
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
} // namespace CPU_CAPABILITY
|
| 66 |
+
|
| 67 |
+
template <
|
| 68 |
+
typename scalar_t,
|
| 69 |
+
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
| 70 |
+
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_to_float(
|
| 71 |
+
const Vectorized<scalar_t>&);
|
| 72 |
+
|
| 73 |
+
template <
|
| 74 |
+
typename scalar_t,
|
| 75 |
+
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
| 76 |
+
inline Vectorized<scalar_t> convert_from_float(
|
| 77 |
+
const Vectorized<float>&,
|
| 78 |
+
const Vectorized<float>&);
|
| 79 |
+
|
| 80 |
+
} // namespace at::vec
|
| 81 |
+
|
| 82 |
+
#else
|
| 83 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 84 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_half.h
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/cpu/vec/intrinsics.h>
|
| 5 |
+
#include <c10/util/Exception.h>
|
| 6 |
+
|
| 7 |
+
#include <torch/headeronly/cpu/vec/vec_half.h>
|
| 8 |
+
|
| 9 |
+
namespace at::vec {
|
| 10 |
+
// See Note [CPU_CAPABILITY namespace]
|
| 11 |
+
inline namespace CPU_CAPABILITY {
|
| 12 |
+
|
| 13 |
+
// Transpose a [2, 32] matrix to [32, 2]
|
| 14 |
+
// Note: the output leading dimension should be 2,
|
| 15 |
+
// that is, the output must be contiguous
|
| 16 |
+
template <typename scalar_t, typename = std::enable_if_t<sizeof(scalar_t) == 2>>
|
| 17 |
+
static inline void transpose_pad_2x32_block(
|
| 18 |
+
const scalar_t* src,
|
| 19 |
+
scalar_t* dst,
|
| 20 |
+
int64_t ld_src,
|
| 21 |
+
int krem = 2,
|
| 22 |
+
int nrem = 32) {
|
| 23 |
+
#if defined(CPU_CAPABILITY_AVX512)
|
| 24 |
+
__m512i r0, r1;
|
| 25 |
+
__m512i d0, d1;
|
| 26 |
+
// load
|
| 27 |
+
if (nrem < 32) {
|
| 28 |
+
__mmask32 mask_krem_v = (1LL << nrem) - 1;
|
| 29 |
+
r0 = _mm512_maskz_loadu_epi16(mask_krem_v, src);
|
| 30 |
+
// if krem is not 2, pad with zeros
|
| 31 |
+
if (krem == 2) {
|
| 32 |
+
r1 = _mm512_maskz_loadu_epi16(mask_krem_v, src + ld_src);
|
| 33 |
+
} else {
|
| 34 |
+
r1 = _mm512_setzero_si512();
|
| 35 |
+
}
|
| 36 |
+
} else {
|
| 37 |
+
r0 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src));
|
| 38 |
+
if (krem == 2) {
|
| 39 |
+
r1 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + ld_src));
|
| 40 |
+
} else {
|
| 41 |
+
r1 = _mm512_setzero_si512();
|
| 42 |
+
}
|
| 43 |
+
}
|
| 44 |
+
// transpose
|
| 45 |
+
d0 = _mm512_unpacklo_epi16(r0, r1);
|
| 46 |
+
d1 = _mm512_unpackhi_epi16(r0, r1);
|
| 47 |
+
r0 = _mm512_shuffle_i32x4(d0, d1, 0x88);
|
| 48 |
+
r1 = _mm512_shuffle_i32x4(d0, d1, 0xdd);
|
| 49 |
+
d0 = _mm512_shuffle_i32x4(r0, r1, 0x88);
|
| 50 |
+
d1 = _mm512_shuffle_i32x4(r0, r1, 0xdd);
|
| 51 |
+
|
| 52 |
+
// store
|
| 53 |
+
if (nrem < 16) {
|
| 54 |
+
__mmask32 mask_rem_v = (1LL << (nrem * 2)) - 1;
|
| 55 |
+
_mm512_mask_storeu_epi16(dst, mask_rem_v, d0);
|
| 56 |
+
} else if (nrem == 16) {
|
| 57 |
+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0);
|
| 58 |
+
} else if (nrem < 32) {
|
| 59 |
+
__mmask32 mask_rem_v = (1LL << (nrem * 2 - 32)) - 1;
|
| 60 |
+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0);
|
| 61 |
+
_mm512_mask_storeu_epi16(
|
| 62 |
+
reinterpret_cast<__m512i*>(dst + 32), mask_rem_v, d1);
|
| 63 |
+
} else {
|
| 64 |
+
// normal store
|
| 65 |
+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), d0);
|
| 66 |
+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 32), d1);
|
| 67 |
+
}
|
| 68 |
+
#else
|
| 69 |
+
TORCH_CHECK(
|
| 70 |
+
false,
|
| 71 |
+
"transpose_pad_2x32_block is only supported when avx512 is supported")
|
| 72 |
+
#endif
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
// To use AMX to accelerate GEMM,
|
| 76 |
+
// reorder the memory format [K, N] -> [K/2, N, 2]
|
| 77 |
+
// Note: If K % 2 != 0, pad K implicitly
|
| 78 |
+
template <typename scalar_t, typename = std::enable_if_t<sizeof(scalar_t) == 2>>
|
| 79 |
+
static inline void pack_vnni2(
|
| 80 |
+
const scalar_t* src,
|
| 81 |
+
scalar_t* dst,
|
| 82 |
+
int64_t ld_src,
|
| 83 |
+
int64_t K,
|
| 84 |
+
int64_t N) {
|
| 85 |
+
#if defined(CPU_CAPABILITY_AVX512)
|
| 86 |
+
int64_t bk = 0;
|
| 87 |
+
int64_t _K = K / 2 * 2;
|
| 88 |
+
int64_t _N = N / 32 * 32;
|
| 89 |
+
for (; bk < _K; bk += 2) {
|
| 90 |
+
int64_t bn = 0;
|
| 91 |
+
for (; bn < _N; bn += 32) {
|
| 92 |
+
transpose_pad_2x32_block(
|
| 93 |
+
src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src);
|
| 94 |
+
}
|
| 95 |
+
int64_t nrem = N - bn;
|
| 96 |
+
if (nrem > 0) {
|
| 97 |
+
transpose_pad_2x32_block(
|
| 98 |
+
src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 2, nrem);
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
if (K % 2 == 1) {
|
| 102 |
+
int64_t bn = 0;
|
| 103 |
+
for (; bn < _N; bn += 32) {
|
| 104 |
+
transpose_pad_2x32_block(
|
| 105 |
+
src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1);
|
| 106 |
+
}
|
| 107 |
+
int64_t nrem = N - bn;
|
| 108 |
+
if (nrem > 0) {
|
| 109 |
+
transpose_pad_2x32_block(
|
| 110 |
+
src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1, nrem);
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
#else
|
| 114 |
+
TORCH_CHECK(false, "pack_vnni2 is only supported when avx512 is supported")
|
| 115 |
+
#endif
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
} // namespace CPU_CAPABILITY
|
| 119 |
+
} // namespace at::vec
|
| 120 |
+
|
| 121 |
+
#else
|
| 122 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 123 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_mask.h
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/cpu/vec/vec_base.h>
|
| 5 |
+
#include <ATen/cpu/vec/vec_n.h>
|
| 6 |
+
namespace at::vec {
|
| 7 |
+
inline namespace CPU_CAPABILITY {
|
| 8 |
+
|
| 9 |
+
/**
|
| 10 |
+
* The `VecMask` class provides a convenient interface for working with
|
| 11 |
+
* vectorized masks in SIMD operations. It encapsulates a `Vectorized<T, N>`
|
| 12 |
+
* mask that can be directly usable in masked vectorized operations. It provides
|
| 13 |
+
* various methods for manipulating and accessing the mask elements:
|
| 14 |
+
* 1. `from` and `to`: Conversion between a vector of boolean values and a
|
| 15 |
+
* vectorized mask.
|
| 16 |
+
* 2. `cast`: Casts the mask to a different base type.
|
| 17 |
+
* 3. `all_zero`: Checks if all mask elements are zero.
|
| 18 |
+
* 4. `is_masked`: Checks if a specific element is masked.
|
| 19 |
+
* 5. `loadu`: Loads data from memory using the mask.
|
| 20 |
+
* 6. `all_masked`: Checks if all mask elements are masked.
|
| 21 |
+
*
|
| 22 |
+
* Some helper template classes are provided to simplify the specialization of
|
| 23 |
+
* the `VecMask` for the specific CPU arch:
|
| 24 |
+
* 1. `VecMaskLoad`: Loads data from memory using the mask.
|
| 25 |
+
* 2. `VecMaskTo`: Converts the mask to boolean.
|
| 26 |
+
* 3. `VecMaskCast`: Casts the mask to a different base type.
|
| 27 |
+
*
|
| 28 |
+
*/
|
| 29 |
+
template <typename T, int N>
|
| 30 |
+
class VecMask;
|
| 31 |
+
|
| 32 |
+
template <
|
| 33 |
+
typename data_t,
|
| 34 |
+
int data_n,
|
| 35 |
+
typename mask_t,
|
| 36 |
+
int mask_n,
|
| 37 |
+
typename Enabled = void>
|
| 38 |
+
struct VecMaskLoad {
|
| 39 |
+
static inline VectorizedN<data_t, data_n> apply(
|
| 40 |
+
const data_t* ptr,
|
| 41 |
+
const VecMask<mask_t, mask_n>& vec_mask) {
|
| 42 |
+
constexpr typename VecMask<mask_t, mask_n>::size_type size =
|
| 43 |
+
VecMask<mask_t, mask_n>::size();
|
| 44 |
+
static_assert(VectorizedN<data_t, data_n>::size() >= size);
|
| 45 |
+
__at_align__ data_t data[size];
|
| 46 |
+
__at_align__ mask_t mask[size];
|
| 47 |
+
auto mask_ = VectorizedN<mask_t, mask_n>(vec_mask);
|
| 48 |
+
mask_.store(mask);
|
| 49 |
+
for (int i = 0; i < size; i++) {
|
| 50 |
+
data[i] = mask[i] ? ptr[i] : static_cast<data_t>(0);
|
| 51 |
+
}
|
| 52 |
+
return VectorizedN<data_t, data_n>::loadu(data, size);
|
| 53 |
+
}
|
| 54 |
+
};
|
| 55 |
+
|
| 56 |
+
template <
|
| 57 |
+
typename dst_t,
|
| 58 |
+
int dst_n,
|
| 59 |
+
typename src_t,
|
| 60 |
+
int src_n,
|
| 61 |
+
typename Enabled = void>
|
| 62 |
+
struct VecMaskTo {
|
| 63 |
+
static inline VecMask<dst_t, dst_n> apply(
|
| 64 |
+
const VecMask<src_t, src_n>& vec_mask) {
|
| 65 |
+
auto zeros = VectorizedN<dst_t, dst_n>(static_cast<dst_t>(0));
|
| 66 |
+
auto ones = VectorizedN<dst_t, dst_n>(static_cast<dst_t>(1));
|
| 67 |
+
return VectorizedN<dst_t, dst_n>::blendv(
|
| 68 |
+
zeros, ones, vec_mask.template cast<dst_t, dst_n>());
|
| 69 |
+
}
|
| 70 |
+
};
|
| 71 |
+
|
| 72 |
+
template <
|
| 73 |
+
typename dst_t,
|
| 74 |
+
int dst_n,
|
| 75 |
+
typename src_t,
|
| 76 |
+
int src_n,
|
| 77 |
+
typename Enabled = void>
|
| 78 |
+
struct VecMaskCast {
|
| 79 |
+
static inline VecMask<dst_t, dst_n> apply(
|
| 80 |
+
const VecMask<src_t, src_n>& vec_mask) {
|
| 81 |
+
return VecMask<dst_t, dst_n>::from(VectorizedN<src_t, src_n>(vec_mask));
|
| 82 |
+
}
|
| 83 |
+
};
|
| 84 |
+
|
| 85 |
+
template <typename T, int N>
|
| 86 |
+
struct VecMaskCast<T, N, T, N> {
|
| 87 |
+
static inline VecMask<T, N> apply(const VecMask<T, N>& vec_mask) {
|
| 88 |
+
return vec_mask;
|
| 89 |
+
}
|
| 90 |
+
};
|
| 91 |
+
|
| 92 |
+
template <typename T, int N>
|
| 93 |
+
struct VecMaskCheck {
|
| 94 |
+
static inline bool all_zero(const VectorizedN<T, N>& vec_mask) {
|
| 95 |
+
__at_align__ T mask[VectorizedN<T, N>::size()];
|
| 96 |
+
vec_mask.store(mask);
|
| 97 |
+
return std::all_of(mask, mask + VectorizedN<T, N>::size(), [](T m) {
|
| 98 |
+
return m == static_cast<T>(0);
|
| 99 |
+
});
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
static inline bool all_masked(const VectorizedN<T, N>& vec_mask) {
|
| 103 |
+
__at_align__ T mask[VectorizedN<T, N>::size()];
|
| 104 |
+
vec_mask.store(mask);
|
| 105 |
+
return std::all_of(mask, mask + VectorizedN<T, N>::size(), [](T m) {
|
| 106 |
+
return m != static_cast<T>(0);
|
| 107 |
+
});
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
static inline bool is_masked(const VectorizedN<T, N>& vec_mask, int i) {
|
| 111 |
+
__at_align__ T mask[VectorizedN<T, N>::size()];
|
| 112 |
+
vec_mask.store(mask);
|
| 113 |
+
return mask[i] != static_cast<T>(0);
|
| 114 |
+
}
|
| 115 |
+
};
|
| 116 |
+
|
| 117 |
+
template <typename T, int N>
|
| 118 |
+
class VecMask {
|
| 119 |
+
public:
|
| 120 |
+
using size_type = int;
|
| 121 |
+
static constexpr size_type size() {
|
| 122 |
+
return VectorizedN<T, N>::size();
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
private:
|
| 126 |
+
VectorizedN<T, N> mask_;
|
| 127 |
+
|
| 128 |
+
public:
|
| 129 |
+
VecMask() : mask_(static_cast<T>(0)) {}
|
| 130 |
+
VecMask(const VectorizedN<T, N>& mask) : mask_(mask) {}
|
| 131 |
+
|
| 132 |
+
template <int L = N, typename std::enable_if_t<L == 1, int> = 0>
|
| 133 |
+
VecMask(const Vectorized<T>& mask) : mask_(mask) {}
|
| 134 |
+
|
| 135 |
+
template <typename U, int L>
|
| 136 |
+
static VecMask<T, N> from(const VectorizedN<U, L>& b_vec) {
|
| 137 |
+
__at_align__ U b_buf[size()];
|
| 138 |
+
if constexpr (size() >= VectorizedN<U, L>::size()) {
|
| 139 |
+
b_vec.store(b_buf);
|
| 140 |
+
for (int i = VectorizedN<U, L>::size(); i < size(); i++) {
|
| 141 |
+
b_buf[i] = static_cast<U>(0);
|
| 142 |
+
}
|
| 143 |
+
} else {
|
| 144 |
+
b_vec.store(b_buf, size());
|
| 145 |
+
}
|
| 146 |
+
return from(b_buf);
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
template <typename U>
|
| 150 |
+
static VecMask<T, N> from(U b) {
|
| 151 |
+
using int_t = int_same_size_t<T>;
|
| 152 |
+
T mask = b ? c10::bit_cast<T>((int_t)(~(int_t)0)) : (T)0;
|
| 153 |
+
return VectorizedN<T, N>(mask);
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
template <typename U>
|
| 157 |
+
static VecMask<T, N> from(U* b) {
|
| 158 |
+
using int_t = int_same_size_t<T>;
|
| 159 |
+
__at_align__ T mask[size()];
|
| 160 |
+
#ifndef __msvc_cl__
|
| 161 |
+
#pragma unroll
|
| 162 |
+
#endif
|
| 163 |
+
for (int i = 0; i < size(); i++) {
|
| 164 |
+
*(int_t*)(mask + i) = b[i] ? ~(int_t)0 : (int_t)0;
|
| 165 |
+
}
|
| 166 |
+
return VectorizedN<T, N>(VectorizedN<T, N>::loadu(mask));
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
template <typename U>
|
| 170 |
+
static VecMask<T, N> from(U* b, int count) {
|
| 171 |
+
using int_t = int_same_size_t<T>;
|
| 172 |
+
__at_align__ T mask[size()];
|
| 173 |
+
#ifndef __msvc_cl__
|
| 174 |
+
#pragma unroll
|
| 175 |
+
#endif
|
| 176 |
+
for (int i = 0; i < count; i++) {
|
| 177 |
+
*(int_t*)(mask + i) = b[i] ? ~(int_t)0 : (int_t)0;
|
| 178 |
+
}
|
| 179 |
+
return VectorizedN<T, N>(VectorizedN<T, N>::loadu(mask, count));
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
static VecMask<T, N> blendv(
|
| 183 |
+
const VecMask<T, N>& c,
|
| 184 |
+
const VecMask<T, N>& b,
|
| 185 |
+
const VecMask<T, N>& a) {
|
| 186 |
+
VectorizedN<T, N> result = VectorizedN<T, N>::blendv(
|
| 187 |
+
VectorizedN<T, N>(c), VectorizedN<T, N>(b), VectorizedN<T, N>(a));
|
| 188 |
+
return result;
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
static VecMask<T, N> set(
|
| 192 |
+
const VecMask<T, N>& a,
|
| 193 |
+
const VecMask<T, N>& b,
|
| 194 |
+
int64_t count = size()) {
|
| 195 |
+
VectorizedN<T, N> result = VectorizedN<T, N>::set(
|
| 196 |
+
VectorizedN<T, N>(a), VectorizedN<T, N>(b), count);
|
| 197 |
+
return result;
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
void store(bool* b, int count = size()) {
|
| 201 |
+
constexpr int L =
|
| 202 |
+
(VectorizedN<T, N>::size() + Vectorized<bool>::size() - 1) /
|
| 203 |
+
Vectorized<bool>::size();
|
| 204 |
+
auto res = this->to<bool, L>();
|
| 205 |
+
res.store(b, count);
|
| 206 |
+
return;
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
template <typename U, int L, std::enable_if_t<L >= 2, int> = 0>
|
| 210 |
+
inline VectorizedN<U, L> to() const {
|
| 211 |
+
return VecMaskTo<U, L, T, N>::apply(*this);
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
template <typename U, int L, std::enable_if_t<L == 1, int> = 0>
|
| 215 |
+
inline Vectorized<U> to() const {
|
| 216 |
+
return VecMaskTo<U, L, T, N>::apply(*this);
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
template <typename U, int L>
|
| 220 |
+
inline VecMask<U, L> cast() const {
|
| 221 |
+
return VecMaskCast<U, L, T, N>::apply(*this);
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
inline bool all_zero() const {
|
| 225 |
+
return VecMaskCheck<T, N>::all_zero(mask_);
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
inline bool all_masked() const {
|
| 229 |
+
return VecMaskCheck<T, N>::all_masked(mask_);
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
inline bool is_masked(int i) const {
|
| 233 |
+
return VecMaskCheck<T, N>::is_masked(mask_, i);
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
inline operator VectorizedN<T, N>() const {
|
| 237 |
+
return mask_;
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
template <int L = N, typename std::enable_if_t<L == 1, int> = 0>
|
| 241 |
+
inline operator Vectorized<T>() const {
|
| 242 |
+
return mask_[0];
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
inline Vectorized<T> operator[](int i) const {
|
| 246 |
+
return mask_[i];
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
template <
|
| 250 |
+
typename U,
|
| 251 |
+
int L,
|
| 252 |
+
std::enable_if_t<L >= 2 && VectorizedN<U, L>::size() >= size(), int> = 0>
|
| 253 |
+
VectorizedN<U, L> loadu(const U* ptr) const {
|
| 254 |
+
return VecMaskLoad<U, L, T, N>::apply(ptr, *this);
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
template <
|
| 258 |
+
typename U,
|
| 259 |
+
int L,
|
| 260 |
+
std::enable_if_t<L == 1 && Vectorized<U>::size() >= size(), int> = 0>
|
| 261 |
+
Vectorized<U> loadu(const U* ptr) const {
|
| 262 |
+
return VecMaskLoad<U, L, T, N>::apply(ptr, *this);
|
| 263 |
+
}
|
| 264 |
+
};
|
| 265 |
+
|
| 266 |
+
#define VEC_MASK_DEFINE_UNARY_OP_GLOBAL(op) \
|
| 267 |
+
template <typename T, int N> \
|
| 268 |
+
inline VecMask<T, N> op(const VecMask<T, N>& a) { \
|
| 269 |
+
return op(VectorizedN<T, N>(a)); \
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
#define VEC_MASK_DEFINE_BINARY_OP_GLOBAL(op) \
|
| 273 |
+
template < \
|
| 274 |
+
typename T, \
|
| 275 |
+
int N, \
|
| 276 |
+
typename V, \
|
| 277 |
+
int M, \
|
| 278 |
+
std::enable_if_t<VecMask<T, N>::size() == VecMask<V, M>::size(), int> = \
|
| 279 |
+
0> \
|
| 280 |
+
inline VecMask<T, N> op(const VecMask<T, N>& a, const VecMask<V, M>& b) { \
|
| 281 |
+
return op( \
|
| 282 |
+
VectorizedN<T, N>(a), VectorizedN<T, N>(b.template cast<T, N>())); \
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
#define VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(op, EXPR) \
|
| 286 |
+
template < \
|
| 287 |
+
typename T, \
|
| 288 |
+
int N, \
|
| 289 |
+
typename V, \
|
| 290 |
+
int M, \
|
| 291 |
+
std::enable_if_t<VecMask<T, N>::size() == VecMask<V, M>::size(), int> = \
|
| 292 |
+
0> \
|
| 293 |
+
inline VecMask<T, N> op(const VecMask<T, N>& a, const VecMask<V, M>& b) { \
|
| 294 |
+
return EXPR; \
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
VEC_MASK_DEFINE_UNARY_OP_GLOBAL(operator~)
|
| 298 |
+
VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator&)
|
| 299 |
+
VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator|)
|
| 300 |
+
VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator^)
|
| 301 |
+
VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator*)
|
| 302 |
+
VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator>, a & ~b)
|
| 303 |
+
VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator<, ~a& b)
|
| 304 |
+
VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator==, ~(a ^ b))
|
| 305 |
+
VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator>=, (a == b) | (a > b))
|
| 306 |
+
VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator<=, (a == b) | (a < b))
|
| 307 |
+
VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator!=, (a ^ b))
|
| 308 |
+
|
| 309 |
+
#undef VEC_MASK_DEFINE_UNARY_OP_GLOBAL
|
| 310 |
+
#undef VEC_MASK_DEFINE_BINARY_OP_GLOBAL
|
| 311 |
+
#undef VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL
|
| 312 |
+
|
| 313 |
+
} // namespace CPU_CAPABILITY
|
| 314 |
+
} // namespace at::vec
|
| 315 |
+
|
| 316 |
+
#else
|
| 317 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 318 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_n.h
ADDED
|
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/cpu/vec/vec_base.h>
|
| 5 |
+
#include <array>
|
| 6 |
+
|
| 7 |
+
namespace at::vec {
|
| 8 |
+
inline namespace CPU_CAPABILITY {
|
| 9 |
+
|
| 10 |
+
/**
|
| 11 |
+
* @brief A class template representing a vectorized type with
|
| 12 |
+
* `N * Vectorized<T>::size()` elements, aiming to support vectors of
|
| 13 |
+
* arbitrary size. A specific use case of it is to represent vectors
|
| 14 |
+
* converted from data types with different sizes but with the same
|
| 15 |
+
* number of vector elements, e.g., `VectorizedN<float, 2>` can be
|
| 16 |
+
* a vector converted from two `Vectorized<bfloat16>`, `VectorizedN<int64_t, 2>`
|
| 17 |
+
* can be a vector converted from two `Vectorized<int32_t>` etc.
|
| 18 |
+
*
|
| 19 |
+
* It supports most of the operations of `Vectorized<T>`
|
| 20 |
+
* and the implementation delegates to `Vectorized<T>` with loops over `N`.
|
| 21 |
+
*
|
| 22 |
+
* @tparam T The underlying type of the vectorized elements.
|
| 23 |
+
* @tparam N The number of underlying `Vectorized<T>`.
|
| 24 |
+
*/
|
| 25 |
+
template <typename T, int N>
|
| 26 |
+
class VectorizedN {
|
| 27 |
+
public:
|
| 28 |
+
using value_type = T;
|
| 29 |
+
using size_type = int;
|
| 30 |
+
|
| 31 |
+
static constexpr size_type size_T = sizeof(T);
|
| 32 |
+
static constexpr size_type size() {
|
| 33 |
+
return Vectorized<T>::size() * N;
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
private:
|
| 37 |
+
std::array<Vectorized<T>, N> values;
|
| 38 |
+
|
| 39 |
+
public:
|
| 40 |
+
// methods not implemented yet:
|
| 41 |
+
// variadic constructor, operator T*, as_bytes, zero_mask
|
| 42 |
+
|
| 43 |
+
#define VECTORIZEDN_DEFINE_UNARY_OP(op) \
|
| 44 |
+
VectorizedN<T, N> op() const { \
|
| 45 |
+
return unary_op([](const Vectorized<T>& a) { return a.op(); }); \
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
#define VECTORIZEDN_DEFINE_BINARY_OP(op) \
|
| 49 |
+
VectorizedN<T, N> op(const VectorizedN<T, N>& other) const { \
|
| 50 |
+
return binary_op( \
|
| 51 |
+
other, [](const Vectorized<T>& a, const Vectorized<T>& b) { \
|
| 52 |
+
return a.op(b); \
|
| 53 |
+
}); \
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
template <typename Op>
|
| 57 |
+
inline VectorizedN<T, N> unary_op(Op op) const {
|
| 58 |
+
VectorizedN<T, N> result;
|
| 59 |
+
#ifndef _MSC_VER
|
| 60 |
+
#pragma unroll
|
| 61 |
+
#endif
|
| 62 |
+
for (int i = 0; i < N; ++i) {
|
| 63 |
+
result.values[i] = op(values[i]);
|
| 64 |
+
}
|
| 65 |
+
return result;
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
template <typename Op>
|
| 69 |
+
inline VectorizedN<T, N> binary_op(const VectorizedN<T, N>& other, Op op)
|
| 70 |
+
const {
|
| 71 |
+
VectorizedN<T, N> result;
|
| 72 |
+
#ifndef _MSC_VER
|
| 73 |
+
#pragma unroll
|
| 74 |
+
#endif
|
| 75 |
+
for (int i = 0; i < N; ++i) {
|
| 76 |
+
result.values[i] = op(values[i], other.values[i]);
|
| 77 |
+
}
|
| 78 |
+
return result;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
template <typename Op>
|
| 82 |
+
inline VectorizedN<T, N> ternary_op(
|
| 83 |
+
const VectorizedN<T, N>& other,
|
| 84 |
+
const VectorizedN<T, N>& other2,
|
| 85 |
+
Op op) const {
|
| 86 |
+
VectorizedN<T, N> result;
|
| 87 |
+
#ifndef _MSC_VER
|
| 88 |
+
#pragma unroll
|
| 89 |
+
#endif
|
| 90 |
+
for (int i = 0; i < N; ++i) {
|
| 91 |
+
result.values[i] = op(values[i], other.values[i], other2.values[i]);
|
| 92 |
+
}
|
| 93 |
+
return result;
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
VectorizedN() = default;
|
| 97 |
+
|
| 98 |
+
explicit VectorizedN(T val) {
|
| 99 |
+
for (int i = 0; i < N; ++i) {
|
| 100 |
+
values[i] = Vectorized<T>(val);
|
| 101 |
+
}
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
template <int L = N, typename std::enable_if_t<L == 1, int> = 0>
|
| 105 |
+
VectorizedN(const Vectorized<T>& val) : values({val}) {}
|
| 106 |
+
|
| 107 |
+
template <int L = N, typename std::enable_if_t<L == 2, int> = 0>
|
| 108 |
+
VectorizedN(const Vectorized<T>& val_0, const Vectorized<T>& val_1)
|
| 109 |
+
: values({val_0, val_1}) {}
|
| 110 |
+
|
| 111 |
+
template <int L = N, typename std::enable_if_t<L == 1, int> = 0>
|
| 112 |
+
inline operator Vectorized<T>() const {
|
| 113 |
+
return values[0];
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
inline const Vectorized<T>& operator[](int i) const {
|
| 117 |
+
return values[i];
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
inline Vectorized<T>& operator[](int i) {
|
| 121 |
+
return values[i];
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
template <int64_t mask>
|
| 125 |
+
static VectorizedN<T, N> blend(
|
| 126 |
+
const VectorizedN<T, N>& a,
|
| 127 |
+
const VectorizedN<T, N>& b) {
|
| 128 |
+
VectorizedN<T, N> result;
|
| 129 |
+
for (int i = 0; i < N; ++i) {
|
| 130 |
+
result.values[i] =
|
| 131 |
+
Vectorized<T>::template blend<mask>(a.values[i], b.values[i]);
|
| 132 |
+
}
|
| 133 |
+
return result;
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
static VectorizedN<T, N> blendv(
|
| 137 |
+
const VectorizedN<T, N>& a,
|
| 138 |
+
const VectorizedN<T, N>& b,
|
| 139 |
+
const VectorizedN<T, N>& mask) {
|
| 140 |
+
VectorizedN<T, N> result;
|
| 141 |
+
for (int i = 0; i < N; ++i) {
|
| 142 |
+
result.values[i] =
|
| 143 |
+
Vectorized<T>::blendv(a.values[i], b.values[i], mask.values[i]);
|
| 144 |
+
}
|
| 145 |
+
return result;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
template <typename step_t>
|
| 149 |
+
static VectorizedN<T, N> arange(
|
| 150 |
+
T base = static_cast<T>(0),
|
| 151 |
+
step_t step = static_cast<step_t>(1)) {
|
| 152 |
+
VectorizedN<T, N> result;
|
| 153 |
+
for (int i = 0; i < N; ++i) {
|
| 154 |
+
result.values[i] = Vectorized<T>::arange(base, step);
|
| 155 |
+
base += step * Vectorized<T>::size();
|
| 156 |
+
}
|
| 157 |
+
return result;
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
static VectorizedN<T, N> set(
|
| 161 |
+
const VectorizedN<T, N>& a,
|
| 162 |
+
const VectorizedN<T, N>& b,
|
| 163 |
+
int64_t count = size()) {
|
| 164 |
+
VectorizedN<T, N> result;
|
| 165 |
+
for (int i = 0; i < N; ++i) {
|
| 166 |
+
if (count > 0) {
|
| 167 |
+
result.values[i] = Vectorized<T>::set(
|
| 168 |
+
a.values[i],
|
| 169 |
+
b.values[i],
|
| 170 |
+
std::min(count, (int64_t)Vectorized<T>::size()));
|
| 171 |
+
count -= Vectorized<T>::size();
|
| 172 |
+
} else {
|
| 173 |
+
result.values[i] = a.values[i];
|
| 174 |
+
}
|
| 175 |
+
}
|
| 176 |
+
return result;
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
static VectorizedN<T, N> loadu(const void* ptr) {
|
| 180 |
+
VectorizedN<T, N> result;
|
| 181 |
+
for (int i = 0; i < N; ++i) {
|
| 182 |
+
result.values[i] = Vectorized<T>::loadu(ptr);
|
| 183 |
+
ptr = static_cast<const T*>(ptr) + Vectorized<T>::size();
|
| 184 |
+
}
|
| 185 |
+
return result;
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
static VectorizedN<T, N> loadu(const void* ptr, int64_t count) {
|
| 189 |
+
VectorizedN<T, N> result;
|
| 190 |
+
for (int i = 0; i < N; ++i) {
|
| 191 |
+
if (count > 0) {
|
| 192 |
+
result.values[i] = Vectorized<T>::loadu(
|
| 193 |
+
ptr, std::min(count, (int64_t)Vectorized<T>::size()));
|
| 194 |
+
ptr = static_cast<const T*>(ptr) + Vectorized<T>::size();
|
| 195 |
+
count -= Vectorized<T>::size();
|
| 196 |
+
} else {
|
| 197 |
+
result.values[i] = Vectorized<T>((T)1);
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
return result;
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
void store(void* ptr) const {
|
| 204 |
+
for (int i = 0; i < N; ++i) {
|
| 205 |
+
values[i].store(ptr);
|
| 206 |
+
ptr = static_cast<T*>(ptr) + Vectorized<T>::size();
|
| 207 |
+
}
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
void store(void* ptr, int count) const {
|
| 211 |
+
for (int i = 0; i < N; ++i) {
|
| 212 |
+
values[i].store(ptr, std::min(count, (int)Vectorized<T>::size()));
|
| 213 |
+
ptr = static_cast<T*>(ptr) + Vectorized<T>::size();
|
| 214 |
+
count -= Vectorized<T>::size();
|
| 215 |
+
if (count <= 0) {
|
| 216 |
+
break;
|
| 217 |
+
}
|
| 218 |
+
}
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
bool has_inf_nan() const {
|
| 222 |
+
for (int i = 0; i < N; ++i) {
|
| 223 |
+
if (values[i].has_inf_nan()) {
|
| 224 |
+
return true;
|
| 225 |
+
}
|
| 226 |
+
}
|
| 227 |
+
return false;
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
VectorizedN<T, N> map(T (*const f)(T)) const {
|
| 231 |
+
VectorizedN<T, N> result;
|
| 232 |
+
for (int i = 0; i < N; ++i) {
|
| 233 |
+
result.values[i] = values[i].map(f);
|
| 234 |
+
}
|
| 235 |
+
return result;
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
VectorizedN<T, N> map(T (*const f)(const T&)) const {
|
| 239 |
+
VectorizedN<T, N> result;
|
| 240 |
+
for (int i = 0; i < N; ++i) {
|
| 241 |
+
result.values[i] = values[i].map(f);
|
| 242 |
+
}
|
| 243 |
+
return result;
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
VECTORIZEDN_DEFINE_UNARY_OP(isnan)
|
| 247 |
+
VECTORIZEDN_DEFINE_UNARY_OP(abs)
|
| 248 |
+
VECTORIZEDN_DEFINE_UNARY_OP(sgn)
|
| 249 |
+
VECTORIZEDN_DEFINE_UNARY_OP(angle)
|
| 250 |
+
VECTORIZEDN_DEFINE_UNARY_OP(real)
|
| 251 |
+
VECTORIZEDN_DEFINE_UNARY_OP(imag)
|
| 252 |
+
VECTORIZEDN_DEFINE_UNARY_OP(conj)
|
| 253 |
+
VECTORIZEDN_DEFINE_UNARY_OP(acos)
|
| 254 |
+
VECTORIZEDN_DEFINE_UNARY_OP(acosh)
|
| 255 |
+
VECTORIZEDN_DEFINE_UNARY_OP(asin)
|
| 256 |
+
VECTORIZEDN_DEFINE_UNARY_OP(asinh)
|
| 257 |
+
VECTORIZEDN_DEFINE_UNARY_OP(atan)
|
| 258 |
+
VECTORIZEDN_DEFINE_UNARY_OP(atanh)
|
| 259 |
+
VECTORIZEDN_DEFINE_BINARY_OP(atan2)
|
| 260 |
+
VECTORIZEDN_DEFINE_BINARY_OP(copysign)
|
| 261 |
+
VECTORIZEDN_DEFINE_UNARY_OP(erf)
|
| 262 |
+
VECTORIZEDN_DEFINE_UNARY_OP(erfc)
|
| 263 |
+
VECTORIZEDN_DEFINE_UNARY_OP(erfinv)
|
| 264 |
+
VECTORIZEDN_DEFINE_UNARY_OP(exp)
|
| 265 |
+
VECTORIZEDN_DEFINE_UNARY_OP(exp2)
|
| 266 |
+
VECTORIZEDN_DEFINE_UNARY_OP(expm1)
|
| 267 |
+
VECTORIZEDN_DEFINE_UNARY_OP(exp_u20)
|
| 268 |
+
VECTORIZEDN_DEFINE_UNARY_OP(fexp_u20)
|
| 269 |
+
VECTORIZEDN_DEFINE_UNARY_OP(frac)
|
| 270 |
+
VECTORIZEDN_DEFINE_BINARY_OP(fmod)
|
| 271 |
+
VECTORIZEDN_DEFINE_UNARY_OP(log)
|
| 272 |
+
VECTORIZEDN_DEFINE_UNARY_OP(log10)
|
| 273 |
+
VECTORIZEDN_DEFINE_UNARY_OP(log1p)
|
| 274 |
+
VECTORIZEDN_DEFINE_UNARY_OP(log2)
|
| 275 |
+
VECTORIZEDN_DEFINE_UNARY_OP(ceil)
|
| 276 |
+
VECTORIZEDN_DEFINE_UNARY_OP(cos)
|
| 277 |
+
VECTORIZEDN_DEFINE_UNARY_OP(cosh)
|
| 278 |
+
VECTORIZEDN_DEFINE_UNARY_OP(floor)
|
| 279 |
+
VECTORIZEDN_DEFINE_BINARY_OP(hypot)
|
| 280 |
+
VECTORIZEDN_DEFINE_UNARY_OP(i0)
|
| 281 |
+
VECTORIZEDN_DEFINE_UNARY_OP(i0e)
|
| 282 |
+
VECTORIZEDN_DEFINE_UNARY_OP(digamma)
|
| 283 |
+
VECTORIZEDN_DEFINE_BINARY_OP(igamma)
|
| 284 |
+
VECTORIZEDN_DEFINE_BINARY_OP(igammac)
|
| 285 |
+
VECTORIZEDN_DEFINE_UNARY_OP(neg)
|
| 286 |
+
VECTORIZEDN_DEFINE_BINARY_OP(nextafter)
|
| 287 |
+
VECTORIZEDN_DEFINE_UNARY_OP(round)
|
| 288 |
+
VECTORIZEDN_DEFINE_UNARY_OP(sin)
|
| 289 |
+
VECTORIZEDN_DEFINE_UNARY_OP(sinh)
|
| 290 |
+
VECTORIZEDN_DEFINE_UNARY_OP(tan)
|
| 291 |
+
VECTORIZEDN_DEFINE_UNARY_OP(tanh)
|
| 292 |
+
VECTORIZEDN_DEFINE_UNARY_OP(trunc)
|
| 293 |
+
VECTORIZEDN_DEFINE_UNARY_OP(lgamma)
|
| 294 |
+
VECTORIZEDN_DEFINE_UNARY_OP(sqrt)
|
| 295 |
+
VECTORIZEDN_DEFINE_UNARY_OP(reciprocal)
|
| 296 |
+
VECTORIZEDN_DEFINE_UNARY_OP(rsqrt)
|
| 297 |
+
VECTORIZEDN_DEFINE_BINARY_OP(pow)
|
| 298 |
+
VECTORIZEDN_DEFINE_BINARY_OP(operator==)
|
| 299 |
+
VECTORIZEDN_DEFINE_BINARY_OP(operator!=)
|
| 300 |
+
VECTORIZEDN_DEFINE_BINARY_OP(operator>=)
|
| 301 |
+
VECTORIZEDN_DEFINE_BINARY_OP(operator<=)
|
| 302 |
+
VECTORIZEDN_DEFINE_BINARY_OP(operator>)
|
| 303 |
+
VECTORIZEDN_DEFINE_BINARY_OP(operator<)
|
| 304 |
+
VECTORIZEDN_DEFINE_BINARY_OP(eq)
|
| 305 |
+
VECTORIZEDN_DEFINE_BINARY_OP(ne)
|
| 306 |
+
VECTORIZEDN_DEFINE_BINARY_OP(gt)
|
| 307 |
+
VECTORIZEDN_DEFINE_BINARY_OP(ge)
|
| 308 |
+
VECTORIZEDN_DEFINE_BINARY_OP(lt)
|
| 309 |
+
VECTORIZEDN_DEFINE_BINARY_OP(le)
|
| 310 |
+
|
| 311 |
+
#undef VECTORIZEDN_DEFINE_UNARY_OP
|
| 312 |
+
#undef VECTORIZEDN_DEFINE_BINARY_OP
|
| 313 |
+
};
|
| 314 |
+
|
| 315 |
+
#define VECTORIZEDN_DEFINE_UNARY_OP_GLOBAL(op) \
|
| 316 |
+
template <typename T, int N> \
|
| 317 |
+
inline VectorizedN<T, N> op(const VectorizedN<T, N>& a) { \
|
| 318 |
+
return a.unary_op([](const Vectorized<T>& a) { return op(a); }); \
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
#define VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(op) \
|
| 322 |
+
template <typename T, int N> \
|
| 323 |
+
inline VectorizedN<T, N> op( \
|
| 324 |
+
const VectorizedN<T, N>& a, const VectorizedN<T, N>& b) { \
|
| 325 |
+
return a.binary_op(b, [](const Vectorized<T>& a, const Vectorized<T>& b) { \
|
| 326 |
+
return op(a, b); \
|
| 327 |
+
}); \
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
#define VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL(op) \
|
| 331 |
+
template <typename T, int N> \
|
| 332 |
+
inline VectorizedN<T, N> op( \
|
| 333 |
+
const VectorizedN<T, N>& a, \
|
| 334 |
+
const VectorizedN<T, N>& b, \
|
| 335 |
+
const VectorizedN<T, N>& c) { \
|
| 336 |
+
return a.ternary_op( \
|
| 337 |
+
b, \
|
| 338 |
+
c, \
|
| 339 |
+
[](const Vectorized<T>& a, \
|
| 340 |
+
const Vectorized<T>& b, \
|
| 341 |
+
const Vectorized<T>& c) { return op(a, b, c); }); \
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
#define VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(op) \
|
| 345 |
+
template <typename T, int N> \
|
| 346 |
+
inline VectorizedN<T, N>& op( \
|
| 347 |
+
VectorizedN<T, N>& a, const VectorizedN<T, N>& b) { \
|
| 348 |
+
a = a.binary_op(b, [](const Vectorized<T>& a, const Vectorized<T>& b) { \
|
| 349 |
+
return op(a, b); \
|
| 350 |
+
}); \
|
| 351 |
+
return a; \
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator+)
|
| 355 |
+
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator-)
|
| 356 |
+
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator*)
|
| 357 |
+
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator/)
|
| 358 |
+
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator%)
|
| 359 |
+
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator||)
|
| 360 |
+
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator<<)
|
| 361 |
+
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator>>)
|
| 362 |
+
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(maximum)
|
| 363 |
+
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(minimum)
|
| 364 |
+
VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL(fmadd)
|
| 365 |
+
VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL(fmsub)
|
| 366 |
+
VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL(clamp)
|
| 367 |
+
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp_max)
|
| 368 |
+
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp_min)
|
| 369 |
+
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator&)
|
| 370 |
+
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator|)
|
| 371 |
+
VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator^)
|
| 372 |
+
VECTORIZEDN_DEFINE_UNARY_OP_GLOBAL(operator~)
|
| 373 |
+
|
| 374 |
+
VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator+=)
|
| 375 |
+
VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator-=)
|
| 376 |
+
VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator*=)
|
| 377 |
+
VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator/=)
|
| 378 |
+
VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator%=)
|
| 379 |
+
VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator<<=)
|
| 380 |
+
VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(operator>>=)
|
| 381 |
+
|
| 382 |
+
#undef VECTORIZEDN_DEFINE_UNARY_OP_GLOBAL
|
| 383 |
+
#undef VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL
|
| 384 |
+
#undef VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL
|
| 385 |
+
|
| 386 |
+
template <typename T, int N, typename OpVec>
|
| 387 |
+
inline T vec_reduce_all(const OpVec& vec_fun, VectorizedN<T, N> acc_vec) {
|
| 388 |
+
Vectorized<T> vec_result = acc_vec[0];
|
| 389 |
+
for (int i = 1; i < N; i++) {
|
| 390 |
+
vec_result = vec_fun(vec_result, acc_vec[i]);
|
| 391 |
+
}
|
| 392 |
+
return vec_reduce_all(vec_fun, vec_result);
|
| 393 |
+
}
|
| 394 |
+
|
| 395 |
+
template <typename T, int N>
|
| 396 |
+
std::ostream& operator<<(std::ostream& stream, const VectorizedN<T, N>& vec_n) {
|
| 397 |
+
stream << "vec_n[";
|
| 398 |
+
for (int i = 0; i < N; ++i) {
|
| 399 |
+
if (i != 0) {
|
| 400 |
+
stream << ", ";
|
| 401 |
+
}
|
| 402 |
+
stream << vec_n[i];
|
| 403 |
+
}
|
| 404 |
+
stream << ']';
|
| 405 |
+
return stream;
|
| 406 |
+
}
|
| 407 |
+
} // namespace CPU_CAPABILITY
|
| 408 |
+
} // namespace at::vec
|
| 409 |
+
|
| 410 |
+
#else
|
| 411 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 412 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpu/vec/vec_quant.h
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/cpu/vec/intrinsics.h>
|
| 5 |
+
#include <c10/util/Exception.h>
|
| 6 |
+
|
| 7 |
+
namespace at::vec {
|
| 8 |
+
// See Note [CPU_CAPABILITY namespace]
|
| 9 |
+
inline namespace CPU_CAPABILITY {
|
| 10 |
+
|
| 11 |
+
// Transpose a [4, 64] block to [64, 4] (with contiguous output, ld=4)
|
| 12 |
+
template <typename scalar_t, typename = std::enable_if_t<sizeof(scalar_t) == 1>>
|
| 13 |
+
static inline void transpose_pad_4x64_block(
|
| 14 |
+
const scalar_t* src,
|
| 15 |
+
scalar_t* dst,
|
| 16 |
+
int64_t ld_src,
|
| 17 |
+
int krem = 4,
|
| 18 |
+
int nrem = 64) {
|
| 19 |
+
#if defined(CPU_CAPABILITY_AVX512)
|
| 20 |
+
__m512i r[4];
|
| 21 |
+
// Load with mask if partial
|
| 22 |
+
if (nrem < 64) {
|
| 23 |
+
__mmask64 mask = (1ULL << nrem) - 1;
|
| 24 |
+
for (int i = 0; i < krem; ++i) {
|
| 25 |
+
r[i] = _mm512_maskz_loadu_epi8(mask, src + i * ld_src);
|
| 26 |
+
}
|
| 27 |
+
for (int i = krem; i < 4; ++i) {
|
| 28 |
+
r[i] = _mm512_setzero_si512();
|
| 29 |
+
}
|
| 30 |
+
} else {
|
| 31 |
+
for (int i = 0; i < krem; ++i) {
|
| 32 |
+
r[i] = _mm512_loadu_si512(
|
| 33 |
+
reinterpret_cast<const __m512i*>(src + i * ld_src));
|
| 34 |
+
}
|
| 35 |
+
for (int i = krem; i < 4; ++i) {
|
| 36 |
+
r[i] = _mm512_setzero_si512();
|
| 37 |
+
}
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
// Transpose 4x64 bytes using unpack and shuffle
|
| 41 |
+
__m512i t0 = _mm512_unpacklo_epi8(r[0], r[1]);
|
| 42 |
+
__m512i t1 = _mm512_unpackhi_epi8(r[0], r[1]);
|
| 43 |
+
__m512i t2 = _mm512_unpacklo_epi8(r[2], r[3]);
|
| 44 |
+
__m512i t3 = _mm512_unpackhi_epi8(r[2], r[3]);
|
| 45 |
+
|
| 46 |
+
__m512i u0 = _mm512_unpacklo_epi16(t0, t2);
|
| 47 |
+
__m512i u1 = _mm512_unpackhi_epi16(t0, t2);
|
| 48 |
+
__m512i u2 = _mm512_unpacklo_epi16(t1, t3);
|
| 49 |
+
__m512i u3 = _mm512_unpackhi_epi16(t1, t3);
|
| 50 |
+
|
| 51 |
+
__m512i v0 = _mm512_shuffle_i32x4(u0, u1, 0x88);
|
| 52 |
+
__m512i v1 = _mm512_shuffle_i32x4(u0, u1, 0xdd);
|
| 53 |
+
__m512i v2 = _mm512_shuffle_i32x4(u2, u3, 0x88);
|
| 54 |
+
__m512i v3 = _mm512_shuffle_i32x4(u2, u3, 0xdd);
|
| 55 |
+
|
| 56 |
+
__m512i r0 = _mm512_shuffle_i32x4(v0, v2, 0x88);
|
| 57 |
+
__m512i r1 = _mm512_shuffle_i32x4(v1, v3, 0x88);
|
| 58 |
+
__m512i r2 = _mm512_shuffle_i32x4(v0, v2, 0xdd);
|
| 59 |
+
__m512i r3 = _mm512_shuffle_i32x4(v1, v3, 0xdd);
|
| 60 |
+
|
| 61 |
+
// Store output
|
| 62 |
+
if (nrem < 16) {
|
| 63 |
+
__mmask64 mask = (1ULL << (nrem * 4)) - 1;
|
| 64 |
+
_mm512_mask_storeu_epi8(dst, mask, r0);
|
| 65 |
+
} else if (nrem == 16) {
|
| 66 |
+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0);
|
| 67 |
+
} else if (nrem < 32) {
|
| 68 |
+
int n_bytes1 = 64;
|
| 69 |
+
int n_bytes2 = (nrem * 4) - n_bytes1;
|
| 70 |
+
__mmask64 mask = (1ULL << n_bytes2) - 1;
|
| 71 |
+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0);
|
| 72 |
+
_mm512_mask_storeu_epi8(reinterpret_cast<__m512i*>(dst + 64), mask, r1);
|
| 73 |
+
} else if (nrem == 32) {
|
| 74 |
+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0);
|
| 75 |
+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1);
|
| 76 |
+
} else if (nrem < 48) {
|
| 77 |
+
int n_bytes1 = 64 * 2;
|
| 78 |
+
int n_bytes2 = (nrem * 4) - n_bytes1;
|
| 79 |
+
__mmask64 mask = (1ULL << n_bytes2) - 1;
|
| 80 |
+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0);
|
| 81 |
+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1);
|
| 82 |
+
_mm512_mask_storeu_epi8(reinterpret_cast<__m512i*>(dst + 64 * 2), mask, r2);
|
| 83 |
+
} else if (nrem == 48) {
|
| 84 |
+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0);
|
| 85 |
+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1);
|
| 86 |
+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64 * 2), r2);
|
| 87 |
+
} else if (nrem < 64) {
|
| 88 |
+
int n_bytes1 = 64 * 3;
|
| 89 |
+
int n_bytes2 = (nrem * 4) - n_bytes1;
|
| 90 |
+
__mmask64 mask = (1ULL << n_bytes2) - 1;
|
| 91 |
+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0);
|
| 92 |
+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1);
|
| 93 |
+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64 * 2), r2);
|
| 94 |
+
_mm512_mask_storeu_epi8(reinterpret_cast<__m512i*>(dst + 64 * 3), mask, r3);
|
| 95 |
+
} else {
|
| 96 |
+
// normal case, nrem == 64
|
| 97 |
+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0);
|
| 98 |
+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1);
|
| 99 |
+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64 * 2), r2);
|
| 100 |
+
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64 * 3), r3);
|
| 101 |
+
}
|
| 102 |
+
#else
|
| 103 |
+
TORCH_CHECK(
|
| 104 |
+
false,
|
| 105 |
+
"transpose_pad_4x64_block is only supported when AVX-512 is supported")
|
| 106 |
+
#endif
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
// Reorder [K, N] → [K/4, N, 4] (VNNI4-style layout for bit8)
|
| 110 |
+
template <typename scalar_t, typename = std::enable_if_t<sizeof(scalar_t) == 1>>
|
| 111 |
+
static inline void pack_vnni4(
|
| 112 |
+
const scalar_t* src,
|
| 113 |
+
scalar_t* dst,
|
| 114 |
+
int64_t ld_src,
|
| 115 |
+
int64_t K,
|
| 116 |
+
int64_t N) {
|
| 117 |
+
#if defined(CPU_CAPABILITY_AVX512)
|
| 118 |
+
int64_t bk = 0;
|
| 119 |
+
int64_t _K = K / 4 * 4;
|
| 120 |
+
int64_t _N = N / 64 * 64;
|
| 121 |
+
for (; bk < _K; bk += 4) {
|
| 122 |
+
int64_t bn = 0;
|
| 123 |
+
for (; bn < _N; bn += 64) {
|
| 124 |
+
transpose_pad_4x64_block(
|
| 125 |
+
src + bk * ld_src + bn, dst + bk * N + bn * 4, ld_src);
|
| 126 |
+
}
|
| 127 |
+
int64_t nrem = N - bn;
|
| 128 |
+
if (nrem > 0) {
|
| 129 |
+
transpose_pad_4x64_block(
|
| 130 |
+
src + bk * ld_src + bn, dst + bk * N + bn * 4, ld_src, 4, nrem);
|
| 131 |
+
}
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
// Handle leftover K rows (< 4)
|
| 135 |
+
if (K % 4 != 0) {
|
| 136 |
+
int krem = K - bk;
|
| 137 |
+
int64_t bn = 0;
|
| 138 |
+
for (; bn < _N; bn += 64) {
|
| 139 |
+
transpose_pad_4x64_block(
|
| 140 |
+
src + bk * ld_src + bn, dst + bk * N + bn * 4, ld_src, krem);
|
| 141 |
+
}
|
| 142 |
+
int64_t nrem = N - bn;
|
| 143 |
+
if (nrem > 0) {
|
| 144 |
+
transpose_pad_4x64_block(
|
| 145 |
+
src + bk * ld_src + bn, dst + bk * N + bn * 4, ld_src, krem, nrem);
|
| 146 |
+
}
|
| 147 |
+
}
|
| 148 |
+
#else
|
| 149 |
+
TORCH_CHECK(false, "pack_vnni4 is only supported when AVX-512 is supported")
|
| 150 |
+
#endif
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
// This is a helper function for transpose_pack_vnni4
|
| 154 |
+
// Transform a [4, 16] block (with incontiguous output)
|
| 155 |
+
// Src:
|
| 156 |
+
// a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 a11 a12 a13 a14 a15 a16
|
| 157 |
+
// b1 b2 b3 b4 b5 b6 b7 b8 b9 b10 b11 b12 b13 b14 b15 b16
|
| 158 |
+
// c1 c2 c3 c4 c5 c6 c7 c8 c9 c10 c11 c12 c13 c14 c15 c16
|
| 159 |
+
// d1 d2 d3 d4 d5 d6 d7 d8 d9 d10 d11 d12 d13 d14 d15 d16
|
| 160 |
+
// Dst:
|
| 161 |
+
// a1 a2 a3 a4 b1 b2 b3 b4 c1 c2 c3 c4 d1 d2 d3 d4
|
| 162 |
+
// a5 a6 a7 a8 b5 b6 b7 b8 c5 c6 c7 c8 d5 d6 d7 d8
|
| 163 |
+
// a9 a10 a11 a12 b9 b10 b11 b12 c9 c10 c11 c12 d9 d10 d11 d12
|
| 164 |
+
// a13 a14 a15 a16 b13 b14 b15 b16 c13 c14 c15 c16 d13 d14 d15 d16
|
| 165 |
+
template <typename scalar_t, typename = std::enable_if_t<sizeof(scalar_t) == 1>>
|
| 166 |
+
static inline void transpose_vnni4_pad_4x16_block(
|
| 167 |
+
const scalar_t* src,
|
| 168 |
+
scalar_t* dst,
|
| 169 |
+
int64_t ld_src,
|
| 170 |
+
int64_t ld_dst,
|
| 171 |
+
int krem = 4) {
|
| 172 |
+
#if defined(CPU_CAPABILITY_AVX512)
|
| 173 |
+
__m128i r[4];
|
| 174 |
+
for (int i = 0; i < krem; ++i) {
|
| 175 |
+
r[i] = _mm_loadu_si128(reinterpret_cast<const __m128i*>(src + i * ld_src));
|
| 176 |
+
}
|
| 177 |
+
for (int i = krem; i < 4; ++i) {
|
| 178 |
+
r[i] = _mm_setzero_si128();
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
// Transpose 4x16 bytes using unpack and shuffle
|
| 182 |
+
__m128i t0 = _mm_unpacklo_epi32(r[0], r[1]);
|
| 183 |
+
__m128i t1 = _mm_unpackhi_epi32(r[0], r[1]);
|
| 184 |
+
__m128i t2 = _mm_unpacklo_epi32(r[2], r[3]);
|
| 185 |
+
__m128i t3 = _mm_unpackhi_epi32(r[2], r[3]);
|
| 186 |
+
|
| 187 |
+
__m128i r0 = _mm_unpacklo_epi64(t0, t2);
|
| 188 |
+
__m128i r1 = _mm_unpackhi_epi64(t0, t2);
|
| 189 |
+
__m128i r2 = _mm_unpacklo_epi64(t1, t3);
|
| 190 |
+
__m128i r3 = _mm_unpackhi_epi64(t1, t3);
|
| 191 |
+
|
| 192 |
+
// Store output
|
| 193 |
+
if (krem == 4) {
|
| 194 |
+
// normal case
|
| 195 |
+
_mm_storeu_si128(reinterpret_cast<__m128i*>(dst), r0);
|
| 196 |
+
_mm_storeu_si128(reinterpret_cast<__m128i*>(dst + ld_dst), r1);
|
| 197 |
+
_mm_storeu_si128(reinterpret_cast<__m128i*>(dst + ld_dst * 2), r2);
|
| 198 |
+
_mm_storeu_si128(reinterpret_cast<__m128i*>(dst + ld_dst * 3), r3);
|
| 199 |
+
} else {
|
| 200 |
+
// masked case
|
| 201 |
+
__mmask16 mask = (1ULL << (krem * 4)) - 1;
|
| 202 |
+
_mm_mask_storeu_epi8(dst, mask, r0);
|
| 203 |
+
_mm_mask_storeu_epi8(reinterpret_cast<__m128i*>(dst + ld_dst), mask, r1);
|
| 204 |
+
_mm_mask_storeu_epi8(
|
| 205 |
+
reinterpret_cast<__m128i*>(dst + ld_dst * 2), mask, r2);
|
| 206 |
+
_mm_mask_storeu_epi8(
|
| 207 |
+
reinterpret_cast<__m128i*>(dst + ld_dst * 3), mask, r3);
|
| 208 |
+
}
|
| 209 |
+
#else
|
| 210 |
+
TORCH_CHECK(
|
| 211 |
+
false,
|
| 212 |
+
"transpose_vnni4_pad_4x16_block is only supported when AVX-512 is supported")
|
| 213 |
+
#endif
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
// Do the transpose packing fusion with VNNI4
|
| 217 |
+
// Reorder [K, N] → [N/4, K, 4] (VNNI4-style layout for bit8)
|
| 218 |
+
template <typename scalar_t, typename = std::enable_if_t<sizeof(scalar_t) == 1>>
|
| 219 |
+
static inline void transpose_pack_vnni4(
|
| 220 |
+
const scalar_t* src,
|
| 221 |
+
scalar_t* dst,
|
| 222 |
+
int64_t ld_src,
|
| 223 |
+
int64_t K,
|
| 224 |
+
int64_t N) {
|
| 225 |
+
#if defined(CPU_CAPABILITY_AVX512)
|
| 226 |
+
TORCH_CHECK(
|
| 227 |
+
N % 16 == 0, "N needs to be multiple of 16 for transpose_pack_vnni4");
|
| 228 |
+
int64_t bk = 0;
|
| 229 |
+
int64_t _K = K / 4 * 4;
|
| 230 |
+
for (; bk < _K; bk += 4) {
|
| 231 |
+
int64_t bn = 0;
|
| 232 |
+
for (; bn < N; bn += 16) {
|
| 233 |
+
transpose_vnni4_pad_4x16_block(
|
| 234 |
+
src + bk * ld_src + bn, dst + bn * K + bk * 4, ld_src, K * 4);
|
| 235 |
+
}
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
// Handle leftover K rows (< 4)
|
| 239 |
+
if (K % 4 != 0) {
|
| 240 |
+
int krem = K - bk;
|
| 241 |
+
int64_t bn = 0;
|
| 242 |
+
for (; bn < N; bn += 16) {
|
| 243 |
+
transpose_vnni4_pad_4x16_block(
|
| 244 |
+
src + bk * ld_src + bn, dst + bn * K + bk * 4, ld_src, K * 4, krem);
|
| 245 |
+
}
|
| 246 |
+
}
|
| 247 |
+
#else
|
| 248 |
+
TORCH_CHECK(
|
| 249 |
+
false, "transpose_pack_vnni4 is only supported when AVX-512 is supported")
|
| 250 |
+
#endif
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
} // namespace CPU_CAPABILITY
|
| 254 |
+
} // namespace at::vec
|
| 255 |
+
|
| 256 |
+
#else
|
| 257 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 258 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/ADInterpreters.h
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
#include <ATen/functorch/Interpreter.h>
|
| 4 |
+
|
| 5 |
+
namespace at::functorch {
|
| 6 |
+
|
| 7 |
+
// These are the interpreters for our AD transforms
|
| 8 |
+
// (grad, vjp and jvp).
|
| 9 |
+
// See NOTE: [functorch interpreter stack] for more details.
|
| 10 |
+
|
| 11 |
+
struct TORCH_API GradInterpreterPtr {
|
| 12 |
+
explicit GradInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Grad); }
|
| 13 |
+
TransformType key() const { return base_->key(); }
|
| 14 |
+
int64_t level() const { return base_->level(); }
|
| 15 |
+
void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
| 16 |
+
void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
|
| 17 |
+
bool prevGradMode() const {
|
| 18 |
+
return std::get<GradInterpreterMeta>(base_->meta()).prevGradMode_;
|
| 19 |
+
}
|
| 20 |
+
Tensor lift(const Tensor& tensor) const;
|
| 21 |
+
private:
|
| 22 |
+
const Interpreter* base_;
|
| 23 |
+
};
|
| 24 |
+
|
| 25 |
+
struct TORCH_API JvpInterpreterPtr {
|
| 26 |
+
explicit JvpInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Jvp); }
|
| 27 |
+
TransformType key() const { return base_->key(); }
|
| 28 |
+
int64_t level() const { return base_->level(); }
|
| 29 |
+
void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
| 30 |
+
void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
|
| 31 |
+
bool prevFwdGradMode() const {
|
| 32 |
+
return std::get<JvpInterpreterMeta>(base_->meta()).prevFwdGradMode_;
|
| 33 |
+
}
|
| 34 |
+
Tensor lift(const Tensor& tensor) const;
|
| 35 |
+
private:
|
| 36 |
+
const Interpreter* base_;
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
} // namespace at::functorch
|
| 40 |
+
|
| 41 |
+
#else
|
| 42 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 43 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchedFallback.h
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
// All rights reserved.
|
| 4 |
+
//
|
| 5 |
+
// This source code is licensed under the BSD-style license found in the
|
| 6 |
+
// LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
#pragma once
|
| 9 |
+
#include <ATen/ATen.h>
|
| 10 |
+
#include <ATen/core/op_registration/op_registration.h>
|
| 11 |
+
#include <torch/library.h>
|
| 12 |
+
|
| 13 |
+
namespace at::functorch {
|
| 14 |
+
|
| 15 |
+
// This file contains code for the vmap fallback (also known as the
|
| 16 |
+
// BatchedTensor fallback or the Batched fallback). This code runs
|
| 17 |
+
// when an operation doesn't have a batching rule implemented.
|
| 18 |
+
|
| 19 |
+
// If an operator doesn't have a batching rule implemented then we fallback
|
| 20 |
+
// to this implementation. The fallback doesn't work on out= variants or
|
| 21 |
+
// view operations; that is, it works for out-of-place operations and
|
| 22 |
+
// in-place non-view operations.
|
| 23 |
+
//
|
| 24 |
+
// For out-of-place operations, the fallback effectively takes all of the
|
| 25 |
+
// BatchedTensors in `stack`, slices them, and runs `op` on all of the
|
| 26 |
+
// corresponding slices to produce slices of the outputs. The output slices
|
| 27 |
+
// then get `torch.stack`ed to create the
|
| 28 |
+
// final returns.
|
| 29 |
+
//
|
| 30 |
+
// The performance of the fallback is not very good because it introduces an
|
| 31 |
+
// extra copy from stacking the sliced outputs. Because of this, we prefer to
|
| 32 |
+
// write batching rules for operators whenever possible.
|
| 33 |
+
void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
| 34 |
+
void batchedNestedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
| 35 |
+
|
| 36 |
+
void vmapErrorFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
| 37 |
+
|
| 38 |
+
// The vmap fallback emits a warning by default, but it may be disabled if
|
| 39 |
+
// the user finds it to be too annoying.
|
| 40 |
+
TORCH_API bool isVmapFallbackWarningEnabled();
|
| 41 |
+
TORCH_API void setVmapFallbackWarningEnabled(bool enabled);
|
| 42 |
+
|
| 43 |
+
// Used for testing. The vmap fallback is enabled by default. When it is disabled,
|
| 44 |
+
// it raises an error.
|
| 45 |
+
TORCH_API bool isVmapFallbackEnabled();
|
| 46 |
+
TORCH_API void setVmapFallbackEnabled(bool enabled);
|
| 47 |
+
|
| 48 |
+
template <typename A> A vector_to_result(const std::vector<IValue>& buffer) {
|
| 49 |
+
return buffer[0].to<A>();
|
| 50 |
+
}
|
| 51 |
+
template <typename A, typename B> std::tuple<A, B> vector_to_result(const std::vector<IValue>& buffer) {
|
| 52 |
+
return std::make_tuple(buffer[0].to<A>(), buffer[1].to<B>());
|
| 53 |
+
}
|
| 54 |
+
template <typename A, typename B, typename C> std::tuple<A, B, C> vector_to_result(const std::vector<IValue>& buffer) {
|
| 55 |
+
return std::make_tuple(buffer[0].to<A>(), buffer[1].to<B>(), buffer[2].to<B>());
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
// slow_fallback is a way to call the vmap fallback inside some boxed kernel.
|
| 59 |
+
// There is probably some better way to metaprogram this.
|
| 60 |
+
template <typename Ret>
|
| 61 |
+
Ret slow_fallback(const c10::OperatorHandle& op, ArrayRef<IValue> args) {
|
| 62 |
+
std::vector<IValue> stack(args.begin(), args.end());
|
| 63 |
+
batchedTensorForLoopFallback(op, &stack);
|
| 64 |
+
return vector_to_result<Ret>(stack);
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
template <typename A, typename B>
|
| 68 |
+
std::tuple<A, B> slow_fallback(const c10::OperatorHandle& op, ArrayRef<IValue> args) {
|
| 69 |
+
std::vector<IValue> stack(args.begin(), args.end());
|
| 70 |
+
batchedTensorForLoopFallback(op, &stack);
|
| 71 |
+
return vector_to_result<A, B>(stack);
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
template <typename A, typename B, typename C>
|
| 75 |
+
std::tuple<A, B, C> slow_fallback(const c10::OperatorHandle& op, ArrayRef<IValue> args) {
|
| 76 |
+
std::vector<IValue> stack(args.begin(), args.end());
|
| 77 |
+
batchedTensorForLoopFallback(op, &stack);
|
| 78 |
+
return vector_to_result<A, B, C>(stack);
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
} // namespace at::functorch
|
| 83 |
+
|
| 84 |
+
#else
|
| 85 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 86 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/BatchedTensorImpl.h
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
// All rights reserved.
|
| 4 |
+
//
|
| 5 |
+
// This source code is licensed under the BSD-style license found in the
|
| 6 |
+
// LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
#pragma once
|
| 9 |
+
|
| 10 |
+
#include <bitset>
|
| 11 |
+
|
| 12 |
+
#include <ATen/ArrayRef.h>
|
| 13 |
+
#include <ATen/SmallVector.h>
|
| 14 |
+
#include <ATen/Tensor.h>
|
| 15 |
+
|
| 16 |
+
namespace at::functorch {
|
| 17 |
+
|
| 18 |
+
using Tensor = at::Tensor;
|
| 19 |
+
|
| 20 |
+
// We assume this in a few other places in the codebase,
|
| 21 |
+
// but there isn't a centralized definition.
|
| 22 |
+
constexpr int64_t kVmapMaxTensorDims = 64;
|
| 23 |
+
|
| 24 |
+
// The valid vmap levels range from [0, 64). This effectively means that we
|
| 25 |
+
// support a maximum of 64 nested vmaps.
|
| 26 |
+
constexpr int64_t kVmapNumLevels = 64;
|
| 27 |
+
|
| 28 |
+
// Store this number of elements of BatchDims on the stack. Most people will
|
| 29 |
+
// probably use <= 5 nested vmaps, but adjust this number as necessary.
|
| 30 |
+
constexpr int64_t kBatchDimsStackSize = 5;
|
| 31 |
+
|
| 32 |
+
// A BatchedTensorImpl holds an underlying Tensor and a single batch dim
|
| 33 |
+
// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
|
| 34 |
+
// BatchedTensorImpl.
|
| 35 |
+
//
|
| 36 |
+
// The batch dimensions are treated as being "private"; they are not user-visible.
|
| 37 |
+
// For example, in the following Tensor,
|
| 38 |
+
// bt = BatchedTensorImpl(ones(2, 3, 5, 7), lvl=1, dim=0)
|
| 39 |
+
// dimension 0 is batch dimension.
|
| 40 |
+
//
|
| 41 |
+
// bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
|
| 42 |
+
// dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7) tensor.
|
| 43 |
+
struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
|
| 44 |
+
explicit BatchedTensorImpl(at::DispatchKeySet key_set, Tensor value, int64_t dim, int64_t level);
|
| 45 |
+
|
| 46 |
+
// Returns batch dimension of this tensor
|
| 47 |
+
int64_t bdim() const { return bdim_; }
|
| 48 |
+
|
| 49 |
+
// Returns batch dimension of this tensor
|
| 50 |
+
int64_t level() const { return level_; }
|
| 51 |
+
|
| 52 |
+
// BatchedTensorImpl wraps a Tensor
|
| 53 |
+
const Tensor& value() const { return value_; }
|
| 54 |
+
|
| 55 |
+
// Given a public dimension index, return the dimension index in the underlying
|
| 56 |
+
// value() tensor.
|
| 57 |
+
// For example, if we have
|
| 58 |
+
// bt = BatchedTensorImpl(ones(2, 3, 5, 7), lvl=1, dim=0)
|
| 59 |
+
// bt.actualDim(0) -> 1
|
| 60 |
+
// bt.actualDim(1) -> 2
|
| 61 |
+
// bt.actualDim(2) -> 3
|
| 62 |
+
// bt.actualDim(3) -> Error
|
| 63 |
+
int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
|
| 64 |
+
|
| 65 |
+
IntArrayRef sizes_custom() const override;
|
| 66 |
+
SymIntArrayRef sym_sizes_custom() const override;
|
| 67 |
+
int64_t size_custom(int64_t d) const override;
|
| 68 |
+
c10::SymInt sym_size_custom(int64_t d) const override;
|
| 69 |
+
// We have to override this because we opted into CustomStrides
|
| 70 |
+
IntArrayRef strides_custom() const override;
|
| 71 |
+
SymIntArrayRef sym_strides_custom() const override;
|
| 72 |
+
// Override a bunch of methods inherited from TensorImpl to return error messages.
|
| 73 |
+
c10::SymBool sym_is_contiguous_custom(at::MemoryFormat memory_format) const override;
|
| 74 |
+
void set_size(int64_t dim, int64_t new_size) override;
|
| 75 |
+
void set_stride(int64_t dim, int64_t new_stride) override;
|
| 76 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 77 |
+
const c10::VariableVersion& version_counter,
|
| 78 |
+
bool allow_tensor_metadata_change) const override;
|
| 79 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 80 |
+
c10::VariableVersion&& version_counter,
|
| 81 |
+
bool allow_tensor_metadata_change) const override;
|
| 82 |
+
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
|
| 83 |
+
#ifdef DEBUG
|
| 84 |
+
bool has_storage() const override;
|
| 85 |
+
#endif
|
| 86 |
+
|
| 87 |
+
void refreshTensorMetadata();
|
| 88 |
+
|
| 89 |
+
// Used in torchdim. torchdim uses non-lexical BatchedTensor; the way it
|
| 90 |
+
// accomplishes this is a hack where it is able to modify the levels of
|
| 91 |
+
// BatchedTensor to match the level of the current vmap transform.
|
| 92 |
+
void _unsafe_set_level(int64_t level) {
|
| 93 |
+
level_ = level;
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
// Used in batching rule for in-place view operations that can change
|
| 97 |
+
// the index of the bdim (think squeeze_, unsqueeze_)
|
| 98 |
+
void unsafe_set_bdim(int64_t bdim) {
|
| 99 |
+
// NB: you MUST call refreshTensorMetadata after doing this.
|
| 100 |
+
bdim_ = bdim;
|
| 101 |
+
}
|
| 102 |
+
private:
|
| 103 |
+
// see NOTE: [BatchedTensorImpl levels invariant]
|
| 104 |
+
void checkInvariants() const;
|
| 105 |
+
const char* tensorimpl_type_name() const override;
|
| 106 |
+
|
| 107 |
+
Tensor value_;
|
| 108 |
+
|
| 109 |
+
int64_t level_;
|
| 110 |
+
int64_t bdim_;
|
| 111 |
+
};
|
| 112 |
+
|
| 113 |
+
// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
|
| 114 |
+
// BatchedTensorImpl.
|
| 115 |
+
inline bool isBatchedTensor(const Tensor& tensor) {
|
| 116 |
+
return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::FuncTorchBatched) ||
|
| 117 |
+
tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::BatchedNestedTensor);
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
// It is unsafe to call this on a Tensor that is not backed by a
|
| 121 |
+
// BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible.
|
| 122 |
+
inline BatchedTensorImpl* unsafeGetBatchedImpl(const Tensor& tensor) {
|
| 123 |
+
return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
inline BatchedTensorImpl* maybeGetBatchedImpl(const Tensor& tensor) {
|
| 127 |
+
if (!isBatchedTensor(tensor)) {
|
| 128 |
+
return nullptr;
|
| 129 |
+
}
|
| 130 |
+
return unsafeGetBatchedImpl(tensor);
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
// Returns a bitset. If bit i is set, then that means dim i is a batchdim.
|
| 134 |
+
inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset(int64_t dim) {
|
| 135 |
+
std::bitset<kVmapMaxTensorDims> is_bdim;
|
| 136 |
+
is_bdim.set(dim);
|
| 137 |
+
return is_bdim;
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
// Creates a bitset for the given level
|
| 141 |
+
inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(int64_t level) {
|
| 142 |
+
std::bitset<kVmapNumLevels> result;
|
| 143 |
+
result.set(level);
|
| 144 |
+
return result;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
// Use this to construct a BatchedTensor from a regular Tensor
|
| 148 |
+
TORCH_API Tensor makeBatched(Tensor tensor, int64_t dim, int64_t level);
|
| 149 |
+
|
| 150 |
+
// Adds a batch dim to `tensor`, returning a BatchedTensor
|
| 151 |
+
TORCH_API Tensor addBatchDim(Tensor tensor, int64_t dim, int64_t level);
|
| 152 |
+
|
| 153 |
+
// Certain dispatch keys must be propagated to the BatchedTensor (or, in general,
|
| 154 |
+
// any wrapper Tensor subclasses). This is because there are methods on Tensor
|
| 155 |
+
// that skip dispatch and check for the presence of a dispatch key (e.g. is_cpu()).
|
| 156 |
+
// TODO: should probably contain more (or all?) backend keys
|
| 157 |
+
constexpr DispatchKeySet kKeysToPropagateToWrapper({
|
| 158 |
+
DispatchKey::Negative,
|
| 159 |
+
DispatchKey::Conjugate,
|
| 160 |
+
DispatchKey::XLA,
|
| 161 |
+
DispatchKey::XPU,
|
| 162 |
+
DispatchKey::HPU,
|
| 163 |
+
DispatchKey::CUDA,
|
| 164 |
+
DispatchKey::CPU,
|
| 165 |
+
DispatchKey::PrivateUse1,
|
| 166 |
+
DispatchKey::SparseCPU,
|
| 167 |
+
DispatchKey::SparseCUDA,
|
| 168 |
+
DispatchKey::SparseCsrCPU,
|
| 169 |
+
DispatchKey::SparseCsrCUDA,
|
| 170 |
+
});
|
| 171 |
+
|
| 172 |
+
inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) {
|
| 173 |
+
auto key_set = tensor.unsafeGetTensorImpl()->key_set();
|
| 174 |
+
return key_set & kKeysToPropagateToWrapper;
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
} // namespace at::functorch
|
| 178 |
+
|
| 179 |
+
#else
|
| 180 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 181 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/DynamicLayer.h
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
// All rights reserved.
|
| 4 |
+
//
|
| 5 |
+
// This source code is licensed under the BSD-style license found in the
|
| 6 |
+
// LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
#pragma once
|
| 9 |
+
#include <ATen/functorch/Macros.h>
|
| 10 |
+
#include <c10/core/DispatchKey.h>
|
| 11 |
+
#include <ATen/core/function_schema.h>
|
| 12 |
+
#include <optional>
|
| 13 |
+
#include <c10/core/impl/LocalDispatchKeySet.h>
|
| 14 |
+
#include <ATen/functorch/Interpreter.h>
|
| 15 |
+
#include <ATen/functorch/VmapInterpreter.h>
|
| 16 |
+
#include <ATen/functorch/ADInterpreters.h>
|
| 17 |
+
#include <ATen/functorch/FunctionalizeInterpreter.h>
|
| 18 |
+
|
| 19 |
+
// Forward declared
|
| 20 |
+
namespace c10 { struct AutogradMetaInterface; }
|
| 21 |
+
|
| 22 |
+
namespace at::functorch {
|
| 23 |
+
|
| 24 |
+
// This file contains the implementation of functorch's interpreter stack.
|
| 25 |
+
// See NOTE: [functorch interpreter stack] first before reading on.
|
| 26 |
+
//
|
| 27 |
+
// NB: the functorch interpreter stack is also referred to as:
|
| 28 |
+
// - the "dynamic layer stack" -- an older name for "interpreter" was
|
| 29 |
+
// "dynamic layer".
|
| 30 |
+
// - the "functorch mode stack". You can think of each functorch transform as a
|
| 31 |
+
// "mode" (in the same sense as torch_dispatch mode or torch_function mode),
|
| 32 |
+
// and functorch being an implementation of a "mode stack" where the modes
|
| 33 |
+
// may be arbitrary composed.
|
| 34 |
+
|
| 35 |
+
// DynamicLayer is basically the same thing as an Interpreter.
|
| 36 |
+
// It represents a functorch transform and it holds an Interpreter,
|
| 37 |
+
// which contains metadata related to the transform and instructions on
|
| 38 |
+
// how to perform the transform.
|
| 39 |
+
//
|
| 40 |
+
// TODO: we can excise DynamicLayer in favor of Interpreter,
|
| 41 |
+
// But I am going to leave it for now as a compatibility shim to avoid
|
| 42 |
+
// needing to refactor a lot of callsites...
|
| 43 |
+
struct TORCH_API DynamicLayer {
|
| 44 |
+
explicit DynamicLayer(
|
| 45 |
+
TransformType transform_type,
|
| 46 |
+
int64_t layerId,
|
| 47 |
+
std::optional<c10::SymInt> batchSize = std::nullopt,
|
| 48 |
+
std::optional<RandomnessType> randomness = std::nullopt,
|
| 49 |
+
std::optional<bool> prev_grad_mode = std::nullopt,
|
| 50 |
+
std::optional<bool> pre_fwd_grad_mode = std::nullopt,
|
| 51 |
+
std::optional<bool> functionalize_add_back_views = std::nullopt);
|
| 52 |
+
|
| 53 |
+
TransformType key() const;
|
| 54 |
+
int64_t layerId() const;
|
| 55 |
+
|
| 56 |
+
const Interpreter& interpreter() const { return interpreter_; }
|
| 57 |
+
Interpreter& interpreter() { return interpreter_; }
|
| 58 |
+
|
| 59 |
+
// Only valid for vmap
|
| 60 |
+
c10::SymInt batchSize() const;
|
| 61 |
+
RandomnessType randomness() const;
|
| 62 |
+
|
| 63 |
+
private:
|
| 64 |
+
Interpreter interpreter_;
|
| 65 |
+
};
|
| 66 |
+
|
| 67 |
+
TORCH_API int64_t initAndPushDynamicLayer(
|
| 68 |
+
TransformType transform_type,
|
| 69 |
+
std::optional<c10::SymInt> batch_size = std::nullopt,
|
| 70 |
+
std::optional<RandomnessType> randomness = std::nullopt,
|
| 71 |
+
std::optional<bool> prev_grad_mode = std::nullopt,
|
| 72 |
+
std::optional<bool> prev_fwd_grad_mode = std::nullopt,
|
| 73 |
+
std::optional<bool> functionalize_add_back_views = std::nullopt);
|
| 74 |
+
TORCH_API DynamicLayer popDynamicLayerAndDeleteMetadata();
|
| 75 |
+
TORCH_API std::optional<DynamicLayer> maybeCurrentDynamicLayer();
|
| 76 |
+
TORCH_API const std::vector<DynamicLayer>& getDynamicLayerStack();
|
| 77 |
+
TORCH_API void setDynamicLayerStack(const std::vector<DynamicLayer>& stack);
|
| 78 |
+
TORCH_API void setDynamicLayerFrontBackKeysIncluded(bool included);
|
| 79 |
+
|
| 80 |
+
// NOTE: [Life handles and lexically scoped transforms]
|
| 81 |
+
// functorch transforms are lexically scoped.
|
| 82 |
+
// Given a level, we store a "life handle" that is a boolean that tells us if the
|
| 83 |
+
// transform with that level is active or not.
|
| 84 |
+
//
|
| 85 |
+
// functorch's TensorWrapper (for grad transforms) stores a life handle.
|
| 86 |
+
// If a TensorWrapper escapes from the scope of the transform, then somehow
|
| 87 |
+
// it must know it escaped; it can tell by querying the life handle.
|
| 88 |
+
TORCH_API const std::shared_ptr<bool>& getLifeHandleForLevel(int64_t level);
|
| 89 |
+
|
| 90 |
+
// Returns if an operator is in-place. An operator is inplace if:
|
| 91 |
+
// 1. The first argument is a Tensor and it is being written to
|
| 92 |
+
// 2. The first argument is being returned
|
| 93 |
+
// 3. No other arguments are aliased
|
| 94 |
+
// Here is an example of an in-place operator:
|
| 95 |
+
// add_(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
|
| 96 |
+
TORCH_API bool isInplaceOp(const c10::FunctionSchema& schema);
|
| 97 |
+
|
| 98 |
+
// Given the indices of unwrapped inputs and the schema, this returns the indices of any outputs that should remain unwrapped
|
| 99 |
+
TORCH_API std::optional<size_t> findAliasedOutput(const FunctionSchema& schema, const int64_t immutable_input);
|
| 100 |
+
|
| 101 |
+
TORCH_API Tensor unwrapIfDead(const Tensor& tensor);
|
| 102 |
+
TORCH_API bool isDeadTensorWrapper(const Tensor& tensor);
|
| 103 |
+
|
| 104 |
+
// Pretty printers
|
| 105 |
+
TORCH_API std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer);
|
| 106 |
+
TORCH_API std::ostream& operator<<(std::ostream& os, const std::vector<DynamicLayer>& dynamicLayerStack);
|
| 107 |
+
|
| 108 |
+
// While a functorch transform is active, torch.autograd.function._SingleLevelFunction
|
| 109 |
+
// is disabled by default. The following two APIs are APIs for enabling
|
| 110 |
+
// it. These are not user-facing APIs. We can delete this in the future, but
|
| 111 |
+
// it is useful for debugging when something goes wrong with the
|
| 112 |
+
// autograd.Function <> functorch interaction, which uses _SingleLevelFunction,
|
| 113 |
+
// because it leads to loud errors if something is incorrect.
|
| 114 |
+
TORCH_API void setSingleLevelAutogradFunctionAllowed(bool allowed);
|
| 115 |
+
TORCH_API bool getSingleLevelAutogradFunctionAllowed();
|
| 116 |
+
|
| 117 |
+
// While a functorch grad transform is active, Tensor.requires_grad_() gets
|
| 118 |
+
// disabled. These two functions are the mechanism to controlling that.
|
| 119 |
+
TORCH_API void setInplaceRequiresGradAllowed(bool allowed);
|
| 120 |
+
TORCH_API bool getInplaceRequiresGradAllowed();
|
| 121 |
+
|
| 122 |
+
TORCH_API DynamicLayer popDynamicLayer();
|
| 123 |
+
TORCH_API int64_t pushDynamicLayer(DynamicLayer&& layer);
|
| 124 |
+
|
| 125 |
+
} // namespace at::functorch
|
| 126 |
+
|
| 127 |
+
#else
|
| 128 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 129 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/FunctionalizeInterpreter.h
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
#include <ATen/functorch/Interpreter.h>
|
| 4 |
+
|
| 5 |
+
namespace at::functorch {
|
| 6 |
+
|
| 7 |
+
// This is the interpreter that handles the functionalize() transform.
|
| 8 |
+
// See NOTE: [functorch interpreter stack] for more details.
|
| 9 |
+
|
| 10 |
+
struct FunctionalizeInterpreterPtr {
|
| 11 |
+
explicit FunctionalizeInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Functionalize); }
|
| 12 |
+
TransformType key() const { return base_->key(); }
|
| 13 |
+
int64_t level() const { return base_->level(); }
|
| 14 |
+
void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
| 15 |
+
void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
|
| 16 |
+
bool functionalizeAddBackViews() const {
|
| 17 |
+
return std::get<FunctionalizeInterpreterMeta>(base_->meta()).functionalizeAddBackViews_;
|
| 18 |
+
}
|
| 19 |
+
private:
|
| 20 |
+
const Interpreter* base_;
|
| 21 |
+
};
|
| 22 |
+
|
| 23 |
+
} // namespace at::functorch
|
| 24 |
+
|
| 25 |
+
#else
|
| 26 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 27 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/Interpreter.h
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/functorch/Macros.h>
|
| 5 |
+
#include <ATen/core/dispatch/Dispatcher.h>
|
| 6 |
+
#include <c10/core/impl/LocalDispatchKeySet.h>
|
| 7 |
+
#include <c10/util/Exception.h>
|
| 8 |
+
#include <optional>
|
| 9 |
+
#include <bitset>
|
| 10 |
+
#include <utility>
|
| 11 |
+
#include <variant>
|
| 12 |
+
|
| 13 |
+
#include <nlohmann/json.hpp>
|
| 14 |
+
|
| 15 |
+
namespace at::functorch {
|
| 16 |
+
|
| 17 |
+
// NOTE: [functorch interpreter stack]
|
| 18 |
+
//
|
| 19 |
+
// functorch's dispatching system uses a stack of interpreters.
|
| 20 |
+
// Historically we've referred to this as the "DynamicLayerStack".
|
| 21 |
+
//
|
| 22 |
+
// An interpreter is something that reads in the code it is passed
|
| 23 |
+
// and then executes it. We have a different interpreter per-transform:
|
| 24 |
+
// the "VmapInterpreter" is responsible for reading in operators (like aten::mv)
|
| 25 |
+
// and executing the batched version of it (the batching rule for aten::mv).
|
| 26 |
+
//
|
| 27 |
+
// Concretely, each interpreter is responsible for two things:
|
| 28 |
+
//
|
| 29 |
+
// 1) process(ophandle, stack)
|
| 30 |
+
// Given an operator handle and a stack of arguments, the interpreter is
|
| 31 |
+
// responsible for figuring out how to execute the operation under the semantics
|
| 32 |
+
// of the interpreter. For e.g. VmapInterpreter, this is figuring out how to call
|
| 33 |
+
// the batching rule.
|
| 34 |
+
//
|
| 35 |
+
// The batching rules are stored as kernels on the FuncTorchBatched key, so the way
|
| 36 |
+
// VmapInterpreter calls the batching rule is roughly: (A) exclude all
|
| 37 |
+
// dispatch keys aside from the Batched key, (B) redispatch so we get to the
|
| 38 |
+
// Batched key.
|
| 39 |
+
//
|
| 40 |
+
// 2) sendToNextInterpreter(ophandle, stack)
|
| 41 |
+
// The VmapInterpreter, when it sees aten::mv, will process it into a call to
|
| 42 |
+
// aten::mm. It then needs to send the call to aten::mm to the next interpreter
|
| 43 |
+
// in the interpreter stack.
|
| 44 |
+
//
|
| 45 |
+
// The VmapInterpreter just does this via a call to ophandle.callBoxed(stack)
|
| 46 |
+
// and most Interpreters will implement it this way.
|
| 47 |
+
|
| 48 |
+
enum class RandomnessType {
|
| 49 |
+
Error, // always errors when calling a random function
|
| 50 |
+
Same, // randomness appears the same across batches
|
| 51 |
+
Different, // randomness appears different across batches
|
| 52 |
+
END
|
| 53 |
+
};
|
| 54 |
+
|
| 55 |
+
enum class TransformType {
|
| 56 |
+
Torch, // Unused
|
| 57 |
+
Vmap,
|
| 58 |
+
Grad, // reverse-mode AD, aka vjp
|
| 59 |
+
Jvp, // forward-mode AD
|
| 60 |
+
Functionalize,
|
| 61 |
+
};
|
| 62 |
+
|
| 63 |
+
std::ostream& operator<<(std::ostream& os, const TransformType& t);
|
| 64 |
+
|
| 65 |
+
// NOTE: [Interpreter "subclassing" design]
|
| 66 |
+
//
|
| 67 |
+
// How are various Interpreters for different transforms (vmap, grad, ...)
|
| 68 |
+
// implemented?
|
| 69 |
+
//
|
| 70 |
+
// Accessing interpreters is in the hot-path of functorch so we have a constraint
|
| 71 |
+
// that this code must be as fast as possible.
|
| 72 |
+
//
|
| 73 |
+
// As a result, we stay away from virtual methods and this causes our code
|
| 74 |
+
// to look a little funny.
|
| 75 |
+
//
|
| 76 |
+
// `Interpreter` is the struct for Interpreters. It holds ALL of the
|
| 77 |
+
// relevant information (what type of interpreter it is and the metadata).
|
| 78 |
+
// Metadata for each interpreter is represented as a Union (std::variant)
|
| 79 |
+
// of all possible metadata (VmapInterpreterMeta, GradInterpreterMeta, ...).
|
| 80 |
+
//
|
| 81 |
+
// Given an Interpreter, how do I get a "VmapInterpreter"? You may wish to do this
|
| 82 |
+
// if you want to access the metadata fields (like batchSize and randomness).
|
| 83 |
+
//
|
| 84 |
+
// Each type of interpreter (e.g. Vmap) has a convenience struct
|
| 85 |
+
// (e.g. VmapInterpreterPtr) associated with it.
|
| 86 |
+
//
|
| 87 |
+
// Construct the convenience struct with VmapInterpreterPtr(Interpreter*),
|
| 88 |
+
// and then one can access methods on VmapInterpreterPtr like so:
|
| 89 |
+
// >>> VmapInterpreterPtr(&interpreter).batchSize()
|
| 90 |
+
//
|
| 91 |
+
// Finally, Interpreter::process switches on the type of the interpreter
|
| 92 |
+
// and calls one of {Transform}Interpreter::processImpl under the hood.
|
| 93 |
+
// Same for Interpreter::sendToNextInterpreter :)
|
| 94 |
+
|
| 95 |
+
struct VmapInterpreterMeta {
|
| 96 |
+
explicit VmapInterpreterMeta(c10::SymInt batchSize, RandomnessType randomness) :
|
| 97 |
+
batchSize_(std::move(batchSize)), randomness_(randomness) {}
|
| 98 |
+
|
| 99 |
+
c10::SymInt batchSize_;
|
| 100 |
+
RandomnessType randomness_;
|
| 101 |
+
|
| 102 |
+
VmapInterpreterMeta() = default;
|
| 103 |
+
VmapInterpreterMeta(const VmapInterpreterMeta&) = default;
|
| 104 |
+
VmapInterpreterMeta(VmapInterpreterMeta&&) = default;
|
| 105 |
+
VmapInterpreterMeta& operator=(const VmapInterpreterMeta&) = default;
|
| 106 |
+
VmapInterpreterMeta& operator=(VmapInterpreterMeta&&) = default;
|
| 107 |
+
~VmapInterpreterMeta() = default;
|
| 108 |
+
|
| 109 |
+
template <typename T>
|
| 110 |
+
friend void to_json(T& json_j, const VmapInterpreterMeta& json_t) {
|
| 111 |
+
TORCH_CHECK(
|
| 112 |
+
!json_t.batchSize_.is_heap_allocated(),
|
| 113 |
+
"Serialization for heap-allocated SymInt is not implemented yet"
|
| 114 |
+
);
|
| 115 |
+
json_j["batchSize"] = json_t.batchSize_.as_int_unchecked();
|
| 116 |
+
json_j["randomness"] = static_cast<int64_t>(json_t.randomness_);
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
template <typename T>
|
| 120 |
+
friend void from_json(const T& json_j, VmapInterpreterMeta& json_t) {
|
| 121 |
+
json_t.batchSize_ = c10::SymInt(SymInt::Unchecked::UNCHECKED, json_j["batchSize"]);
|
| 122 |
+
json_t.randomness_ = static_cast<RandomnessType>(json_j["randomness"]);
|
| 123 |
+
}
|
| 124 |
+
};
|
| 125 |
+
|
| 126 |
+
struct GradInterpreterMeta {
|
| 127 |
+
explicit GradInterpreterMeta(bool prevGradMode): prevGradMode_(prevGradMode) {}
|
| 128 |
+
GradInterpreterMeta() = default;
|
| 129 |
+
GradInterpreterMeta(const GradInterpreterMeta&) = default;
|
| 130 |
+
GradInterpreterMeta(GradInterpreterMeta&&) = default;
|
| 131 |
+
GradInterpreterMeta& operator=(const GradInterpreterMeta&) = default;
|
| 132 |
+
GradInterpreterMeta& operator=(GradInterpreterMeta&&) = default;
|
| 133 |
+
~GradInterpreterMeta() = default;
|
| 134 |
+
|
| 135 |
+
bool prevGradMode_;
|
| 136 |
+
template <typename T>
|
| 137 |
+
friend void to_json(T& json_j, const GradInterpreterMeta& json_t) {
|
| 138 |
+
json_j["prevGradMode"] = json_t.prevGradMode_;
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
template <typename T>
|
| 142 |
+
friend void from_json(const T& json_j, GradInterpreterMeta& json_t) {
|
| 143 |
+
json_t.prevGradMode_ = json_j["prevGradMode"];
|
| 144 |
+
}
|
| 145 |
+
};
|
| 146 |
+
|
| 147 |
+
struct JvpInterpreterMeta {
|
| 148 |
+
explicit JvpInterpreterMeta(bool prevFwdGradMode) : prevFwdGradMode_(prevFwdGradMode) {}
|
| 149 |
+
JvpInterpreterMeta() = default;
|
| 150 |
+
JvpInterpreterMeta(const JvpInterpreterMeta&) = default;
|
| 151 |
+
JvpInterpreterMeta(JvpInterpreterMeta&&) = default;
|
| 152 |
+
JvpInterpreterMeta& operator=(const JvpInterpreterMeta&) = default;
|
| 153 |
+
JvpInterpreterMeta& operator=(JvpInterpreterMeta&&) = default;
|
| 154 |
+
~JvpInterpreterMeta() = default;
|
| 155 |
+
|
| 156 |
+
bool prevFwdGradMode_;
|
| 157 |
+
template <typename T>
|
| 158 |
+
friend void to_json(T& json_j, const JvpInterpreterMeta& json_t) {
|
| 159 |
+
json_j["prevFwdGradMode"] = json_t.prevFwdGradMode_;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
template <typename T>
|
| 163 |
+
friend void from_json(const T& json_j, JvpInterpreterMeta& json_t) {
|
| 164 |
+
json_t.prevFwdGradMode_ = json_j["prevFwdGradMode"];
|
| 165 |
+
}
|
| 166 |
+
};
|
| 167 |
+
|
| 168 |
+
struct FunctionalizeInterpreterMeta {
|
| 169 |
+
explicit FunctionalizeInterpreterMeta(bool functionalizeAddBackViews) :
|
| 170 |
+
functionalizeAddBackViews_(functionalizeAddBackViews) {}
|
| 171 |
+
FunctionalizeInterpreterMeta() = default;
|
| 172 |
+
FunctionalizeInterpreterMeta(const FunctionalizeInterpreterMeta&) = default;
|
| 173 |
+
FunctionalizeInterpreterMeta(FunctionalizeInterpreterMeta&&) = default;
|
| 174 |
+
FunctionalizeInterpreterMeta& operator=(const FunctionalizeInterpreterMeta&) = default;
|
| 175 |
+
FunctionalizeInterpreterMeta& operator=(FunctionalizeInterpreterMeta&&) = default;
|
| 176 |
+
~FunctionalizeInterpreterMeta() = default;
|
| 177 |
+
|
| 178 |
+
bool functionalizeAddBackViews_;
|
| 179 |
+
template <typename T>
|
| 180 |
+
friend void to_json(T& json_j, const FunctionalizeInterpreterMeta& json_t) {
|
| 181 |
+
json_j["functionalizeAddBackViews"] = json_t.functionalizeAddBackViews_;
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
template <typename T>
|
| 185 |
+
friend void from_json(const T& json_j, FunctionalizeInterpreterMeta& json_t) {
|
| 186 |
+
json_t.functionalizeAddBackViews_ = json_j["functionalizeAddBackViews"];
|
| 187 |
+
}
|
| 188 |
+
};
|
| 189 |
+
|
| 190 |
+
typedef std::variant<
|
| 191 |
+
int64_t,
|
| 192 |
+
GradInterpreterMeta,
|
| 193 |
+
JvpInterpreterMeta,
|
| 194 |
+
VmapInterpreterMeta,
|
| 195 |
+
FunctionalizeInterpreterMeta
|
| 196 |
+
> InterpreterMeta;
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
struct Interpreter {
|
| 200 |
+
// factory functions
|
| 201 |
+
static Interpreter Vmap(int64_t level, c10::SymInt batchSize, RandomnessType randomness) {
|
| 202 |
+
return Interpreter(TransformType::Vmap, level, VmapInterpreterMeta(std::move(batchSize), randomness));
|
| 203 |
+
}
|
| 204 |
+
static Interpreter Grad(int64_t level, bool prevGradMode) {
|
| 205 |
+
return Interpreter(TransformType::Grad, level, GradInterpreterMeta(prevGradMode));
|
| 206 |
+
}
|
| 207 |
+
static Interpreter Jvp(int64_t level, bool prevFwdGradMode) {
|
| 208 |
+
return Interpreter(TransformType::Jvp, level, JvpInterpreterMeta(prevFwdGradMode));
|
| 209 |
+
}
|
| 210 |
+
static Interpreter Functionalize(int64_t level, bool functionalizeAddBackViews) {
|
| 211 |
+
return Interpreter(TransformType::Functionalize, level, FunctionalizeInterpreterMeta(functionalizeAddBackViews));
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
// methods
|
| 215 |
+
TransformType key() const { return type_; }
|
| 216 |
+
int64_t level() const { return level_; }
|
| 217 |
+
const InterpreterMeta& meta() const { return meta_; }
|
| 218 |
+
|
| 219 |
+
void process(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
| 220 |
+
void sendToNextInterpreter(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
|
| 221 |
+
|
| 222 |
+
void saveLocalDispatchKeySet(c10::impl::LocalDispatchKeySet keyset) {
|
| 223 |
+
TORCH_INTERNAL_ASSERT(!savedLocalDispatchKeySet_.has_value());
|
| 224 |
+
savedLocalDispatchKeySet_ = keyset;
|
| 225 |
+
}
|
| 226 |
+
void clearSavedLocalDispatchKeySet() {
|
| 227 |
+
TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value());
|
| 228 |
+
savedLocalDispatchKeySet_ = std::nullopt;
|
| 229 |
+
}
|
| 230 |
+
c10::impl::LocalDispatchKeySet getSavedLocalDispatchKeySet() const {
|
| 231 |
+
TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value());
|
| 232 |
+
return *savedLocalDispatchKeySet_;
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
// An Interpreter is alive if we are currently inside the ongoing transform
|
| 236 |
+
// for the interpreter. For example, vmap(f)(x); inside of f, the vmap's
|
| 237 |
+
// corresponding Interpreter is alive, even when it is not on the DynamicLayerStack.
|
| 238 |
+
bool is_alive() const {
|
| 239 |
+
return *is_alive_;
|
| 240 |
+
}
|
| 241 |
+
const std::shared_ptr<bool>& is_alive_ptr() const {
|
| 242 |
+
return is_alive_;
|
| 243 |
+
}
|
| 244 |
+
void set_is_alive(bool alive) {
|
| 245 |
+
*is_alive_ = alive;
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
// Please don't use this
|
| 249 |
+
explicit Interpreter() = default;
|
| 250 |
+
|
| 251 |
+
template <typename T>
|
| 252 |
+
friend void to_json(T& json_j, const Interpreter& json_t) {
|
| 253 |
+
json_j["type"] = static_cast<int64_t>(json_t.type_);
|
| 254 |
+
json_j["level"] = json_t.level_;
|
| 255 |
+
if (json_t.savedLocalDispatchKeySet_) {
|
| 256 |
+
json_j["savedLocalDispatchKeySet"] = {
|
| 257 |
+
{"included", json_t.savedLocalDispatchKeySet_->included_.raw_repr()},
|
| 258 |
+
{"excluded", json_t.savedLocalDispatchKeySet_->excluded_.raw_repr()}
|
| 259 |
+
};
|
| 260 |
+
} else {
|
| 261 |
+
json_j["savedLocalDispatchKeySet"] = nlohmann::json();
|
| 262 |
+
}
|
| 263 |
+
json_j["is_alive"] = *json_t.is_alive_;
|
| 264 |
+
std::visit([&](auto&& arg) {
|
| 265 |
+
using V = std::decay_t<decltype(arg)>;
|
| 266 |
+
if constexpr (std::is_same_v<V, int64_t>) {
|
| 267 |
+
json_j["meta"] = {{"Torch", arg}};
|
| 268 |
+
} else if constexpr (std::is_same_v<V, GradInterpreterMeta>) {
|
| 269 |
+
json_j["meta"] = {{"Grad", arg}};
|
| 270 |
+
} else if constexpr (std::is_same_v<V, JvpInterpreterMeta>) {
|
| 271 |
+
json_j["meta"] = {{"Jvp", arg}};
|
| 272 |
+
} else if constexpr (std::is_same_v<V, VmapInterpreterMeta>) {
|
| 273 |
+
json_j["meta"] = {{"Vmap", arg}};
|
| 274 |
+
} else if constexpr (std::is_same_v<V, FunctionalizeInterpreterMeta>) {
|
| 275 |
+
json_j["meta"] = {{"Functionalize", arg}};
|
| 276 |
+
} else {
|
| 277 |
+
static_assert(false && sizeof(V), "unknown variant case");
|
| 278 |
+
}
|
| 279 |
+
}, json_t.meta_);
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
template <typename T>
|
| 283 |
+
friend void from_json(const T& json_j, Interpreter& json_t) {
|
| 284 |
+
json_t.type_ = static_cast<TransformType>(json_j["type"]);
|
| 285 |
+
json_t.level_ = json_j["level"];
|
| 286 |
+
auto savedLocalDispatchKeySet = json_j["savedLocalDispatchKeySet"];
|
| 287 |
+
if (savedLocalDispatchKeySet.is_null()) {
|
| 288 |
+
json_t.savedLocalDispatchKeySet_ = std::nullopt;
|
| 289 |
+
} else {
|
| 290 |
+
c10::impl::PODLocalDispatchKeySet pod;
|
| 291 |
+
pod.set_included(DispatchKeySet::from_raw_repr(savedLocalDispatchKeySet["included"].template get<uint64_t>()));
|
| 292 |
+
pod.set_excluded(DispatchKeySet::from_raw_repr(savedLocalDispatchKeySet["excluded"].template get<uint64_t>()));
|
| 293 |
+
json_t.savedLocalDispatchKeySet_ = c10::impl::LocalDispatchKeySet(pod);
|
| 294 |
+
}
|
| 295 |
+
json_t.is_alive_ = std::make_shared<bool>(json_j["is_alive"]);
|
| 296 |
+
auto meta = json_j["meta"];
|
| 297 |
+
if (meta.contains("Torch")) {
|
| 298 |
+
json_t.meta_.emplace<int64_t>(meta["Torch"].template get<int64_t>());
|
| 299 |
+
} else if (meta.contains("Grad")) {
|
| 300 |
+
json_t.meta_.emplace<GradInterpreterMeta>(meta["Grad"].template get<GradInterpreterMeta>());
|
| 301 |
+
} else if (meta.contains("Jvp")) {
|
| 302 |
+
json_t.meta_.emplace<JvpInterpreterMeta>(meta["Jvp"].template get<JvpInterpreterMeta>());
|
| 303 |
+
} else if (meta.contains("Vmap")) {
|
| 304 |
+
json_t.meta_.emplace<VmapInterpreterMeta>(meta["Vmap"].template get<VmapInterpreterMeta>());
|
| 305 |
+
} else if (meta.contains("Functionalize")) {
|
| 306 |
+
json_t.meta_.emplace<FunctionalizeInterpreterMeta>(meta["Functionalize"].template get<FunctionalizeInterpreterMeta>());
|
| 307 |
+
} else {
|
| 308 |
+
TORCH_CHECK(false, "unknown interpreter metadata type");
|
| 309 |
+
}
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
std::string serialize() const {
|
| 313 |
+
return nlohmann::json(*this).dump();
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
static Interpreter deserialize(const std::string& serialized) {
|
| 317 |
+
return nlohmann::json::parse(serialized).get<Interpreter>();
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
private:
|
| 321 |
+
explicit Interpreter(TransformType type, int64_t level, InterpreterMeta meta):
|
| 322 |
+
type_(type), level_(level), is_alive_(std::make_shared<bool>(false)), meta_(std::move(meta)) {}
|
| 323 |
+
|
| 324 |
+
// fields
|
| 325 |
+
TransformType type_{};
|
| 326 |
+
int64_t level_{};
|
| 327 |
+
std::optional<c10::impl::LocalDispatchKeySet> savedLocalDispatchKeySet_;
|
| 328 |
+
std::shared_ptr<bool> is_alive_;
|
| 329 |
+
InterpreterMeta meta_;
|
| 330 |
+
};
|
| 331 |
+
|
| 332 |
+
// Applies the following for-loop:
|
| 333 |
+
// for i in range(begin, end):
|
| 334 |
+
// args[i] = func(args[i])
|
| 335 |
+
void foreachTensorInplace(std::vector<IValue>& args, int64_t begin, int64_t end,
|
| 336 |
+
std::function<Tensor(const Tensor&)> func);
|
| 337 |
+
|
| 338 |
+
// Applies the following for-loop:
|
| 339 |
+
// for i in range(begin, end):
|
| 340 |
+
// if use_flag_relative[i] == 1: <-- treats use_flag_relative as a bitset
|
| 341 |
+
// args[i] = func(args[i], i - begin, true)
|
| 342 |
+
// args[i] = func(args[i], i - begin)
|
| 343 |
+
void foreachTensorInplaceWithFlag(std::vector<IValue>& args, int64_t begin, int64_t end,
|
| 344 |
+
const std::bitset<64> use_flag_relative, const std::function<Tensor(const Tensor&, bool)>& func);
|
| 345 |
+
|
| 346 |
+
std::vector<int64_t> findUnwrappedInputs(std::vector<IValue>& args, int64_t begin, int64_t end);
|
| 347 |
+
|
| 348 |
+
DispatchKeySet keysToExcludeWhenEnteringDynamicLayer(TransformType key);
|
| 349 |
+
|
| 350 |
+
void setup_dispatch_key_tls(TransformType key, DispatchKeySet include);
|
| 351 |
+
|
| 352 |
+
void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
| 353 |
+
|
| 354 |
+
} // namespace at::functorch
|
| 355 |
+
|
| 356 |
+
#else
|
| 357 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 358 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/PlumbingHelper.h
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
// All rights reserved.
|
| 4 |
+
//
|
| 5 |
+
// This source code is licensed under the BSD-style license found in the
|
| 6 |
+
// LICENSE file in the root directory of this source tree.
|
| 7 |
+
#pragma once
|
| 8 |
+
#include <ATen/Tensor.h>
|
| 9 |
+
#include <ATen/functorch/BatchedTensorImpl.h>
|
| 10 |
+
#include <ATen/functorch/DynamicLayer.h>
|
| 11 |
+
|
| 12 |
+
// NOTE: [vmap plumbing]
|
| 13 |
+
//
|
| 14 |
+
// Here's how "batching rules" work.
|
| 15 |
+
// - we register kernels to the Batched key
|
| 16 |
+
// - these kernels have the same signatures as the original operators.
|
| 17 |
+
// For example, at::sin(Tensor self) accepts a Tensor, and the batched kernel
|
| 18 |
+
// must also accept a Tensor
|
| 19 |
+
// - However, it is more natural for users to write a batching rule like the
|
| 20 |
+
// following: sin_batch_rule(Tensor self, std::optional<int> self_bdim)
|
| 21 |
+
// - There is some codegenerated layer (the "plumbing") that wraps the user
|
| 22 |
+
// defined batching rule (e.g. sin_batch_rule) in a kernel that can be
|
| 23 |
+
// registered to the Batched key.
|
| 24 |
+
//
|
| 25 |
+
// The plumbing is responsible for wrapping a batching rule into a form that may
|
| 26 |
+
// be registered as the kernel for the batched key.
|
| 27 |
+
|
| 28 |
+
namespace at::functorch {
|
| 29 |
+
|
| 30 |
+
void vmap_check_escaped(const std::optional<DynamicLayer> &layer, const char* what);
|
| 31 |
+
|
| 32 |
+
// Create a BatchedTensor given a tensor, bdim, and level
|
| 33 |
+
TORCH_API Tensor makeBatched(Tensor tensor, std::optional<int64_t> bdim, int64_t level);
|
| 34 |
+
|
| 35 |
+
// Given a Tensor that may or may not be a BatchedTensor, unwrap it.
|
| 36 |
+
// If `tensor` is not a BatchedTensor, or is a BatchedTensor but the level
|
| 37 |
+
// doesn't match, then this returns (tensor, std::nullopt).
|
| 38 |
+
// Otherwise, it returns (unwrap(tensor), bdim).
|
| 39 |
+
TORCH_API std::tuple<Tensor, std::optional<int64_t>> unwrapTensorAtLevel(const Tensor& tensor, int64_t level);
|
| 40 |
+
|
| 41 |
+
// Creates a vector of BatchedTensor
|
| 42 |
+
TORCH_API std::vector<Tensor> makeBatchedVector(std::vector<Tensor> tensors, std::optional<int64_t> bdim, int64_t level);
|
| 43 |
+
|
| 44 |
+
// Returns True if ANY tensor in tensors is batched at level
|
| 45 |
+
TORCH_API bool isBatchedAtLevel(ITensorListRef tensors, int64_t level);
|
| 46 |
+
TORCH_API bool isBatchedAtLevel(const c10::List<std::optional<Tensor>>& maybe_tensors, int64_t level);
|
| 47 |
+
TORCH_API bool isBatchedAtLevel(const Tensor& tensor, int64_t level);
|
| 48 |
+
TORCH_API bool isBatchedAtLevel(const std::optional<Tensor>& maybe_tensor, int64_t level);
|
| 49 |
+
|
| 50 |
+
// Convenience helper. Returns true if any tensor is batched at level
|
| 51 |
+
TORCH_API bool areAnyBatchedAtLevel(ArrayRef<std::optional<Tensor>> maybe_tensors, int64_t level);
|
| 52 |
+
|
| 53 |
+
inline bool ivalueParticipatesInCurrentLevel(const IValue& ivalue) {
|
| 54 |
+
if (ivalue.isTensor()) {
|
| 55 |
+
auto maybe_level = maybeCurrentDynamicLayer();
|
| 56 |
+
TORCH_INTERNAL_ASSERT(maybe_level.has_value());
|
| 57 |
+
auto current_level = maybe_level->layerId();
|
| 58 |
+
return isBatchedAtLevel(ivalue.toTensor(), current_level);
|
| 59 |
+
}
|
| 60 |
+
// TODO: should really check this
|
| 61 |
+
return false;
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
} // namespace at::functorch
|
| 65 |
+
|
| 66 |
+
#else
|
| 67 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 68 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/functorch/TensorWrapper.h
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
// All rights reserved.
|
| 4 |
+
//
|
| 5 |
+
// This source code is licensed under the BSD-style license found in the
|
| 6 |
+
// LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
#pragma once
|
| 9 |
+
|
| 10 |
+
#include <ATen/functorch/Macros.h>
|
| 11 |
+
#include <ATen/Tensor.h>
|
| 12 |
+
#include <ATen/functorch/Interpreter.h>
|
| 13 |
+
|
| 14 |
+
namespace at::functorch {
|
| 15 |
+
|
| 16 |
+
// NOTE: [functorch's TensorWrapper]
|
| 17 |
+
//
|
| 18 |
+
// Taking better suggestions for a name. TensorWrapper is the wrapper Tensor
|
| 19 |
+
// Subclass for functorch's grad-based transforms (grad, vjp, jvp). It is
|
| 20 |
+
// analogous to how vmap uses BatchedTensor as the wrapper Tensor subclass.
|
| 21 |
+
//
|
| 22 |
+
// If you're familiar with the Tensor-Variable merge, TensorWrapper is effectively
|
| 23 |
+
// another Variable.
|
| 24 |
+
//
|
| 25 |
+
// Consider grad(grad(torch.sin))(x). This wraps `x` as TensorWrapper(TensorWrapper(x)).
|
| 26 |
+
// The reason why is so that each TensorWrapper can hold its own AutogradMeta and
|
| 27 |
+
// participate in a **separate** autograd graph.
|
| 28 |
+
//
|
| 29 |
+
// There are alternative designs we could have chosen (e.g. each grad transform
|
| 30 |
+
// stores a weak map of Tensor -> AutogradMeta); the benefit of the TensorWrapper
|
| 31 |
+
// design is that we can reuse existing VariableType kernels (i.e. Autograd kernels)
|
| 32 |
+
// without much modification. Since a TensorWrapper looks like a regular Tensor,
|
| 33 |
+
// the VariableType kernel can pull out the AutogradMeta struct from where it
|
| 34 |
+
// expects and extend the autograd graph
|
| 35 |
+
|
| 36 |
+
struct TORCH_API TensorWrapper : public c10::TensorImpl {
|
| 37 |
+
explicit TensorWrapper(
|
| 38 |
+
c10::DispatchKeySet key_set,
|
| 39 |
+
Tensor value,
|
| 40 |
+
int64_t level,
|
| 41 |
+
std::shared_ptr<bool> is_alive,
|
| 42 |
+
bool is_immutable = false, // if true, this came from an operation that aliases an immutable tensor
|
| 43 |
+
bool use_value_sizes_strides = true);
|
| 44 |
+
|
| 45 |
+
void refreshMetadata();
|
| 46 |
+
|
| 47 |
+
const Tensor& value() const {
|
| 48 |
+
return value_;
|
| 49 |
+
}
|
| 50 |
+
std::optional<int64_t> level() const {
|
| 51 |
+
if (is_alive()) {
|
| 52 |
+
return level_;
|
| 53 |
+
}
|
| 54 |
+
return {};
|
| 55 |
+
}
|
| 56 |
+
bool is_immutable() const {
|
| 57 |
+
return is_immutable_;
|
| 58 |
+
}
|
| 59 |
+
bool is_alive() const;
|
| 60 |
+
|
| 61 |
+
// Overrides necessary for autograd
|
| 62 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 63 |
+
const c10::VariableVersion& version_counter,
|
| 64 |
+
bool allow_tensor_metadata_change) const override;
|
| 65 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 66 |
+
c10::VariableVersion&& version_counter,
|
| 67 |
+
bool allow_tensor_metadata_change) const override;
|
| 68 |
+
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
|
| 69 |
+
|
| 70 |
+
private:
|
| 71 |
+
const char* tensorimpl_type_name() const override;
|
| 72 |
+
Tensor value_;
|
| 73 |
+
int64_t level_;
|
| 74 |
+
bool is_immutable_;
|
| 75 |
+
|
| 76 |
+
// TensorWrapper receives a boolean flag on whether or not the Grad Interpreter
|
| 77 |
+
// that created it is still alive or not.
|
| 78 |
+
// If the Grad Interpreter is no longer alive then it attempts to behave like
|
| 79 |
+
// a regular Tensor.
|
| 80 |
+
//
|
| 81 |
+
// When we exit the level, this wrapper may be marked as "not alive".
|
| 82 |
+
// Wrappers that are not alive:
|
| 83 |
+
// 1) May still have autograd metadata on them
|
| 84 |
+
// 2) Forward dispatches to the underlying value()
|
| 85 |
+
std::shared_ptr<bool> is_alive_;
|
| 86 |
+
};
|
| 87 |
+
|
| 88 |
+
// There are two variants of makeTensorWrapper: one that accepts a level
|
| 89 |
+
// and one that accepts an Interpreter.
|
| 90 |
+
//
|
| 91 |
+
// The one that accepts a level tries to automatically get the life handle from the
|
| 92 |
+
// interpreter on the DynamicLayerStack.
|
| 93 |
+
// It needs to be used with caution: if the interpreter is not on the
|
| 94 |
+
// DynamicLayerStack, then we won't be able to find the life handle.
|
| 95 |
+
//
|
| 96 |
+
// In practice this isn't a problem: when we're constructing TensorWrapper in
|
| 97 |
+
// Python, the corresponding interpreter is on the stack.
|
| 98 |
+
TORCH_API Tensor makeTensorWrapper(const Tensor& tensor, int64_t level, bool is_immutable=false);
|
| 99 |
+
TORCH_API Tensor makeTensorWrapper(const Tensor& tensor, const Interpreter& interpreter, bool is_immutable=false);
|
| 100 |
+
TORCH_API TensorWrapper* maybeGetTensorWrapper(const Tensor& tensor);
|
| 101 |
+
TORCH_API void dumpTensor(std::ostream & ss, const Tensor& tensor);
|
| 102 |
+
TORCH_API void dumpTensorCout(const Tensor& tensor);
|
| 103 |
+
|
| 104 |
+
} // namespace at::functorch
|
| 105 |
+
|
| 106 |
+
#else
|
| 107 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 108 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/Tensor.h>
|
| 5 |
+
#include <c10/core/QScheme.h>
|
| 6 |
+
|
| 7 |
+
#ifdef USE_FBGEMM
|
| 8 |
+
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi")
|
| 9 |
+
#include <fbgemm/Fbgemm.h>
|
| 10 |
+
#include <fbgemm/FbgemmSparse.h>
|
| 11 |
+
#include <ATen/native/ao_sparse/quantized/cpu/packed_params.h>
|
| 12 |
+
C10_DIAGNOSTIC_POP()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
namespace ao::sparse {
|
| 16 |
+
|
| 17 |
+
struct TORCH_API PackedLinearWeight
|
| 18 |
+
: public LinearPackedParamsBase {
|
| 19 |
+
PackedLinearWeight(std::unique_ptr<fbgemm::BCSRMatrix<int8_t>> w,
|
| 20 |
+
std::optional<at::Tensor> bias,
|
| 21 |
+
std::vector<int32_t> col_offsets,
|
| 22 |
+
std::vector<float> w_scale,
|
| 23 |
+
std::vector<int32_t> w_zp,
|
| 24 |
+
c10::QScheme q_scheme,
|
| 25 |
+
const int64_t out_features_block_size /* block sparsity size across output_features */,
|
| 26 |
+
const int64_t in_features_block_size /* block sparsity size across input_features */)
|
| 27 |
+
: LinearPackedParamsBase(
|
| 28 |
+
out_features_block_size,
|
| 29 |
+
in_features_block_size),
|
| 30 |
+
w(std::move(w)),
|
| 31 |
+
bias_(std::move(bias)),
|
| 32 |
+
col_offsets(std::move(col_offsets)),
|
| 33 |
+
w_scale(std::move(w_scale)),
|
| 34 |
+
w_zp(std::move(w_zp)),
|
| 35 |
+
q_scheme(q_scheme) {}
|
| 36 |
+
std::unique_ptr<fbgemm::BCSRMatrix<int8_t>> w;
|
| 37 |
+
std::optional<at::Tensor> bias_;
|
| 38 |
+
std::vector<int32_t> col_offsets;
|
| 39 |
+
std::vector<float> w_scale;
|
| 40 |
+
std::vector<int32_t> w_zp;
|
| 41 |
+
c10::QScheme q_scheme;
|
| 42 |
+
|
| 43 |
+
at::Tensor apply(
|
| 44 |
+
const at::Tensor& input,
|
| 45 |
+
double output_scale,
|
| 46 |
+
int64_t output_zero_point) override;
|
| 47 |
+
at::Tensor apply_relu(
|
| 48 |
+
const at::Tensor& input,
|
| 49 |
+
double output_scale,
|
| 50 |
+
int64_t output_zero_point) override;
|
| 51 |
+
|
| 52 |
+
at::Tensor apply_dynamic(const at::Tensor& input) override {
|
| 53 |
+
TORCH_INTERNAL_ASSERT(
|
| 54 |
+
false,
|
| 55 |
+
"Sparse quantized dynamic linear with fused relu is not yet "
|
| 56 |
+
"supported on qnnpack backend.");
|
| 57 |
+
return at::Tensor();
|
| 58 |
+
}
|
| 59 |
+
at::Tensor apply_dynamic_relu(const at::Tensor& input) override {
|
| 60 |
+
TORCH_INTERNAL_ASSERT(
|
| 61 |
+
false,
|
| 62 |
+
"Sparse quantized dynamic linear with fused relu is not yet "
|
| 63 |
+
"supported on qnnpack backend.");
|
| 64 |
+
return at::Tensor();
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
LinearPackedSerializationType unpack() override;
|
| 68 |
+
|
| 69 |
+
BCSRSerializationType serialize() override;
|
| 70 |
+
|
| 71 |
+
static c10::intrusive_ptr<LinearPackedParamsBase> deserialize(
|
| 72 |
+
const BCSRSerializationType& serialized);
|
| 73 |
+
|
| 74 |
+
std::optional<at::Tensor> bias() override {
|
| 75 |
+
return bias_;
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
static c10::intrusive_ptr<LinearPackedParamsBase> prepack(
|
| 79 |
+
const at::Tensor& weight,
|
| 80 |
+
const std::optional<at::Tensor>& bias,
|
| 81 |
+
const int64_t out_features_block_size,
|
| 82 |
+
const int64_t in_features_block_size);
|
| 83 |
+
|
| 84 |
+
private:
|
| 85 |
+
template <bool ReluFused>
|
| 86 |
+
at::Tensor apply_impl(
|
| 87 |
+
const at::Tensor& input,
|
| 88 |
+
double output_scale,
|
| 89 |
+
int64_t output_zero_point);
|
| 90 |
+
};
|
| 91 |
+
|
| 92 |
+
} // namespace ao::sparse
|
| 93 |
+
|
| 94 |
+
#endif // USE_FBGEMM
|
| 95 |
+
|
| 96 |
+
namespace ao::sparse {
|
| 97 |
+
int register_linear_params();
|
| 98 |
+
} // namespace ao::sparse
|
| 99 |
+
|
| 100 |
+
#else
|
| 101 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 102 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/packed_params.h
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <cstdint>
|
| 5 |
+
|
| 6 |
+
#include <ATen/core/ivalue.h>
|
| 7 |
+
#include <c10/util/Exception.h>
|
| 8 |
+
|
| 9 |
+
namespace ao::sparse {
|
| 10 |
+
|
| 11 |
+
// <Weight, bias, out_features_block_size, in_features_block_size>
|
| 12 |
+
using LinearPackedSerializationType =
|
| 13 |
+
std::tuple<at::Tensor, std::optional<at::Tensor>, std::vector<int64_t>>;
|
| 14 |
+
|
| 15 |
+
#define SPARSE_LINEAR_PACKED_PARAM_SERIALIZATION_VERSION 2
|
| 16 |
+
|
| 17 |
+
using BCSRSerializationType =
|
| 18 |
+
std::tuple<
|
| 19 |
+
int64_t, // Serialization Version
|
| 20 |
+
std::optional<at::Tensor>, // Bias
|
| 21 |
+
int64_t, // Out Features (Row) Block Size
|
| 22 |
+
int64_t, // In Features (Column) Block Size
|
| 23 |
+
at::Tensor, // Weight Scales (single element vector if per-tensor) (float)
|
| 24 |
+
at::Tensor, // Wrapper for Weight Zero Points (single element vector if per-tensor) (int8_t)
|
| 25 |
+
bool, // Quantization Scheme (true: per tensor, false: per channel)
|
| 26 |
+
at::Tensor, // Wrapper for Row Block Indices (int8_t, int16_t, or int32_t)
|
| 27 |
+
at::Tensor, // Wrapper for Column Block Indices (int8_t, int16_t, or int32_t)
|
| 28 |
+
at::Tensor, // Wrapper for Non-Zero Weight Values, each +128 (uint8_t)
|
| 29 |
+
int64_t, // Number of Output Channels
|
| 30 |
+
int64_t // Number of Input Channels
|
| 31 |
+
>;
|
| 32 |
+
|
| 33 |
+
using BCSR =
|
| 34 |
+
std::tuple<
|
| 35 |
+
std::vector<int8_t>, // Non-Zero Weight Values
|
| 36 |
+
std::vector<int32_t>, // Compressed Row Block Indices
|
| 37 |
+
std::vector<int32_t> // Column Block Indices
|
| 38 |
+
>;
|
| 39 |
+
|
| 40 |
+
struct LinearPackedParamsBase : public torch::jit::CustomClassHolder {
|
| 41 |
+
public:
|
| 42 |
+
LinearPackedParamsBase(
|
| 43 |
+
const int64_t out_features_block_size,
|
| 44 |
+
const int64_t in_features_block_size)
|
| 45 |
+
: out_features_block_size_(out_features_block_size),
|
| 46 |
+
in_features_block_size_(in_features_block_size) {}
|
| 47 |
+
|
| 48 |
+
virtual at::Tensor apply(
|
| 49 |
+
const at::Tensor& input,
|
| 50 |
+
double output_scale,
|
| 51 |
+
int64_t output_zero_point) = 0;
|
| 52 |
+
virtual at::Tensor apply_relu(
|
| 53 |
+
const at::Tensor& input,
|
| 54 |
+
double output_scale,
|
| 55 |
+
int64_t output_zero_point) = 0;
|
| 56 |
+
|
| 57 |
+
virtual at::Tensor apply_dynamic(const at::Tensor& input) = 0;
|
| 58 |
+
virtual at::Tensor apply_dynamic_relu(const at::Tensor& input) = 0;
|
| 59 |
+
|
| 60 |
+
virtual LinearPackedSerializationType unpack() = 0;
|
| 61 |
+
|
| 62 |
+
virtual BCSRSerializationType serialize() = 0;
|
| 63 |
+
|
| 64 |
+
virtual std::optional<at::Tensor> bias() = 0;
|
| 65 |
+
|
| 66 |
+
virtual void set_bias(const std::optional<at::Tensor>& bias) {
|
| 67 |
+
TORCH_CHECK(false, "set_bias is not implemented for this packed parameter type");
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
protected:
|
| 71 |
+
const int64_t out_features_block_size_, in_features_block_size_;
|
| 72 |
+
};
|
| 73 |
+
|
| 74 |
+
} // namespace ao::sparse
|
| 75 |
+
|
| 76 |
+
#else
|
| 77 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 78 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/ao_sparse/quantized/cpu/qnnpack_utils.h
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/Tensor.h>
|
| 5 |
+
#include <c10/core/QScheme.h>
|
| 6 |
+
|
| 7 |
+
#ifdef USE_PYTORCH_QNNPACK
|
| 8 |
+
// TODO: Refacto QnnpackUtils.h so as to separate code
|
| 9 |
+
// needed for quantized op from the generic qnnpack specific
|
| 10 |
+
// quantization utilities.
|
| 11 |
+
#include <ATen/native/ao_sparse/quantized/cpu/packed_params.h>
|
| 12 |
+
#include <ATen/native/quantized/cpu/QnnpackUtils.h>
|
| 13 |
+
#include <pack_block_sparse.h>
|
| 14 |
+
|
| 15 |
+
namespace ao::sparse {
|
| 16 |
+
|
| 17 |
+
struct TORCH_API PackedLinearWeightQnnp : public LinearPackedParamsBase {
|
| 18 |
+
PackedLinearWeightQnnp(const at::Tensor& weight, const std::optional<at::Tensor>& bias, const int64_t out_features_block_size /* block sparsity size across output_features */, const int64_t in_features_block_size /* block sparsity size across input_features */);
|
| 19 |
+
explicit PackedLinearWeightQnnp(const BCSRSerializationType& serialized);
|
| 20 |
+
std::optional<at::Tensor> orig_bias_;
|
| 21 |
+
// Separate copy of bias exist so that we can fill in zeros when
|
| 22 |
+
// optional bias does not exist. This is to compy with qnnpack operator that
|
| 23 |
+
// expects bias to be present.
|
| 24 |
+
// In case bias is present bias_ is just a reference to orig_bias_
|
| 25 |
+
at::Tensor bias_;
|
| 26 |
+
c10::QScheme q_scheme_;
|
| 27 |
+
double input_scale_{};
|
| 28 |
+
std::unique_ptr<qnnpack::BCSRMatrix> bcsr_matrix_;
|
| 29 |
+
at::Tensor w_scales_;
|
| 30 |
+
std::vector<uint8_t> w_zero_points_;
|
| 31 |
+
std::vector<float> requantization_scales_;
|
| 32 |
+
std::unique_ptr<pytorch_qnnp_operator, QnnpackOperatorDeleter>
|
| 33 |
+
sparse_linear_op_{nullptr};
|
| 34 |
+
int64_t output_channels_;
|
| 35 |
+
int64_t input_channels_;
|
| 36 |
+
// Deserialized Tensors are stored to maintain the lifetime of underlying
|
| 37 |
+
// BCSR data.
|
| 38 |
+
// These are left empty if PackedLinearWeightQnnp is created via prepacking
|
| 39 |
+
// rather than deserializing.
|
| 40 |
+
at::Tensor deserialized_bcsr_row_block_indices_;
|
| 41 |
+
at::Tensor deserialized_bcsr_col_block_indices_;
|
| 42 |
+
at::Tensor deserialized_bcsr_weight_values_;
|
| 43 |
+
|
| 44 |
+
at::Tensor apply(
|
| 45 |
+
const at::Tensor& input,
|
| 46 |
+
double output_scale,
|
| 47 |
+
int64_t output_zero_point) override {
|
| 48 |
+
TORCH_CHECK(
|
| 49 |
+
false, "Static quantized sparse linear unimplemented on QNNPACK");
|
| 50 |
+
}
|
| 51 |
+
at::Tensor apply_relu(
|
| 52 |
+
const at::Tensor& input,
|
| 53 |
+
double output_scale,
|
| 54 |
+
int64_t output_zero_point) override {
|
| 55 |
+
TORCH_CHECK(
|
| 56 |
+
false, "Static quantized sparse linear unimplemented on QNNPACK");
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
at::Tensor apply_dynamic(const at::Tensor& input) override;
|
| 60 |
+
at::Tensor apply_dynamic_relu(const at::Tensor& input) override;
|
| 61 |
+
|
| 62 |
+
LinearPackedSerializationType unpack() override;
|
| 63 |
+
|
| 64 |
+
BCSRSerializationType serialize() override;
|
| 65 |
+
|
| 66 |
+
static c10::intrusive_ptr<LinearPackedParamsBase> deserialize(
|
| 67 |
+
const BCSRSerializationType& serialized);
|
| 68 |
+
|
| 69 |
+
std::optional<at::Tensor> bias() override {
|
| 70 |
+
return orig_bias_;
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
static c10::intrusive_ptr<LinearPackedParamsBase> prepack(
|
| 74 |
+
const at::Tensor& weight,
|
| 75 |
+
const std::optional<at::Tensor>& bias,
|
| 76 |
+
const int64_t out_features_block_size,
|
| 77 |
+
const int64_t in_features_block_size);
|
| 78 |
+
|
| 79 |
+
private:
|
| 80 |
+
template <bool ReluFused>
|
| 81 |
+
at::Tensor apply_impl(
|
| 82 |
+
const at::Tensor& input,
|
| 83 |
+
double output_scale,
|
| 84 |
+
int64_t output_zero_point);
|
| 85 |
+
template <bool ReluFused>
|
| 86 |
+
at::Tensor apply_dynamic_impl(const at::Tensor& input);
|
| 87 |
+
};
|
| 88 |
+
|
| 89 |
+
} // namespace ao::sparse
|
| 90 |
+
|
| 91 |
+
#endif // USE_PYTORCH_QNNPACK
|
| 92 |
+
|
| 93 |
+
#else
|
| 94 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 95 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/AtomicAddFloat.h
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#ifndef ATOMIC_ADD_FLOAT
|
| 3 |
+
#define ATOMIC_ADD_FLOAT
|
| 4 |
+
|
| 5 |
+
#if (defined(__x86_64__) || defined(__i386__) || defined(__aarch64__))
|
| 6 |
+
#include <ATen/native/cpu/Intrinsics.h>
|
| 7 |
+
#else
|
| 8 |
+
#define _mm_pause()
|
| 9 |
+
#endif
|
| 10 |
+
|
| 11 |
+
#include <atomic>
|
| 12 |
+
|
| 13 |
+
static inline void cpu_atomic_add_float(float* dst, float fvalue)
|
| 14 |
+
{
|
| 15 |
+
typedef union {
|
| 16 |
+
unsigned intV;
|
| 17 |
+
float floatV;
|
| 18 |
+
} uf32_t;
|
| 19 |
+
|
| 20 |
+
uf32_t new_value, old_value;
|
| 21 |
+
std::atomic<unsigned>* dst_intV = (std::atomic<unsigned>*)dst;
|
| 22 |
+
|
| 23 |
+
old_value.floatV = *dst;
|
| 24 |
+
new_value.floatV = old_value.floatV + fvalue;
|
| 25 |
+
|
| 26 |
+
unsigned* old_intV = &old_value.intV;
|
| 27 |
+
while (!std::atomic_compare_exchange_strong(dst_intV, old_intV, new_value.intV)) {
|
| 28 |
+
#ifdef __aarch64__
|
| 29 |
+
__asm__ __volatile__("yield;" : : : "memory");
|
| 30 |
+
#else
|
| 31 |
+
_mm_pause();
|
| 32 |
+
#endif
|
| 33 |
+
old_value.floatV = *dst;
|
| 34 |
+
new_value.floatV = old_value.floatV + fvalue;
|
| 35 |
+
}
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
#endif
|
| 39 |
+
|
| 40 |
+
#else
|
| 41 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 42 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/ChannelShuffleKernel.h
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
#include <cstdint>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
class TensorBase;
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
namespace at::native {
|
| 11 |
+
|
| 12 |
+
using channel_shuffle_fn = void(*)(TensorBase&, const TensorBase&, int64_t);
|
| 13 |
+
DECLARE_DISPATCH(channel_shuffle_fn, channel_shuffle_kernel)
|
| 14 |
+
|
| 15 |
+
} // at::native
|
| 16 |
+
|
| 17 |
+
#else
|
| 18 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 19 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/DepthwiseConvKernel.h
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/native/DispatchStub.h>
|
| 5 |
+
#include <c10/util/ArrayRef.h>
|
| 6 |
+
|
| 7 |
+
/*
|
| 8 |
+
Depthwise 3x3 Winograd convolution operator
|
| 9 |
+
*/
|
| 10 |
+
|
| 11 |
+
namespace at {
|
| 12 |
+
class Tensor;
|
| 13 |
+
|
| 14 |
+
namespace native {
|
| 15 |
+
|
| 16 |
+
using convolution_depthwise3x3_winograd_fn =
|
| 17 |
+
Tensor (*)(const Tensor &, const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, int64_t);
|
| 18 |
+
|
| 19 |
+
DECLARE_DISPATCH(convolution_depthwise3x3_winograd_fn, convolution_depthwise3x3_winograd_stub)
|
| 20 |
+
|
| 21 |
+
} // namespace native
|
| 22 |
+
} // namespace at
|
| 23 |
+
|
| 24 |
+
#else
|
| 25 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 26 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/Elu.h
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
// On Windows, math.h needs to be included with _USE_MATH_DEFINES defined to
|
| 5 |
+
// access constants such as M_SQRT2 and M_2_SQRTPI.
|
| 6 |
+
#ifdef _WIN32
|
| 7 |
+
#define _USE_MATH_DEFINES
|
| 8 |
+
#include <cmath>
|
| 9 |
+
#endif // _WIN32
|
| 10 |
+
|
| 11 |
+
#include <ATen/cpu/vec/vec.h>
|
| 12 |
+
#include <c10/util/BFloat16.h> // For c10::is_reduced_floating_point_v.
|
| 13 |
+
|
| 14 |
+
namespace at::native {
|
| 15 |
+
inline namespace CPU_CAPABILITY {
|
| 16 |
+
/**
|
| 17 |
+
* Return a function object that calculates ELU with the given
|
| 18 |
+
* parameters on its input element. ParamT is the type of the input
|
| 19 |
+
* and output to the ELU, and MathT is the type (possibly
|
| 20 |
+
* higher-precision, e.g. float if ParamT is reduced-precision float)
|
| 21 |
+
* in which to do intermediate calculations.
|
| 22 |
+
*/
|
| 23 |
+
template <typename ParamT, typename MathT=ParamT>
|
| 24 |
+
auto get_scalar_elu_elementwise_func(MathT alpha, MathT scale, MathT input_scale) {
|
| 25 |
+
const auto negcoef = alpha * scale;
|
| 26 |
+
const auto poscoef = scale;
|
| 27 |
+
const auto negiptcoef = input_scale;
|
| 28 |
+
return [negcoef, negiptcoef, poscoef](ParamT a) -> ParamT {
|
| 29 |
+
return MathT(a) < MathT(0)
|
| 30 |
+
? std::expm1(MathT(a) * negiptcoef) * negcoef
|
| 31 |
+
: MathT(a) * poscoef;
|
| 32 |
+
};
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
/**
|
| 36 |
+
* Return a function object that calculates ELU with the given
|
| 37 |
+
* parameters on its input element. The function object takes and
|
| 38 |
+
* returns Vectorized<T>.
|
| 39 |
+
*/
|
| 40 |
+
template <typename T, std::enable_if_t<!c10::is_reduced_floating_point_v<T>, bool> = true>
|
| 41 |
+
auto get_vectorized_elu_elementwise_func(T alpha, T scale, T input_scale) {
|
| 42 |
+
const vec::Vectorized<T> negcoef_vec(alpha * scale);
|
| 43 |
+
const vec::Vectorized<T> poscoef_vec(scale);
|
| 44 |
+
const vec::Vectorized<T> negiptcoef_vec(input_scale);
|
| 45 |
+
const vec::Vectorized<T> zero_vec(static_cast<T>(0));
|
| 46 |
+
return [negcoef_vec, poscoef_vec, negiptcoef_vec, zero_vec](vec::Vectorized<T> a) -> vec::Vectorized<T> {
|
| 47 |
+
const auto cmp = a >= zero_vec;
|
| 48 |
+
if (!cmp.zero_mask()) {
|
| 49 |
+
return a * poscoef_vec;
|
| 50 |
+
} else {
|
| 51 |
+
return vec::Vectorized<T>::blendv((a * negiptcoef_vec).expm1() * negcoef_vec, a * poscoef_vec, cmp);
|
| 52 |
+
}
|
| 53 |
+
};
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
/**
|
| 57 |
+
* Return a function object that calculates ELU with the given
|
| 58 |
+
* parameters on its input element. The function object takes and
|
| 59 |
+
* returns Vectorized<ParamT>, and Vectorized<MathT> is the type
|
| 60 |
+
* (possibly higher-precision) in which to do intermediate
|
| 61 |
+
* calculations.
|
| 62 |
+
*/
|
| 63 |
+
template <typename T, std::enable_if_t<c10::is_reduced_floating_point_v<T>, bool> = true>
|
| 64 |
+
auto get_vectorized_elu_elementwise_func(float alpha, float scale, float input_scale) {
|
| 65 |
+
// Takes float->float.
|
| 66 |
+
const auto float_func = get_vectorized_elu_elementwise_func<float>(alpha, scale, input_scale);
|
| 67 |
+
return [float_func](vec::Vectorized<T> a) -> vec::Vectorized<T> {
|
| 68 |
+
auto [a0, a1] = vec::convert_to_float<T>(a);
|
| 69 |
+
auto res0 = float_func(a0);
|
| 70 |
+
auto res1 = float_func(a1);
|
| 71 |
+
return vec::convert_from_float<T>(res0, res1);
|
| 72 |
+
};
|
| 73 |
+
}
|
| 74 |
+
} // namespace CPU_CAPABILITY
|
| 75 |
+
} // namespace at::native
|
| 76 |
+
|
| 77 |
+
#else
|
| 78 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 79 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/GridSamplerKernel.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/native/DispatchStub.h>
|
| 5 |
+
|
| 6 |
+
#include <array>
|
| 7 |
+
#include <cstdint>
|
| 8 |
+
|
| 9 |
+
namespace at {
|
| 10 |
+
class TensorBase;
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
namespace at::native {
|
| 14 |
+
|
| 15 |
+
using forward_2d_fn = void (*) (
|
| 16 |
+
const TensorBase &output,
|
| 17 |
+
const TensorBase &input,
|
| 18 |
+
const TensorBase &grid,
|
| 19 |
+
int64_t interpolation_mode,
|
| 20 |
+
int64_t padding_mode,
|
| 21 |
+
bool align_corners);
|
| 22 |
+
using backward_2d_fn = void (*) (
|
| 23 |
+
const TensorBase &grad_input,
|
| 24 |
+
const TensorBase &grad_grid,
|
| 25 |
+
const TensorBase &grad_output,
|
| 26 |
+
const TensorBase &input,
|
| 27 |
+
const TensorBase &grid,
|
| 28 |
+
int64_t interpolation_mode,
|
| 29 |
+
int64_t padding_mode,
|
| 30 |
+
bool align_corners,
|
| 31 |
+
std::array<bool, 2> output_mask);
|
| 32 |
+
DECLARE_DISPATCH(forward_2d_fn, grid_sampler_2d_cpu_kernel)
|
| 33 |
+
DECLARE_DISPATCH(backward_2d_fn, grid_sampler_2d_backward_cpu_kernel)
|
| 34 |
+
|
| 35 |
+
} // namespace at::native
|
| 36 |
+
|
| 37 |
+
#else
|
| 38 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 39 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/IndexKernelUtils.h
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
#include <ATen/native/TensorIterator.h>
|
| 4 |
+
#include <c10/util/irange.h>
|
| 5 |
+
|
| 6 |
+
namespace at::native {
|
| 7 |
+
|
| 8 |
+
inline bool is_constant_index(int ntensor, const int64_t* strides) {
|
| 9 |
+
AT_ASSERT(ntensor >= 3);
|
| 10 |
+
for (const auto arg : c10::irange(2, ntensor)) {
|
| 11 |
+
if (strides[arg] != 0) {
|
| 12 |
+
return false;
|
| 13 |
+
}
|
| 14 |
+
}
|
| 15 |
+
return true;
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
struct Indexer {
|
| 20 |
+
Indexer(int64_t num_indexers, char** indexers, const int64_t* indexer_strides,
|
| 21 |
+
IntArrayRef original_sizes, IntArrayRef original_strides)
|
| 22 |
+
: num_indexers(num_indexers)
|
| 23 |
+
, indexers(indexers)
|
| 24 |
+
, indexer_strides(indexer_strides)
|
| 25 |
+
, original_strides(original_strides.data())
|
| 26 |
+
, original_sizes(original_sizes.data()) {
|
| 27 |
+
AT_ASSERT(static_cast<int64_t>(original_strides.size()) == num_indexers);
|
| 28 |
+
AT_ASSERT(static_cast<int64_t>(original_sizes.size()) == num_indexers);
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
int64_t num_indexers;
|
| 32 |
+
char** indexers;
|
| 33 |
+
const int64_t* indexer_strides;
|
| 34 |
+
const int64_t* original_strides;
|
| 35 |
+
const int64_t* original_sizes;
|
| 36 |
+
|
| 37 |
+
int64_t get(int64_t idx) {
|
| 38 |
+
int64_t offset = 0;
|
| 39 |
+
for (const auto j : c10::irange(num_indexers)) {
|
| 40 |
+
int64_t value = *(int64_t*)&indexers[j][idx * indexer_strides[j]];
|
| 41 |
+
int64_t size = original_sizes[j];
|
| 42 |
+
TORCH_CHECK_INDEX(value >= -size && value < size,
|
| 43 |
+
"index ", value, " is out of bounds for dimension ", j, " with size ", size);
|
| 44 |
+
if (value < 0) {
|
| 45 |
+
value += size;
|
| 46 |
+
}
|
| 47 |
+
offset += value * original_strides[j];
|
| 48 |
+
}
|
| 49 |
+
return offset;
|
| 50 |
+
}
|
| 51 |
+
};
|
| 52 |
+
|
| 53 |
+
template <typename scalar_t, typename func_t>
|
| 54 |
+
void cpu_index_kernel(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride,
|
| 55 |
+
const func_t& f, bool serial_execution=false)
|
| 56 |
+
{
|
| 57 |
+
int ntensor = iter.ntensors();
|
| 58 |
+
// When launch the index parallel version, set a relative small grain size less than the INTERNAL::GRAIN_SIZE
|
| 59 |
+
// to make the whole available thread numbers get more balanced work load and a better cache location.
|
| 60 |
+
// The grain size here is chosen by the op benchmark to overcome the thread launch overhead
|
| 61 |
+
const int index_parallel_grain_size = 3000;
|
| 62 |
+
auto loop = [&](char** data, const int64_t* strides, int64_t n) {
|
| 63 |
+
auto indexer = Indexer(ntensor - 2, &data[2], &strides[2], index_size, index_stride);
|
| 64 |
+
char* dst = data[0];
|
| 65 |
+
char* src = data[1];
|
| 66 |
+
if (is_constant_index(ntensor, strides)) {
|
| 67 |
+
// specialization for when every element uses the same index
|
| 68 |
+
int64_t offset = indexer.get(0);
|
| 69 |
+
for (const auto i : c10::irange(n)) {
|
| 70 |
+
f(dst + strides[0] * i, src + strides[1] * i, offset);
|
| 71 |
+
}
|
| 72 |
+
} else {
|
| 73 |
+
for (const auto i : c10::irange(n)) {
|
| 74 |
+
int64_t offset = indexer.get(i);
|
| 75 |
+
f(dst + strides[0] * i, src + strides[1] * i, offset);
|
| 76 |
+
}
|
| 77 |
+
}
|
| 78 |
+
};
|
| 79 |
+
if (serial_execution) {
|
| 80 |
+
iter.serial_for_each(loop, {0, iter.numel()});
|
| 81 |
+
} else {
|
| 82 |
+
iter.for_each(loop, index_parallel_grain_size);
|
| 83 |
+
}
|
| 84 |
+
}
|
| 85 |
+
} // at
|
| 86 |
+
// native
|
| 87 |
+
|
| 88 |
+
#else
|
| 89 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 90 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/Intrinsics.h
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#if defined(__clang__) && (defined(__x86_64__) || defined(__i386__))
|
| 5 |
+
/* Clang-compatible compiler, targeting x86/x86-64 */
|
| 6 |
+
#include <x86intrin.h>
|
| 7 |
+
#elif defined(_MSC_VER)
|
| 8 |
+
/* Microsoft C/C++-compatible compiler */
|
| 9 |
+
#include <intrin.h>
|
| 10 |
+
#if _MSC_VER <= 1900
|
| 11 |
+
#define _mm256_extract_epi64(X, Y) (((uint64_t*)&X)[Y])
|
| 12 |
+
#endif
|
| 13 |
+
#elif defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__))
|
| 14 |
+
/* GCC-compatible compiler, targeting x86/x86-64 */
|
| 15 |
+
#include <x86intrin.h>
|
| 16 |
+
#elif defined(__GNUC__) && defined(__ARM_NEON__)
|
| 17 |
+
/* GCC-compatible compiler, targeting ARM with NEON */
|
| 18 |
+
#include <arm_neon.h>
|
| 19 |
+
#elif defined(__GNUC__) && defined(__IWMMXT__)
|
| 20 |
+
/* GCC-compatible compiler, targeting ARM with WMMX */
|
| 21 |
+
#include <mmintrin.h>
|
| 22 |
+
#elif (defined(__GNUC__) || defined(__xlC__)) && \
|
| 23 |
+
(defined(__VEC__) || defined(__ALTIVEC__))
|
| 24 |
+
/* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */
|
| 25 |
+
#include <altivec.h>
|
| 26 |
+
/* We need to undef those tokens defined by <altivec.h> to avoid conflicts
|
| 27 |
+
with the C++ types. => Can still use __bool/__vector */
|
| 28 |
+
#undef bool
|
| 29 |
+
#undef vector
|
| 30 |
+
#undef pixel
|
| 31 |
+
#elif defined(__GNUC__) && defined(__SPE__)
|
| 32 |
+
/* GCC-compatible compiler, targeting PowerPC with SPE */
|
| 33 |
+
#include <spe.h>
|
| 34 |
+
#endif
|
| 35 |
+
|
| 36 |
+
#else
|
| 37 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 38 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/IsContiguous.h
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
namespace at::native { inline namespace CPU_CAPABILITY {
|
| 5 |
+
|
| 6 |
+
// n: number of function arguments (arity)
|
| 7 |
+
// traits: function_traits (see FunctionTraits.h)
|
| 8 |
+
// s: index of scalar argument or -1
|
| 9 |
+
template <int n, int stride_index, typename traits, int s=-1>
|
| 10 |
+
struct IsContiguous {
|
| 11 |
+
static bool eval(const int64_t* strides) {
|
| 12 |
+
using type = typename traits::template arg<n - 1>::type;
|
| 13 |
+
return strides[stride_index] == (s == n ? 0 : sizeof(type)) &&
|
| 14 |
+
IsContiguous<n - 1, stride_index - 1, traits, s>::eval(strides);
|
| 15 |
+
}
|
| 16 |
+
};
|
| 17 |
+
|
| 18 |
+
// will be called when there is an output exists
|
| 19 |
+
template <typename traits, int s>
|
| 20 |
+
struct IsContiguous<0, 0, traits, s> {
|
| 21 |
+
static bool eval(const int64_t* strides) {
|
| 22 |
+
return strides[0] == sizeof(typename traits::result_type);
|
| 23 |
+
}
|
| 24 |
+
};
|
| 25 |
+
|
| 26 |
+
// will be called when there is no output
|
| 27 |
+
template <typename traits, int s>
|
| 28 |
+
struct IsContiguous<0, -1, traits, s> {
|
| 29 |
+
static bool eval(const int64_t* /*strides*/) {
|
| 30 |
+
return true;
|
| 31 |
+
}
|
| 32 |
+
};
|
| 33 |
+
|
| 34 |
+
// output and all inputs are contiguous
|
| 35 |
+
template <
|
| 36 |
+
typename traits,
|
| 37 |
+
std::enable_if_t<std::is_void_v<typename traits::result_type>>* =
|
| 38 |
+
nullptr>
|
| 39 |
+
static inline bool is_contiguous(const int64_t* strides) {
|
| 40 |
+
return IsContiguous<traits::arity, traits::arity - 1, traits>::eval(strides);
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
template <typename traits,
|
| 44 |
+
std::enable_if_t<!std::is_void_v<typename traits::result_type>>* = nullptr>
|
| 45 |
+
static inline bool is_contiguous(const int64_t* strides) {
|
| 46 |
+
return IsContiguous<traits::arity, traits::arity, traits>::eval(strides);
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
// input at `s` is scalar (stride 0); output and other inputs are contiguous
|
| 50 |
+
// NB: output is typically at strides[0] so first input corresponds to s=1
|
| 51 |
+
template <typename traits, int s,
|
| 52 |
+
std::enable_if_t<std::is_void_v<typename traits::result_type>>* = nullptr>
|
| 53 |
+
static inline bool is_contiguous_scalar(const int64_t* strides) {
|
| 54 |
+
static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds");
|
| 55 |
+
return IsContiguous<traits::arity, traits::arity - 1, traits, s>::eval(strides);
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
template <typename traits, int s,
|
| 59 |
+
std::enable_if_t<!std::is_void_v<typename traits::result_type>>* = nullptr>
|
| 60 |
+
static inline bool is_contiguous_scalar(const int64_t* strides) {
|
| 61 |
+
static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds");
|
| 62 |
+
return IsContiguous<traits::arity, traits::arity, traits, s>::eval(strides);
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
}}
|
| 66 |
+
|
| 67 |
+
#else
|
| 68 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 69 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/LogSoftmaxKernelImpl.h
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/OpMathType.h>
|
| 5 |
+
#include <ATen/Parallel.h>
|
| 6 |
+
#include <ATen/cpu/vec/functional.h>
|
| 7 |
+
#include <ATen/cpu/vec/vec.h>
|
| 8 |
+
#include <c10/util/irange.h>
|
| 9 |
+
|
| 10 |
+
#include <algorithm>
|
| 11 |
+
#include <cmath>
|
| 12 |
+
#include <cstdint>
|
| 13 |
+
#include <limits>
|
| 14 |
+
#include <memory>
|
| 15 |
+
#include <type_traits>
|
| 16 |
+
|
| 17 |
+
namespace at::native {
|
| 18 |
+
inline namespace CPU_CAPABILITY {
|
| 19 |
+
template <typename scalar_t>
|
| 20 |
+
int64_t vec_log_softmax_lastdim_chunk_size(int64_t grain_size, int64_t outer_size, int64_t dim_size) {
|
| 21 |
+
// Coincidentally, at::internal::GRAIN_SIZE is 32768, which is equal to the
|
| 22 |
+
// size of L1D cache on many processors. Some processors have 48 KB L1D cache
|
| 23 |
+
// nowadays, so maybe in the future, we can leverage the knowledge of a
|
| 24 |
+
// machine's L1D cache size.
|
| 25 |
+
int64_t MAX_CHUNK_SIZE = std::max<int64_t>(
|
| 26 |
+
1,
|
| 27 |
+
grain_size / (sizeof(scalar_t) * dim_size));
|
| 28 |
+
return std::min<int64_t>(MAX_CHUNK_SIZE, outer_size);
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
template <typename scalar_t>
|
| 32 |
+
void serial_vec_log_softmax_lastdim_range(
|
| 33 |
+
const scalar_t* input_data_base,
|
| 34 |
+
scalar_t* output_data_base,
|
| 35 |
+
int64_t dim_size,
|
| 36 |
+
int64_t chunk_size,
|
| 37 |
+
int64_t begin,
|
| 38 |
+
int64_t end) {
|
| 39 |
+
if (end <= begin) {
|
| 40 |
+
return;
|
| 41 |
+
}
|
| 42 |
+
using Vec = vec::Vectorized<vec::vec_scalar_t<scalar_t>>;
|
| 43 |
+
// MSVC requires such a declaration of dynamic arrays
|
| 44 |
+
// Source: https://stackoverflow.com/a/33423538
|
| 45 |
+
auto tmp_sum_scalar = std::make_unique<scalar_t[]>(chunk_size);
|
| 46 |
+
auto max_input_arr = std::make_unique<scalar_t[]>(chunk_size);
|
| 47 |
+
for (int64_t ii = begin; ii < end; ii += chunk_size) {
|
| 48 |
+
int64_t loop_end = chunk_size;
|
| 49 |
+
if (ii + chunk_size > end) {
|
| 50 |
+
loop_end = end - ii;
|
| 51 |
+
}
|
| 52 |
+
for (const auto j : c10::irange(loop_end)) {
|
| 53 |
+
int64_t i = ii + j;
|
| 54 |
+
const scalar_t* input_data = input_data_base + i * dim_size;
|
| 55 |
+
max_input_arr[j] = vec::reduce_all<scalar_t>(
|
| 56 |
+
[](Vec& x, Vec& y) { return vec::maximum(x, y); },
|
| 57 |
+
input_data,
|
| 58 |
+
dim_size);
|
| 59 |
+
}
|
| 60 |
+
for (const auto j : c10::irange(loop_end)) {
|
| 61 |
+
int64_t i = ii + j;
|
| 62 |
+
const scalar_t* input_data = input_data_base + i * dim_size;
|
| 63 |
+
scalar_t max_input = max_input_arr[j];
|
| 64 |
+
tmp_sum_scalar[j] = vec::map_reduce_all<scalar_t>(
|
| 65 |
+
[max_input](Vec x) { return (x - Vec(max_input)).exp(); },
|
| 66 |
+
[](Vec x, Vec y) { return x + y; },
|
| 67 |
+
input_data,
|
| 68 |
+
dim_size);
|
| 69 |
+
}
|
| 70 |
+
// See [Note AVX-SSE transitions] for why this should call the
|
| 71 |
+
// vectorized version (aside from perf improvements).
|
| 72 |
+
vec::map(
|
| 73 |
+
[](Vec x) { return x.log(); },
|
| 74 |
+
tmp_sum_scalar.get(),
|
| 75 |
+
tmp_sum_scalar.get(),
|
| 76 |
+
loop_end);
|
| 77 |
+
for (const auto j : c10::irange(loop_end)) {
|
| 78 |
+
int64_t i = ii + j;
|
| 79 |
+
const scalar_t* input_data = input_data_base + i * dim_size;
|
| 80 |
+
scalar_t* output_data = output_data_base + i * dim_size;
|
| 81 |
+
scalar_t tmp_sum = tmp_sum_scalar[j];
|
| 82 |
+
scalar_t max_input = max_input_arr[j];
|
| 83 |
+
|
| 84 |
+
// It's necessary to keep the order of the operations below.
|
| 85 |
+
// In some cases that input is large digits and the difference
|
| 86 |
+
// is small, if we compute `max_input` plus `tmp_sum` before,
|
| 87 |
+
// there would be a numerical problem. See an example in
|
| 88 |
+
// https://github.com/pytorch/pytorch/issues/11752#issuecomment-422883379
|
| 89 |
+
vec::map(
|
| 90 |
+
[tmp_sum, max_input](Vec x) {
|
| 91 |
+
return x - Vec(max_input) - Vec(tmp_sum);
|
| 92 |
+
},
|
| 93 |
+
output_data,
|
| 94 |
+
input_data,
|
| 95 |
+
dim_size);
|
| 96 |
+
}
|
| 97 |
+
}
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
// Can't include ATen/Parallel.h.
|
| 101 |
+
// TODO: find a way to have only one copy of divup.
|
| 102 |
+
inline int64_t divup(int64_t x, int64_t y) {
|
| 103 |
+
return (x + y - 1) / y;
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
template <typename scalar_t, int64_t BLOCK_SIZE = 128 * 1024>
|
| 107 |
+
std::pair<int64_t,int64_t> vec_logsoftmax_chunk_size_and_num_chunks(int64_t inner_size, int64_t dim_size) {
|
| 108 |
+
using Vec = vec::Vectorized<scalar_t>;
|
| 109 |
+
int64_t MAX_CHUNK_SIZE = std::max<int64_t>(BLOCK_SIZE / dim_size / sizeof(scalar_t), Vec::size());
|
| 110 |
+
MAX_CHUNK_SIZE = MAX_CHUNK_SIZE / Vec::size() * Vec::size();
|
| 111 |
+
int64_t CHUNK_SIZE = std::min<int64_t>(MAX_CHUNK_SIZE, inner_size);
|
| 112 |
+
int64_t num_chunks = divup(inner_size, CHUNK_SIZE);
|
| 113 |
+
return {CHUNK_SIZE, num_chunks};
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
template <typename scalar_t>
|
| 117 |
+
std::enable_if_t<std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
|
| 118 |
+
serial_vec_logsoftmax_range(
|
| 119 |
+
const scalar_t* input_data_base,
|
| 120 |
+
scalar_t* output_data_base,
|
| 121 |
+
int64_t inner_size,
|
| 122 |
+
int64_t chunk_size,
|
| 123 |
+
int64_t num_chunks,
|
| 124 |
+
int64_t dim_size,
|
| 125 |
+
int64_t begin,
|
| 126 |
+
int64_t end) {
|
| 127 |
+
using Vec = vec::Vectorized<scalar_t>;
|
| 128 |
+
// thread local temp buffer which holds vertical reduction result: max and sum.
|
| 129 |
+
auto buffer = std::make_unique<scalar_t []>(chunk_size * 2);
|
| 130 |
+
scalar_t* input_max_data = buffer.get();
|
| 131 |
+
scalar_t* tmp_sum_data = buffer.get() + chunk_size;
|
| 132 |
+
|
| 133 |
+
for (int64_t i = begin; i < end; i++) {
|
| 134 |
+
int64_t outer_idx = i / num_chunks;
|
| 135 |
+
int64_t k = i % num_chunks;
|
| 136 |
+
int64_t inner_idx_begin = k * chunk_size;
|
| 137 |
+
int64_t size = std::min(chunk_size, inner_size - inner_idx_begin);
|
| 138 |
+
|
| 139 |
+
// init
|
| 140 |
+
Vec zero_vec = Vec(scalar_t(0));
|
| 141 |
+
Vec min_vec = Vec(-std::numeric_limits<scalar_t>::infinity());
|
| 142 |
+
int64_t d0 = 0;
|
| 143 |
+
for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) {
|
| 144 |
+
min_vec.store(input_max_data + d0);
|
| 145 |
+
zero_vec.store(tmp_sum_data + d0);
|
| 146 |
+
}
|
| 147 |
+
for (; d0 < size; d0++) {
|
| 148 |
+
input_max_data[d0] = -std::numeric_limits<scalar_t>::infinity();
|
| 149 |
+
tmp_sum_data[d0] = scalar_t(0);
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
// compute max
|
| 153 |
+
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
|
| 154 |
+
const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
|
| 155 |
+
+ dim_idx * inner_size + inner_idx_begin;
|
| 156 |
+
|
| 157 |
+
int64_t d1 = 0;
|
| 158 |
+
for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
|
| 159 |
+
Vec data_vec = Vec::loadu(input_ptr + d1);
|
| 160 |
+
Vec max_vec = Vec::loadu(input_max_data + d1);
|
| 161 |
+
max_vec = Vec::blendv(max_vec, data_vec, data_vec > max_vec);
|
| 162 |
+
max_vec.store(input_max_data + d1);
|
| 163 |
+
}
|
| 164 |
+
for (; d1 < size; d1++) {
|
| 165 |
+
scalar_t data_val = input_ptr[d1];
|
| 166 |
+
scalar_t max_val = input_max_data[d1];
|
| 167 |
+
input_max_data[d1] = data_val > max_val ? data_val : max_val;
|
| 168 |
+
}
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
// compute sum of (x - max).exp()
|
| 172 |
+
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
|
| 173 |
+
const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
|
| 174 |
+
+ dim_idx * inner_size + inner_idx_begin;
|
| 175 |
+
|
| 176 |
+
int64_t d2 = 0;
|
| 177 |
+
for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
|
| 178 |
+
Vec data_vec = Vec::loadu(input_ptr + d2);
|
| 179 |
+
Vec sum_vec = Vec::loadu(tmp_sum_data + d2);
|
| 180 |
+
Vec max_vec = Vec::loadu(input_max_data + d2);
|
| 181 |
+
sum_vec += (data_vec - max_vec).exp();
|
| 182 |
+
sum_vec.store(tmp_sum_data + d2);
|
| 183 |
+
}
|
| 184 |
+
for (; d2 < size; d2++) {
|
| 185 |
+
scalar_t data_val = input_ptr[d2];
|
| 186 |
+
scalar_t max_val = input_max_data[d2];
|
| 187 |
+
tmp_sum_data[d2] += std::exp(data_val - max_val);
|
| 188 |
+
}
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
// apply log
|
| 192 |
+
vec::map([](Vec x) { return x.log(); }, tmp_sum_data, tmp_sum_data, size);
|
| 193 |
+
|
| 194 |
+
// compute x - max - sum
|
| 195 |
+
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
|
| 196 |
+
int64_t offset = outer_idx * dim_size * inner_size + dim_idx * inner_size + inner_idx_begin;
|
| 197 |
+
const scalar_t* input_ptr = input_data_base + offset;
|
| 198 |
+
scalar_t* output_ptr = output_data_base + offset;
|
| 199 |
+
|
| 200 |
+
int64_t d3 = 0;
|
| 201 |
+
for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) {
|
| 202 |
+
Vec data_vec = Vec::loadu(input_ptr + d3);
|
| 203 |
+
Vec max_vec = Vec::loadu(input_max_data + d3);
|
| 204 |
+
Vec sum_vec = Vec::loadu(tmp_sum_data + d3);
|
| 205 |
+
Vec out_vec = data_vec - max_vec - sum_vec;
|
| 206 |
+
out_vec.store(output_ptr + d3);
|
| 207 |
+
}
|
| 208 |
+
for (; d3 < size; d3++) {
|
| 209 |
+
output_ptr[d3] = input_ptr[d3] - input_max_data[d3] - tmp_sum_data[d3];
|
| 210 |
+
}
|
| 211 |
+
}
|
| 212 |
+
}
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
template <typename scalar_t>
|
| 216 |
+
std::enable_if_t<!std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
|
| 217 |
+
serial_vec_logsoftmax_range(
|
| 218 |
+
const scalar_t* input_data_base,
|
| 219 |
+
scalar_t* output_data_base,
|
| 220 |
+
int64_t inner_size,
|
| 221 |
+
int64_t chunk_size,
|
| 222 |
+
int64_t num_chunks,
|
| 223 |
+
int64_t dim_size,
|
| 224 |
+
int64_t begin,
|
| 225 |
+
int64_t end) {
|
| 226 |
+
using Vec = vec::Vectorized<scalar_t>;
|
| 227 |
+
using fVec = vec::Vectorized<float>;
|
| 228 |
+
auto buffer = std::make_unique<float []>(chunk_size * 2);
|
| 229 |
+
float* input_max_data = buffer.get();
|
| 230 |
+
float* tmp_sum_data = buffer.get() + chunk_size;
|
| 231 |
+
|
| 232 |
+
// thread local buffer that holds input data in float32 to save next 2 dtype conversion
|
| 233 |
+
auto input_buffer = std::make_unique<float []>(dim_size * chunk_size);
|
| 234 |
+
float* input_buffer_data = input_buffer.get();
|
| 235 |
+
|
| 236 |
+
// init
|
| 237 |
+
for (int64_t i = begin; i < end; i++) {
|
| 238 |
+
int64_t outer_idx = i / num_chunks;
|
| 239 |
+
int64_t k = i % num_chunks;
|
| 240 |
+
int64_t inner_idx_begin = k * chunk_size;
|
| 241 |
+
int64_t size = std::min(chunk_size, inner_size - inner_idx_begin);
|
| 242 |
+
|
| 243 |
+
fVec zero_fvec = fVec(float(0));
|
| 244 |
+
fVec min_fvec = fVec(-std::numeric_limits<float>::infinity());
|
| 245 |
+
int64_t d0 = 0;
|
| 246 |
+
for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) {
|
| 247 |
+
min_fvec.store(input_max_data + d0);
|
| 248 |
+
min_fvec.store(input_max_data + d0 + fVec::size());
|
| 249 |
+
zero_fvec.store(tmp_sum_data + d0);
|
| 250 |
+
zero_fvec.store(tmp_sum_data + d0 + fVec::size());
|
| 251 |
+
}
|
| 252 |
+
for (; d0 < size; d0++) {
|
| 253 |
+
input_max_data[d0] = -std::numeric_limits<float>::infinity();
|
| 254 |
+
tmp_sum_data[d0] = float(0);
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
// compute max
|
| 258 |
+
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
|
| 259 |
+
const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
|
| 260 |
+
+ dim_idx * inner_size + inner_idx_begin;
|
| 261 |
+
float* input_buffer_ptr = input_buffer_data + dim_idx * chunk_size;
|
| 262 |
+
|
| 263 |
+
int64_t d1 = 0;
|
| 264 |
+
for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
|
| 265 |
+
Vec data_vec = Vec::loadu(input_ptr + d1);
|
| 266 |
+
auto [data_fvec0, data_fvec1] = vec::convert_to_float<scalar_t>(data_vec);
|
| 267 |
+
fVec max_fvec0 = fVec::loadu(input_max_data + d1);
|
| 268 |
+
fVec max_fvec1 = fVec::loadu(input_max_data + d1 + fVec::size());
|
| 269 |
+
max_fvec0 = fVec::blendv(max_fvec0, data_fvec0, data_fvec0 > max_fvec0);
|
| 270 |
+
max_fvec1 = fVec::blendv(max_fvec1, data_fvec1, data_fvec1 > max_fvec1);
|
| 271 |
+
max_fvec0.store(input_max_data + d1);
|
| 272 |
+
max_fvec1.store(input_max_data + d1 + fVec::size());
|
| 273 |
+
|
| 274 |
+
// cache the 'converted' float input
|
| 275 |
+
data_fvec0.store(input_buffer_ptr + d1);
|
| 276 |
+
data_fvec1.store(input_buffer_ptr + d1 + fVec::size());
|
| 277 |
+
}
|
| 278 |
+
for (; d1 < size; d1++) {
|
| 279 |
+
float data_val = float(input_ptr[d1]);
|
| 280 |
+
float max_val = input_max_data[d1];
|
| 281 |
+
input_max_data[d1] = data_val > max_val ? data_val : max_val;
|
| 282 |
+
input_buffer_ptr[d1] = data_val;
|
| 283 |
+
}
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
// compute sum of (x - max).exp()
|
| 287 |
+
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
|
| 288 |
+
float* input_buffer_ptr = input_buffer_data + dim_idx * chunk_size;
|
| 289 |
+
|
| 290 |
+
int64_t d2 = 0;
|
| 291 |
+
for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
|
| 292 |
+
fVec data_fvec0 = fVec::loadu(input_buffer_ptr + d2);
|
| 293 |
+
fVec data_fvec1 = fVec::loadu(input_buffer_ptr + d2 + fVec::size());
|
| 294 |
+
fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d2);
|
| 295 |
+
fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d2 + fVec::size());
|
| 296 |
+
fVec max_fvec0 = fVec::loadu(input_max_data + d2);
|
| 297 |
+
fVec max_fvec1 = fVec::loadu(input_max_data + d2 + fVec::size());
|
| 298 |
+
sum_fvec0 += (data_fvec0 - max_fvec0).exp();
|
| 299 |
+
sum_fvec1 += (data_fvec1 - max_fvec1).exp();
|
| 300 |
+
sum_fvec0.store(tmp_sum_data + d2);
|
| 301 |
+
sum_fvec1.store(tmp_sum_data + d2 + fVec::size());
|
| 302 |
+
}
|
| 303 |
+
for (; d2 < size; d2++) {
|
| 304 |
+
float data_val = input_buffer_ptr[d2];
|
| 305 |
+
float max_val = input_max_data[d2];
|
| 306 |
+
tmp_sum_data[d2] += std::exp(data_val - max_val);
|
| 307 |
+
}
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
// apply log
|
| 311 |
+
vec::map([](fVec x) { return x.log(); }, tmp_sum_data, tmp_sum_data, size);
|
| 312 |
+
|
| 313 |
+
// compute x - max - sum
|
| 314 |
+
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
|
| 315 |
+
float* input_buffer_ptr = input_buffer_data + dim_idx * chunk_size;
|
| 316 |
+
scalar_t* output_ptr = output_data_base + outer_idx * dim_size * inner_size
|
| 317 |
+
+ dim_idx * inner_size + inner_idx_begin;
|
| 318 |
+
|
| 319 |
+
int64_t d3 = 0;
|
| 320 |
+
for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) {
|
| 321 |
+
fVec data_fvec0 = fVec::loadu(input_buffer_ptr + d3);
|
| 322 |
+
fVec data_fvec1 = fVec::loadu(input_buffer_ptr + d3 + fVec::size());
|
| 323 |
+
fVec max_fvec0 = fVec::loadu(input_max_data + d3);
|
| 324 |
+
fVec max_fvec1 = fVec::loadu(input_max_data + d3 + fVec::size());
|
| 325 |
+
fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d3);
|
| 326 |
+
fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d3 + fVec::size());
|
| 327 |
+
fVec out_fvec0 = data_fvec0 - max_fvec0 - sum_fvec0;
|
| 328 |
+
fVec out_fvec1 = data_fvec1 - max_fvec1 - sum_fvec1;
|
| 329 |
+
Vec out_vec = vec::convert_from_float<scalar_t>(out_fvec0, out_fvec1);
|
| 330 |
+
out_vec.store(output_ptr + d3);
|
| 331 |
+
}
|
| 332 |
+
for (; d3 < size; d3++) {
|
| 333 |
+
output_ptr[d3] = scalar_t(input_buffer_ptr[d3] - input_max_data[d3] - tmp_sum_data[d3]);
|
| 334 |
+
}
|
| 335 |
+
}
|
| 336 |
+
}
|
| 337 |
+
} // namespace CPU_CAPABILITY
|
| 338 |
+
}} // namespace at::native
|
| 339 |
+
|
| 340 |
+
#else
|
| 341 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 342 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/Loops.h
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
// This file provides two functions to help write elementwise kernels:
|
| 5 |
+
//
|
| 6 |
+
// cpu_kernel(TensorIterator iter, <lambda>)
|
| 7 |
+
// cpu_kernel_vec(TensorIterator iter, <lambda>, <vec_lambda>)
|
| 8 |
+
//
|
| 9 |
+
// Both functions may generate vectorized code. The cpu_kernel implementation
|
| 10 |
+
// relies on the compiler's auto-vectorization. The cpu_kernel_vec
|
| 11 |
+
// implementation uses x86 SIMD intrinsics when available. These functions
|
| 12 |
+
// are only intended to be used in the ATen/native/cpu subdirectory, since files
|
| 13 |
+
// in other directories are not compiled with AVX/AVX2 enabled. See README.md
|
| 14 |
+
// for more details.
|
| 15 |
+
//
|
| 16 |
+
// For example, to write a multiplication kernel for float:
|
| 17 |
+
//
|
| 18 |
+
// cpu_kernel(iter, [](float a, float b) { return a * b; });
|
| 19 |
+
//
|
| 20 |
+
// Or you may write:
|
| 21 |
+
//
|
| 22 |
+
// cpu_kernel_vec(iter,
|
| 23 |
+
// [](float a, float b) { return a * b; },
|
| 24 |
+
// [](Vectorized<float> a, Vectorized<float> b) { return a * b; });
|
| 25 |
+
//
|
| 26 |
+
// See BinaryOpsKernel.cpp for the complete implementation
|
| 27 |
+
//
|
| 28 |
+
//
|
| 29 |
+
|
| 30 |
+
#include <cstdint>
|
| 31 |
+
#include <c10/util/C++17.h>
|
| 32 |
+
#include <c10/util/Load.h>
|
| 33 |
+
#include <c10/util/irange.h>
|
| 34 |
+
#include <ATen/detail/FunctionTraits.h>
|
| 35 |
+
#include <ATen/native/cpu/IsContiguous.h>
|
| 36 |
+
#include <ATen/native/TensorIterator.h>
|
| 37 |
+
#include <ATen/native/TensorIteratorDynamicCasting.h>
|
| 38 |
+
#include <ATen/cpu/vec/vec.h>
|
| 39 |
+
|
| 40 |
+
#include <tuple>
|
| 41 |
+
#include <utility>
|
| 42 |
+
|
| 43 |
+
namespace at::native { inline namespace CPU_CAPABILITY {
|
| 44 |
+
|
| 45 |
+
using namespace vec;
|
| 46 |
+
|
| 47 |
+
template <typename traits, std::size_t... INDEX>
|
| 48 |
+
typename traits::ArgsTuple
|
| 49 |
+
dereference_impl(char* C10_RESTRICT data[], const int64_t* strides, int64_t i,
|
| 50 |
+
std::index_sequence<INDEX...> /*unused*/) {
|
| 51 |
+
return std::make_tuple(
|
| 52 |
+
c10::load<typename traits::template arg<INDEX>::type>(
|
| 53 |
+
data[INDEX] + i * strides[INDEX])...);
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
template <typename traits>
|
| 57 |
+
typename traits::ArgsTuple
|
| 58 |
+
dereference(char* C10_RESTRICT data[], const int64_t* strides, int64_t i) {
|
| 59 |
+
using Indices = std::make_index_sequence<traits::arity>;
|
| 60 |
+
return dereference_impl<traits>(data, strides, i, Indices{});
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
template <typename traits, std::size_t... INDEX>
|
| 64 |
+
typename traits::ArgsTuple
|
| 65 |
+
dereference_vec_impl(char* C10_RESTRICT data[],
|
| 66 |
+
const typename traits::result_type& opt_scalar,
|
| 67 |
+
size_t S,
|
| 68 |
+
int64_t i,
|
| 69 |
+
std::index_sequence<INDEX...> /*unused*/) {
|
| 70 |
+
using Vec = typename traits::result_type;
|
| 71 |
+
using scalar_t = typename Vec::value_type;
|
| 72 |
+
return std::make_tuple(
|
| 73 |
+
S == INDEX + 1 ?
|
| 74 |
+
opt_scalar :
|
| 75 |
+
Vec::loadu(data[INDEX] + i * sizeof(scalar_t))...);
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
template <typename traits>
|
| 79 |
+
typename traits::ArgsTuple
|
| 80 |
+
dereference_vec(char* C10_RESTRICT data[], const typename traits::result_type& opt_scalar, size_t S, int64_t i) {
|
| 81 |
+
using Indices = std::make_index_sequence<traits::arity>;
|
| 82 |
+
return dereference_vec_impl<traits>(data, opt_scalar, S, i, Indices{});
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
template <typename func_t,
|
| 86 |
+
std::enable_if_t<!std::is_void_v<typename function_traits<func_t>::result_type>>* = nullptr>
|
| 87 |
+
inline void
|
| 88 |
+
execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) {
|
| 89 |
+
using traits = function_traits<func_t>;
|
| 90 |
+
using result_type = typename traits::result_type;
|
| 91 |
+
for (; i < n; i++) {
|
| 92 |
+
result_type* out_ptr = (result_type*)(data[0] + i * strides[0]);
|
| 93 |
+
*out_ptr = std::apply(op, dereference<traits>(
|
| 94 |
+
&data[1],
|
| 95 |
+
&strides[1],
|
| 96 |
+
i));
|
| 97 |
+
}
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
template <typename func_t,
|
| 101 |
+
std::enable_if_t<std::is_void_v<typename function_traits<func_t>::result_type>>* = nullptr>
|
| 102 |
+
inline void
|
| 103 |
+
execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) {
|
| 104 |
+
using traits = function_traits<func_t>;
|
| 105 |
+
for (; i < n; i++) {
|
| 106 |
+
std::apply(op, dereference<traits>(
|
| 107 |
+
&data[0],
|
| 108 |
+
&strides[0],
|
| 109 |
+
i));
|
| 110 |
+
}
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
// Basic loop operation (one output, N inputs). May be auto-vectorized
|
| 114 |
+
// by the compiler. Supports inputs and outputs of different types.
|
| 115 |
+
template <typename func_t>
|
| 116 |
+
inline void
|
| 117 |
+
basic_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) {
|
| 118 |
+
using traits = function_traits<func_t>;
|
| 119 |
+
constexpr int ntensors = traits::arity + 1;
|
| 120 |
+
|
| 121 |
+
// Copying strides to temporary array helps auto vectorization in older GCC
|
| 122 |
+
// versions.
|
| 123 |
+
int64_t strides[ntensors];
|
| 124 |
+
for (const auto arg : c10::irange(ntensors)) {
|
| 125 |
+
strides[arg] = strides_[arg];
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
execute_op(data, strides, i, n, std::forward<func_t>(op));
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
// the recursive variadic template for iterating over the returned tuple
|
| 132 |
+
template<class T, size_t N>
|
| 133 |
+
struct TupleOutput {
|
| 134 |
+
static void handle(char *C10_RESTRICT data[], const int64_t *strides, int64_t i,
|
| 135 |
+
const T &tuple) {
|
| 136 |
+
TupleOutput<T, N - 1>::handle(data, strides, i, tuple);
|
| 137 |
+
|
| 138 |
+
auto output = std::get<N - 1>(tuple);
|
| 139 |
+
using output_type = decltype(output);
|
| 140 |
+
output_type * out_ptr = (output_type *)(data[N - 1] + i * strides[N - 1]);
|
| 141 |
+
*out_ptr = output;
|
| 142 |
+
}
|
| 143 |
+
};
|
| 144 |
+
|
| 145 |
+
// Base case for the above recursive template
|
| 146 |
+
template<class T>
|
| 147 |
+
struct TupleOutput<T, 1> {
|
| 148 |
+
static void handle(char *C10_RESTRICT data[], const int64_t *strides, int64_t i,
|
| 149 |
+
const T &tuple) {
|
| 150 |
+
auto output = std::get<0>(tuple);
|
| 151 |
+
using output_type = decltype(output);
|
| 152 |
+
output_type* out_ptr = (output_type *)(data[0] + i * strides[0]);
|
| 153 |
+
*out_ptr = output;
|
| 154 |
+
}
|
| 155 |
+
};
|
| 156 |
+
|
| 157 |
+
template<class... Args>
|
| 158 |
+
void handle_tuple_outputs(char* C10_RESTRICT data[],
|
| 159 |
+
const int64_t* strides,
|
| 160 |
+
int64_t i,
|
| 161 |
+
const std::tuple<Args...> &tuple) {
|
| 162 |
+
TupleOutput<decltype(tuple), sizeof...(Args)>::handle(data, strides, i, tuple);
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
// Loop operation for `cpu_kernel_multiple_outputs`.
|
| 166 |
+
// 1. Use `std::apply` to make dynamic method invocation
|
| 167 |
+
// for the lambda passed in `cpu_kernel_multiple_outputs`.
|
| 168 |
+
// 2. Iterate over the members of the returned tuple, set the corresponding
|
| 169 |
+
// output tensor by the tuple member in `handle_tuple_outputs` function.
|
| 170 |
+
template <typename func_t>
|
| 171 |
+
inline void
|
| 172 |
+
multiple_outputs_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) {
|
| 173 |
+
using traits = function_traits<func_t>;
|
| 174 |
+
|
| 175 |
+
using result_type = typename traits::result_type;
|
| 176 |
+
constexpr int num_outputs = std::tuple_size_v<result_type>;
|
| 177 |
+
constexpr int ntensors = traits::arity + num_outputs;
|
| 178 |
+
|
| 179 |
+
// Copying strides to temporary array helps auto vectorization in older GCC
|
| 180 |
+
// versions.
|
| 181 |
+
int64_t strides[ntensors];
|
| 182 |
+
for (const auto arg : c10::irange(ntensors)) {
|
| 183 |
+
strides[arg] = strides_[arg];
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
for (; i < n; i++) {
|
| 187 |
+
auto output = std::apply(op, dereference<traits>(
|
| 188 |
+
&data[num_outputs],
|
| 189 |
+
&strides[num_outputs],
|
| 190 |
+
i));
|
| 191 |
+
handle_tuple_outputs(data, strides, i, output);
|
| 192 |
+
}
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
// Explicitly vectorized loop implementation. All inputs and outputs must be
|
| 196 |
+
// the same type and contiguous with one exception: a single input may be
|
| 197 |
+
// a scalar (stride 0). It's position is indicated by the argument `S`. If `S`
|
| 198 |
+
// is 0, then there are no scalar inputs.
|
| 199 |
+
template <typename func_t, typename vec_func_t>
|
| 200 |
+
inline void
|
| 201 |
+
vectorized_loop(char** C10_RESTRICT data_, int64_t n, int64_t S, func_t&& op, vec_func_t&& vop) {
|
| 202 |
+
using traits = function_traits<vec_func_t>;
|
| 203 |
+
using scalar_t = typename function_traits<func_t>::result_type;
|
| 204 |
+
using Vec = Vectorized<scalar_t>;
|
| 205 |
+
constexpr int ntensors = traits::arity + 1;
|
| 206 |
+
|
| 207 |
+
char* C10_RESTRICT data[ntensors];
|
| 208 |
+
for (const auto arg : c10::irange(ntensors)) {
|
| 209 |
+
data[arg] = data_[arg];
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
Vec opt_scalar = Vec(S > 0 ? c10::load((scalar_t*)data[S]) : scalar_t(0));
|
| 213 |
+
int64_t i = 0;
|
| 214 |
+
for (; i <= n - 2 * Vec::size(); i += 2 * Vec::size()) {
|
| 215 |
+
auto args1 = dereference_vec<traits>(&data[1], opt_scalar, S, i);
|
| 216 |
+
auto args2 = dereference_vec<traits>(&data[1], opt_scalar, S, i + Vec::size());
|
| 217 |
+
auto out1 = std::apply(vop, std::move(args1));
|
| 218 |
+
auto out2 = std::apply(vop, std::move(args2));
|
| 219 |
+
out1.store(data[0] + i * sizeof(scalar_t));
|
| 220 |
+
out2.store(data[0] + (i + Vec::size()) * sizeof(scalar_t));
|
| 221 |
+
}
|
| 222 |
+
if (i < n) {
|
| 223 |
+
int64_t strides[ntensors];
|
| 224 |
+
for (const auto arg : c10::irange(ntensors)) {
|
| 225 |
+
strides[arg] = (S > 0 && arg == S) ? 0 : sizeof(scalar_t);
|
| 226 |
+
}
|
| 227 |
+
basic_loop(data, strides, i, n, std::forward<func_t>(op));
|
| 228 |
+
}
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
template <typename traits, typename cb_t>
|
| 233 |
+
inline void unroll_contiguous_scalar_checks(
|
| 234 |
+
const int64_t* /*strides*/,
|
| 235 |
+
std::index_sequence<> /*unused*/,
|
| 236 |
+
cb_t&& cb) {
|
| 237 |
+
cb(0);
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
template <typename traits, typename cb_t, size_t INDEX0, size_t ...INDEX>
|
| 241 |
+
inline void unroll_contiguous_scalar_checks(
|
| 242 |
+
const int64_t* strides,
|
| 243 |
+
std::index_sequence<INDEX0, INDEX...> /*unused*/,
|
| 244 |
+
cb_t&& cb) {
|
| 245 |
+
if (is_contiguous_scalar<traits, INDEX0 + 1>(strides)) {
|
| 246 |
+
cb(INDEX0 + 1);
|
| 247 |
+
} else {
|
| 248 |
+
unroll_contiguous_scalar_checks<traits>(strides, std::index_sequence<INDEX...>{}, std::forward<cb_t>(cb));
|
| 249 |
+
}
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
template <typename op_t, typename vop_t>
|
| 253 |
+
struct VectorizedLoop2d {
|
| 254 |
+
op_t op;
|
| 255 |
+
vop_t vop;
|
| 256 |
+
|
| 257 |
+
using traits = function_traits<op_t>;
|
| 258 |
+
static constexpr int ntensors = traits::arity + 1;
|
| 259 |
+
using data_t = std::array<char*, ntensors>;
|
| 260 |
+
|
| 261 |
+
VectorizedLoop2d(op_t op, vop_t vop):
|
| 262 |
+
op(std::move(op)), vop(std::move(vop)) {}
|
| 263 |
+
|
| 264 |
+
static void advance(data_t &data, const int64_t *outer_strides) {
|
| 265 |
+
for (const auto arg : c10::irange(data.size())) {
|
| 266 |
+
data[arg] += outer_strides[arg];
|
| 267 |
+
}
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
void operator()(char** base, const int64_t *strides, int64_t size0, int64_t size1) {
|
| 271 |
+
data_t data;
|
| 272 |
+
std::copy_n(base, ntensors, data.data());
|
| 273 |
+
const int64_t *outer_strides = &strides[ntensors];
|
| 274 |
+
|
| 275 |
+
if (is_contiguous<traits>(strides)) {
|
| 276 |
+
for ([[maybe_unused]] const auto i : c10::irange(size1)) {
|
| 277 |
+
vectorized_loop(data.data(), size0, 0, op, vop);
|
| 278 |
+
advance(data, outer_strides);
|
| 279 |
+
}
|
| 280 |
+
} else {
|
| 281 |
+
using Indices = std::make_index_sequence<traits::arity>;
|
| 282 |
+
unroll_contiguous_scalar_checks<traits>(strides, Indices{}, [&](size_t idx) {
|
| 283 |
+
if (idx) {
|
| 284 |
+
for ([[maybe_unused]] const auto i : c10::irange(size1)) {
|
| 285 |
+
vectorized_loop(data.data(), size0, idx, op, vop);
|
| 286 |
+
advance(data, outer_strides);
|
| 287 |
+
}
|
| 288 |
+
} else {
|
| 289 |
+
for ([[maybe_unused]] const auto i : c10::irange(size1)) {
|
| 290 |
+
basic_loop(data.data(), strides, 0, size0, op);
|
| 291 |
+
advance(data, outer_strides);
|
| 292 |
+
}
|
| 293 |
+
}
|
| 294 |
+
});
|
| 295 |
+
}
|
| 296 |
+
}
|
| 297 |
+
};
|
| 298 |
+
|
| 299 |
+
template <typename op_t, typename vop_t>
|
| 300 |
+
VectorizedLoop2d<op_t, vop_t> make_vectorized_loop2d(
|
| 301 |
+
op_t &&op, vop_t &&vop) {
|
| 302 |
+
return VectorizedLoop2d<op_t, vop_t>(std::forward<op_t>(op), std::forward<vop_t>(vop));
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
template <typename func_t>
|
| 306 |
+
void cpu_kernel(TensorIteratorBase& iter, func_t&& op, int64_t grain_size = at::internal::GRAIN_SIZE) {
|
| 307 |
+
using traits = function_traits<func_t>;
|
| 308 |
+
// this could be extended to work with void return types
|
| 309 |
+
TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
|
| 310 |
+
TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
|
| 311 |
+
// dynamic casting not currently supported on CPU
|
| 312 |
+
TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
|
| 313 |
+
|
| 314 |
+
iter.for_each([&](char** data, const int64_t* strides, int64_t n) {
|
| 315 |
+
// basic loop can handle 1d slices with arbitrary strides, and 1d slices is all that
|
| 316 |
+
// iter.for_each is ever sending to the loop lambda
|
| 317 |
+
basic_loop(data, strides, 0, n, op);
|
| 318 |
+
}, grain_size);
|
| 319 |
+
iter.cast_outputs();
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
// This function helps write elementwise kernels that requires multiple outputs.
|
| 323 |
+
// It follows the similar structure of cpu_kernel.
|
| 324 |
+
// Instead of `basic_loop` function, a new `multiple_outputs_loop` function is
|
| 325 |
+
// manipulated to handle multiple return values.
|
| 326 |
+
// For now `needs_dynamic_casting` check is not added as the passed lambda (`func_t`)
|
| 327 |
+
// of `multiple_outputs_loop` returns `std::tuple` instead of `scalar_t`.
|
| 328 |
+
// The `gpu_kernel_multiple_outputs` is also implemented without this check,
|
| 329 |
+
// We could extend `needs_dynamic_casting` to support both `std::tuple` and
|
| 330 |
+
// `thrust::tuple` in the future.
|
| 331 |
+
template <typename func_t>
|
| 332 |
+
void cpu_kernel_multiple_outputs(TensorIteratorBase& iter, func_t&& op, int64_t grain_size = at::internal::GRAIN_SIZE) {
|
| 333 |
+
using traits = function_traits<func_t>;
|
| 334 |
+
TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
|
| 335 |
+
|
| 336 |
+
iter.for_each([&](char** data, const int64_t* strides, int64_t n) {
|
| 337 |
+
multiple_outputs_loop(data, strides, 0, n, op);
|
| 338 |
+
}, grain_size);
|
| 339 |
+
iter.cast_outputs();
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
template <bool check_dynamic_cast=true, typename func_t, typename vec_func_t>
|
| 343 |
+
void cpu_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop, int64_t grain_size = at::internal::GRAIN_SIZE) {
|
| 344 |
+
using traits = function_traits<func_t>;
|
| 345 |
+
// this could be extended to work with void return types
|
| 346 |
+
TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
|
| 347 |
+
TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
|
| 348 |
+
// dynamic casting not currently supported on CPU, but some kernels (like Fill)
|
| 349 |
+
// explicitly dynamic_cast, so we give the opt-out of checking.
|
| 350 |
+
if constexpr (check_dynamic_cast) {
|
| 351 |
+
TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
iter.for_each(make_vectorized_loop2d(std::forward<func_t>(op), std::forward<vec_func_t>(vop)), grain_size);
|
| 355 |
+
iter.cast_outputs();
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
template <typename func_t>
|
| 359 |
+
void cpu_serial_kernel(TensorIteratorBase& iter, func_t&& op, const Range& range) {
|
| 360 |
+
using traits = function_traits<func_t>;
|
| 361 |
+
constexpr bool result_void = std::is_void_v<typename traits::result_type>;
|
| 362 |
+
TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity &&
|
| 363 |
+
((result_void && iter.noutputs() == 0) || (!result_void && iter.noutputs() == 1)));
|
| 364 |
+
// dynamic casting not currently supported on CPU
|
| 365 |
+
TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
|
| 366 |
+
|
| 367 |
+
iter.serial_for_each([&](char** data, const int64_t* strides, int64_t n) {
|
| 368 |
+
basic_loop(data, strides, 0, n, op);
|
| 369 |
+
}, range);
|
| 370 |
+
iter.cast_outputs();
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
template <typename func_t>
|
| 374 |
+
void cpu_serial_kernel(TensorIteratorBase& iter, func_t&& op) {
|
| 375 |
+
cpu_serial_kernel(iter, std::forward<func_t>(op), {0, iter.numel()});
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
template <typename func_t, typename vec_func_t>
|
| 379 |
+
void cpu_serial_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop, const Range& range) {
|
| 380 |
+
using traits = function_traits<func_t>;
|
| 381 |
+
// this could be extended to work with void return types
|
| 382 |
+
TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
|
| 383 |
+
TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
|
| 384 |
+
// dynamic casting not currently supported on CPU
|
| 385 |
+
TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
|
| 386 |
+
|
| 387 |
+
iter.serial_for_each(make_vectorized_loop2d(std::forward<func_t>(op), std::forward<vec_func_t>(vop)), range);
|
| 388 |
+
iter.cast_outputs();
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
template <typename func_t, typename vec_func_t>
|
| 392 |
+
void cpu_serial_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop) {
|
| 393 |
+
cpu_serial_kernel_vec(iter, std::forward<func_t>(op), std::forward<vec_func_t>(vop), {0, iter.numel()});
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
}} // namespace at::native::<anonymous>
|
| 397 |
+
|
| 398 |
+
#else
|
| 399 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 400 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/MaxUnpoolKernel.h
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
class Tensor;
|
| 7 |
+
|
| 8 |
+
namespace native {
|
| 9 |
+
|
| 10 |
+
using max_unpooling_fn = void(*)(Tensor&, const Tensor&, const Tensor&);
|
| 11 |
+
|
| 12 |
+
DECLARE_DISPATCH(max_unpooling_fn, max_unpool2d_kernel)
|
| 13 |
+
DECLARE_DISPATCH(max_unpooling_fn, max_unpool3d_kernel)
|
| 14 |
+
|
| 15 |
+
}} // at::native
|
| 16 |
+
|
| 17 |
+
#else
|
| 18 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 19 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/PixelShuffleKernel.h
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
class TensorBase;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
namespace at::native {
|
| 10 |
+
|
| 11 |
+
using pixel_shuffle_fn = void(*)(TensorBase&, const TensorBase&, int64_t);
|
| 12 |
+
DECLARE_DISPATCH(pixel_shuffle_fn, pixel_shuffle_kernel)
|
| 13 |
+
DECLARE_DISPATCH(pixel_shuffle_fn, pixel_unshuffle_kernel)
|
| 14 |
+
|
| 15 |
+
} // at::native
|
| 16 |
+
|
| 17 |
+
#else
|
| 18 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 19 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/ReduceUtils.h
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/Parallel.h>
|
| 5 |
+
#include <ATen/NumericUtils.h>
|
| 6 |
+
#include <ATen/cpu/vec/vec.h>
|
| 7 |
+
#include <ATen/cpu/vec/functional.h>
|
| 8 |
+
#include <ATen/native/ReductionType.h>
|
| 9 |
+
#include <c10/util/irange.h>
|
| 10 |
+
#include <ATen/OpMathType.h>
|
| 11 |
+
#include <ATen/native/cpu/utils.h>
|
| 12 |
+
|
| 13 |
+
namespace at::native {
|
| 14 |
+
inline namespace CPU_CAPABILITY {
|
| 15 |
+
|
| 16 |
+
using namespace vec;
|
| 17 |
+
|
| 18 |
+
#define AT_DISPATCH_REDUCTION_TYPES(op, ...) \
|
| 19 |
+
[&] { \
|
| 20 |
+
switch (op) { \
|
| 21 |
+
case ReductionType::SUM: { \
|
| 22 |
+
static constexpr auto reduce = ReductionType::SUM; \
|
| 23 |
+
return __VA_ARGS__(); \
|
| 24 |
+
} \
|
| 25 |
+
case ReductionType::MEAN: { \
|
| 26 |
+
static constexpr auto reduce = ReductionType::MEAN; \
|
| 27 |
+
return __VA_ARGS__(); \
|
| 28 |
+
} \
|
| 29 |
+
case ReductionType::MIN: { \
|
| 30 |
+
static constexpr auto reduce = ReductionType::MIN; \
|
| 31 |
+
return __VA_ARGS__(); \
|
| 32 |
+
} \
|
| 33 |
+
case ReductionType::MAX: { \
|
| 34 |
+
static constexpr auto reduce = ReductionType::MAX; \
|
| 35 |
+
return __VA_ARGS__(); \
|
| 36 |
+
} \
|
| 37 |
+
case ReductionType::PROD: { \
|
| 38 |
+
static constexpr auto reduce = ReductionType::PROD; \
|
| 39 |
+
return __VA_ARGS__(); \
|
| 40 |
+
} \
|
| 41 |
+
} \
|
| 42 |
+
}()
|
| 43 |
+
|
| 44 |
+
template <typename scalar_t, ReductionType reduce>
|
| 45 |
+
inline vec_scalar_t<scalar_t> init_value() {
|
| 46 |
+
using acc_t = vec_scalar_t<scalar_t>;
|
| 47 |
+
acc_t val;
|
| 48 |
+
if (reduce == ReductionType::SUM ||
|
| 49 |
+
reduce == ReductionType::MEAN) {
|
| 50 |
+
val = static_cast<acc_t>(0);
|
| 51 |
+
} else if (reduce == ReductionType::PROD) {
|
| 52 |
+
val = static_cast<acc_t>(1);
|
| 53 |
+
} else if (reduce == ReductionType::MAX) {
|
| 54 |
+
val = -std::numeric_limits<acc_t>::infinity();
|
| 55 |
+
} else {
|
| 56 |
+
TORCH_INTERNAL_ASSERT(reduce == ReductionType::MIN);
|
| 57 |
+
val = std::numeric_limits<acc_t>::infinity();
|
| 58 |
+
}
|
| 59 |
+
return val;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
template <typename scalar_t, ReductionType reduce>
|
| 63 |
+
inline vec_scalar_t<scalar_t> init_value(const std::optional<Scalar>& initial) {
|
| 64 |
+
using acc_t = vec_scalar_t<scalar_t>;
|
| 65 |
+
if (initial.has_value()) {
|
| 66 |
+
return initial.value().to<acc_t>();
|
| 67 |
+
} else {
|
| 68 |
+
return init_value<scalar_t, reduce>();
|
| 69 |
+
}
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
template <typename scalar_t>
|
| 73 |
+
inline void init(scalar_t* out, int64_t size, const vec_scalar_t<scalar_t>& val) {
|
| 74 |
+
using Vec = Vectorized<vec_scalar_t<scalar_t>>;
|
| 75 |
+
map<scalar_t>(
|
| 76 |
+
[val](Vec x) { return Vec(val); },
|
| 77 |
+
out,
|
| 78 |
+
out,
|
| 79 |
+
size);
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
template <typename scalar_t, ReductionType reduce>
|
| 83 |
+
inline void init(scalar_t* out, int64_t size, const std::optional<Scalar>& initial) {
|
| 84 |
+
using acc_t = vec_scalar_t<scalar_t>;
|
| 85 |
+
acc_t val = init_value<scalar_t, reduce>(initial);
|
| 86 |
+
init(out, size, val);
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
// overload with `include_self`, used by scatter_reduce
|
| 90 |
+
template <typename scalar_t, ReductionType reduce>
|
| 91 |
+
inline void init(scalar_t* out, int64_t size, bool include_self = false) {
|
| 92 |
+
using acc_t = vec_scalar_t<scalar_t>;
|
| 93 |
+
if (!include_self) {
|
| 94 |
+
acc_t val = init_value<scalar_t, reduce>();
|
| 95 |
+
init(out, size, val);
|
| 96 |
+
}
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
template <typename scalar_t, ReductionType reduce>
|
| 100 |
+
inline void _init(scalar_t* self_ptr, at::opmath_type<scalar_t>* buffer_ptr, int64_t size, bool include_self) {
|
| 101 |
+
if (!include_self) {
|
| 102 |
+
init<at::opmath_type<scalar_t>, reduce>(buffer_ptr, size, include_self);
|
| 103 |
+
} else {
|
| 104 |
+
vec::convert(self_ptr, buffer_ptr, size);
|
| 105 |
+
}
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
template <typename scalar_t>
|
| 109 |
+
inline std::enable_if_t<!std::is_same_v<scalar_t, Vec2>, scalar_t>
|
| 110 |
+
_max(const scalar_t& x, const scalar_t& y) {
|
| 111 |
+
return at::_isnan(y) ? y : std::max(x, y);
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
template <typename scalar_t>
|
| 115 |
+
inline Vectorized<scalar_t> _max(const Vectorized<scalar_t>& x, const Vectorized<scalar_t>& y) {
|
| 116 |
+
// vec::maximum propagates NaN
|
| 117 |
+
return vec::maximum(x, y);
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
template <typename vec_t>
|
| 121 |
+
inline std::enable_if_t<std::is_same_v<vec_t, Vec2>, Vec2>
|
| 122 |
+
_max(const vec_t& x, const vec_t& y) {
|
| 123 |
+
// vec::maximum propagates NaN
|
| 124 |
+
return maximum(x, y);
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
template <typename scalar_t>
|
| 128 |
+
inline std::enable_if_t<!std::is_same_v<scalar_t, Vec2>, scalar_t>
|
| 129 |
+
_min(const scalar_t& x, const scalar_t& y) {
|
| 130 |
+
return at::_isnan(y) ? y : std::min(x, y);
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
template <typename scalar_t>
|
| 134 |
+
inline Vectorized<scalar_t> _min(const Vectorized<scalar_t>& x, const Vectorized<scalar_t>& y) {
|
| 135 |
+
// vec::minimum propagates NaN
|
| 136 |
+
return vec::minimum(x, y);
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
template <typename vec_t>
|
| 140 |
+
inline std::enable_if_t<std::is_same_v<vec_t, Vec2>, Vec2>
|
| 141 |
+
_min(const vec_t& x, const vec_t& y) {
|
| 142 |
+
// vec::minimum propagates NaN
|
| 143 |
+
return minimum(x, y);
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
template <typename scalar_t, typename accumut, typename Op,
|
| 147 |
+
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
| 148 |
+
inline void map_acc(
|
| 149 |
+
const Op& vec_fun,
|
| 150 |
+
accumut* output_data,
|
| 151 |
+
const accumut* input_data,
|
| 152 |
+
const scalar_t* input_data2,
|
| 153 |
+
int64_t size) {
|
| 154 |
+
using Vec = vec::Vectorized<scalar_t>;
|
| 155 |
+
using aVec = vec::Vectorized<accumut>;
|
| 156 |
+
int64_t d = 0;
|
| 157 |
+
constexpr int64_t kVecSize = Vec::size();
|
| 158 |
+
constexpr int64_t kaVecSize = aVec::size();
|
| 159 |
+
for (d = 0; d < size - (size % kVecSize); d += kVecSize) {
|
| 160 |
+
Vec data2_vec = Vec::loadu(input_data2 + d);
|
| 161 |
+
auto [data2_avec0, data2_avec1] = convert_to_float<scalar_t>(data2_vec);
|
| 162 |
+
aVec input_vec0 = aVec::loadu(input_data + d);
|
| 163 |
+
aVec input_vec1 = aVec::loadu(input_data + d + kaVecSize);
|
| 164 |
+
vec_fun(input_vec0, data2_avec0).store(output_data + d);
|
| 165 |
+
vec_fun(input_vec1, data2_avec1).store(output_data + d + kaVecSize);
|
| 166 |
+
}
|
| 167 |
+
if (size - d > 0) {
|
| 168 |
+
int64_t tail_size = size - d;
|
| 169 |
+
Vec data2_vec = Vec::loadu(input_data2 + d, tail_size);
|
| 170 |
+
auto [data2_avec0, data2_avec1] = convert_to_float<scalar_t>(data2_vec);
|
| 171 |
+
if (tail_size > kaVecSize) {
|
| 172 |
+
aVec input_vec0 = aVec::loadu(input_data + d);
|
| 173 |
+
aVec input_vec1 = aVec::loadu(input_data + d + kaVecSize, tail_size - kaVecSize);
|
| 174 |
+
vec_fun(input_vec0, data2_avec0).store(output_data + d);
|
| 175 |
+
vec_fun(input_vec1, data2_avec1).store(output_data + d + kaVecSize, tail_size - kaVecSize);
|
| 176 |
+
} else {
|
| 177 |
+
aVec input_vec0 = aVec::loadu(input_data + d, tail_size);
|
| 178 |
+
vec_fun(input_vec0, data2_avec0).store(output_data + d, tail_size);
|
| 179 |
+
}
|
| 180 |
+
}
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
// for Max and Min, propagate NaN:
|
| 184 |
+
template <typename T, ReductionType reduce>
|
| 185 |
+
inline T update(const T& x, const T& y) {
|
| 186 |
+
if (reduce == ReductionType::SUM ||
|
| 187 |
+
reduce == ReductionType::MEAN) {
|
| 188 |
+
return x + y;
|
| 189 |
+
} else if (reduce == ReductionType::PROD) {
|
| 190 |
+
return x * y;
|
| 191 |
+
} else if (reduce == ReductionType::MAX) {
|
| 192 |
+
return _max(x, y);
|
| 193 |
+
} else {
|
| 194 |
+
TORCH_INTERNAL_ASSERT(reduce == ReductionType::MIN);
|
| 195 |
+
return _min(x, y);
|
| 196 |
+
}
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
template <typename scalar_t, ReductionType reduce>
|
| 200 |
+
inline void update(scalar_t* out, const scalar_t* data, int64_t K) {
|
| 201 |
+
using Vec = vec::Vectorized<vec_scalar_t<scalar_t>>;
|
| 202 |
+
map2<scalar_t>(
|
| 203 |
+
[](Vec x, Vec y) { return update<Vec, reduce>(x, y); },
|
| 204 |
+
out,
|
| 205 |
+
out,
|
| 206 |
+
data,
|
| 207 |
+
K);
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
template <typename scalar_t, ReductionType reduce,
|
| 211 |
+
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
| 212 |
+
inline void update(at::opmath_type<scalar_t>* out, const scalar_t* data, int64_t K) {
|
| 213 |
+
using opmath_t = at::opmath_type<scalar_t>;
|
| 214 |
+
using Vec = vec::Vectorized<opmath_t>;
|
| 215 |
+
map_acc<scalar_t, opmath_t>(
|
| 216 |
+
[](Vec x, Vec y) { return update<Vec, reduce>(x, y); },
|
| 217 |
+
out,
|
| 218 |
+
out,
|
| 219 |
+
data,
|
| 220 |
+
K);
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
template <typename scalar_t, ReductionType reduce>
|
| 224 |
+
inline void write(scalar_t* out, int64_t count, int64_t K) {
|
| 225 |
+
using Vec = vec::Vectorized<vec_scalar_t<scalar_t>>;
|
| 226 |
+
if (reduce == ReductionType::MEAN) {
|
| 227 |
+
if (count > 0) {
|
| 228 |
+
vec::map<scalar_t>(
|
| 229 |
+
[count](Vec x) { return x / Vec(count); },
|
| 230 |
+
out,
|
| 231 |
+
out,
|
| 232 |
+
K);
|
| 233 |
+
}
|
| 234 |
+
}
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
} // namespace CPU_CAPABILITY
|
| 238 |
+
} // namespace at::native
|
| 239 |
+
|
| 240 |
+
#else
|
| 241 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 242 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.h
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/native/DispatchStub.h>
|
| 5 |
+
#include <c10/macros/Macros.h>
|
| 6 |
+
#include <c10/util/BFloat16.h>
|
| 7 |
+
#include <c10/util/Half.h>
|
| 8 |
+
|
| 9 |
+
namespace at::native {
|
| 10 |
+
#if !defined(C10_MOBILE)
|
| 11 |
+
using fp16_gemv_fn = void(*)(int, int, float, const Half*, int, const Half*, int, float, Half*, int);
|
| 12 |
+
DECLARE_DISPATCH(fp16_gemv_fn, fp16_gemv_trans_stub)
|
| 13 |
+
|
| 14 |
+
using bf16_gemv_fn = void(*)(int, int, BFloat16, const BFloat16*, int, const BFloat16*, int, BFloat16, BFloat16*, int);
|
| 15 |
+
DECLARE_DISPATCH(bf16_gemv_fn, bf16_gemv_trans_stub)
|
| 16 |
+
|
| 17 |
+
using fp16_dot_fn = float(*)(const int64_t, const Half*, const int64_t, const Half*, const int64_t);
|
| 18 |
+
DECLARE_DISPATCH(fp16_dot_fn, fp16_dot_stub)
|
| 19 |
+
|
| 20 |
+
using bf16_dot_fn = float(*)(const int64_t, const BFloat16*, const int64_t, const BFloat16*, const int64_t);
|
| 21 |
+
DECLARE_DISPATCH(bf16_dot_fn, bf16_dot_stub)
|
| 22 |
+
|
| 23 |
+
inline namespace CPU_CAPABILITY {
|
| 24 |
+
float fp16_dot_with_fp32_arith(const Half* vec1, const Half* vec2, int64_t len);
|
| 25 |
+
float bf16_dot_with_fp32_arith(const BFloat16* vec1, const BFloat16* vec2, int64_t len);
|
| 26 |
+
} // inline namespace CPU_CAPABILITY
|
| 27 |
+
#endif // !defined(C10_MOBILE)
|
| 28 |
+
} // namespace at::native
|
| 29 |
+
|
| 30 |
+
#else
|
| 31 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 32 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/SampledAddmmKernel.h
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/core/Tensor.h>
|
| 5 |
+
#include <ATen/native/DispatchStub.h>
|
| 6 |
+
|
| 7 |
+
namespace at::native {
|
| 8 |
+
|
| 9 |
+
using sampled_addmm_sparse_csr_fn = void(*)(const Tensor&, const Tensor&, const Scalar&, const Scalar&, const Tensor&);
|
| 10 |
+
|
| 11 |
+
DECLARE_DISPATCH(sampled_addmm_sparse_csr_fn, sampled_addmm_sparse_csr_stub)
|
| 12 |
+
|
| 13 |
+
} // at::native
|
| 14 |
+
|
| 15 |
+
#else
|
| 16 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 17 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/SerialStackImpl.h
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
// Copyright 2004-present Facebook. All Rights Reserved.
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <ATen/core/Tensor.h>
|
| 6 |
+
|
| 7 |
+
#include <ATen/MemoryOverlap.h>
|
| 8 |
+
#include <ATen/Parallel.h>
|
| 9 |
+
#include <ATen/TensorIterator.h>
|
| 10 |
+
#include <ATen/cpu/vec/functional.h>
|
| 11 |
+
#include <ATen/cpu/vec/vec.h>
|
| 12 |
+
#include <c10/util/irange.h>
|
| 13 |
+
|
| 14 |
+
namespace at::native::detail {
|
| 15 |
+
|
| 16 |
+
struct InputMeta {
|
| 17 |
+
void* data_ptr;
|
| 18 |
+
int64_t inner_size;
|
| 19 |
+
|
| 20 |
+
InputMeta(const Tensor& t, int64_t dim, int64_t inner)
|
| 21 |
+
: data_ptr(t.data_ptr()), inner_size(t.sizes()[dim] * inner) {}
|
| 22 |
+
};
|
| 23 |
+
|
| 24 |
+
// This kernel is used by two TensorList types:
|
| 25 |
+
// 1. stack_serial_kernel uses at::ArrayRef<Tensor>
|
| 26 |
+
// 2. Static runtime calls this kernel directly (csrc/jit/runtime/static/ops.cpp) with
|
| 27 |
+
// ProcessedNodeInputWrapper.
|
| 28 |
+
// When making changes, make sure that they are compatible with both types!
|
| 29 |
+
template <typename scalar_t, typename TensorListType>
|
| 30 |
+
void stack_serial_kernel_impl(Tensor& result, TensorListType tensors, int64_t dim) {
|
| 31 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
| 32 |
+
dim >= 0 && dim <= result.dim(),
|
| 33 |
+
"dim out of range in stack_serial_kernel_impl");
|
| 34 |
+
int64_t outer =
|
| 35 |
+
result.numel() / (result.sizes()[dim] * result.strides()[dim]);
|
| 36 |
+
scalar_t* result_data = result.data_ptr<scalar_t>();
|
| 37 |
+
int64_t ninputs = tensors.size();
|
| 38 |
+
std::vector<InputMeta> inputs;
|
| 39 |
+
inputs.reserve(ninputs);
|
| 40 |
+
for (const auto& tensor : tensors) {
|
| 41 |
+
inputs.emplace_back(tensor, dim, tensor.strides()[dim]);
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
using Vec = vec::Vectorized<scalar_t>;
|
| 45 |
+
scalar_t* result_ptr = result_data;
|
| 46 |
+
for (const auto i : c10::irange(outer)) {
|
| 47 |
+
for (const auto j : c10::irange(ninputs)) {
|
| 48 |
+
int64_t local_inner = inputs[j].inner_size;
|
| 49 |
+
scalar_t* input_ptr = (scalar_t*)(inputs[j].data_ptr) + i * local_inner;
|
| 50 |
+
|
| 51 |
+
if (local_inner < Vec::size()) {
|
| 52 |
+
for (const auto k : c10::irange(local_inner)) {
|
| 53 |
+
result_ptr[k] = input_ptr[k];
|
| 54 |
+
}
|
| 55 |
+
} else {
|
| 56 |
+
vec::map(
|
| 57 |
+
[](Vec x) { return x; }, result_ptr, input_ptr, local_inner);
|
| 58 |
+
}
|
| 59 |
+
result_ptr += local_inner;
|
| 60 |
+
}
|
| 61 |
+
}
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
// Checks to see whether native stack can be invoked under these conditions:
|
| 65 |
+
// - result and input tensors are contiguous
|
| 66 |
+
// - only one thread is used
|
| 67 |
+
// - no type promotion has to occur
|
| 68 |
+
// - tensors dtype is Double or Float
|
| 69 |
+
template <typename TensorListType>
|
| 70 |
+
bool can_use_native_serial_stack_impl(Tensor& result, TensorListType tensors, int64_t dim) {
|
| 71 |
+
TORCH_CHECK(!tensors.empty(), "expected a non-empty list of Tensors");
|
| 72 |
+
const Tensor& first_tensor = tensors[0];
|
| 73 |
+
// stack dimension should be in range [0,firstTensor.dim())
|
| 74 |
+
// dim == firstTensor.dim() is a valid input, but it is handled by default code path
|
| 75 |
+
// that uses unsqueeze
|
| 76 |
+
if (dim >= first_tensor.dim()) return false;
|
| 77 |
+
// Native stack doesn't apply any tensor is skipped.
|
| 78 |
+
if (first_tensor.numel() == 0 && first_tensor.dim() == 1) return false;
|
| 79 |
+
// there should be no type promotion
|
| 80 |
+
if (result.dtype() != first_tensor.dtype()) return false;
|
| 81 |
+
|
| 82 |
+
auto first_tensor_mem_format = first_tensor.suggest_memory_format();
|
| 83 |
+
ScalarType dtype = first_tensor.scalar_type();
|
| 84 |
+
|
| 85 |
+
if (!result.is_contiguous(first_tensor_mem_format)) {
|
| 86 |
+
return false;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
// fast path only works for Double and Float
|
| 90 |
+
if (dtype != ScalarType::Double && dtype != ScalarType::Float) {
|
| 91 |
+
return false;
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
// check remainder of inputs
|
| 95 |
+
#ifndef STRIP_ERROR_MESSAGES
|
| 96 |
+
auto const &first_tensor_shape = first_tensor.sizes();
|
| 97 |
+
#endif
|
| 98 |
+
for (const auto i : c10::irange(1, tensors.size())) {
|
| 99 |
+
auto const &tensor = tensors[i];
|
| 100 |
+
TORCH_CHECK(tensors[i].sizes() == first_tensor.sizes(),
|
| 101 |
+
"stack expects each tensor to be equal size, but got ", first_tensor_shape,
|
| 102 |
+
" at entry 0 and ", tensor.sizes(), " at entry ", i);
|
| 103 |
+
|
| 104 |
+
// every tensor must be contiguous
|
| 105 |
+
// tensor sizes and strides must be the same
|
| 106 |
+
// there should be no type promotion
|
| 107 |
+
if (!tensor.is_contiguous(first_tensor_mem_format) ||
|
| 108 |
+
tensor.strides() != first_tensor.strides() ||
|
| 109 |
+
tensor.dtype() != dtype) {
|
| 110 |
+
return false;
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
// fast native stack should only be used when it is not worth using multiple threads
|
| 115 |
+
// or there is only one thread. Note that we aren't checking result.numel() here because
|
| 116 |
+
// it may not have been resized and we want to defer that cost till later.
|
| 117 |
+
int64_t numel_in_stack = first_tensor.numel() * tensors.size();
|
| 118 |
+
return numel_in_stack < at::internal::GRAIN_SIZE || at::get_num_threads() == 1;
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
template <typename TensorListType, bool should_skip_overlap_check>
|
| 122 |
+
struct CanUseNativeSerialStack;
|
| 123 |
+
|
| 124 |
+
template <typename TensorListType>
|
| 125 |
+
struct CanUseNativeSerialStack<TensorListType, false> {
|
| 126 |
+
static bool call(Tensor& result, TensorListType tensors, int64_t dim) {
|
| 127 |
+
// Inputs cannot alias the output tensor
|
| 128 |
+
for (const auto i : c10::irange(tensors.size())) {
|
| 129 |
+
auto lap = at::get_overlap_status(result, tensors[i]);
|
| 130 |
+
TORCH_CHECK(lap != at::MemOverlapStatus::Partial &&
|
| 131 |
+
lap != at::MemOverlapStatus::Full, 0,
|
| 132 |
+
"unsupported operation: the input tensors cannot refer to any of the "
|
| 133 |
+
"output memory locations. Found overlap in input tensor ", i);
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
return can_use_native_serial_stack_impl(result, tensors, dim);
|
| 137 |
+
}
|
| 138 |
+
};
|
| 139 |
+
|
| 140 |
+
template <typename TensorListType>
|
| 141 |
+
struct CanUseNativeSerialStack<TensorListType, true> {
|
| 142 |
+
static bool call(Tensor& result, TensorListType tensors, int64_t dim) {
|
| 143 |
+
return can_use_native_serial_stack_impl(result, tensors, dim);
|
| 144 |
+
}
|
| 145 |
+
};
|
| 146 |
+
|
| 147 |
+
} // namespace at::native::detail
|
| 148 |
+
|
| 149 |
+
#else
|
| 150 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 151 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/SoftmaxKernel.h
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/native/DispatchStub.h>
|
| 5 |
+
#include <cstdint>
|
| 6 |
+
|
| 7 |
+
namespace at {
|
| 8 |
+
class Tensor;
|
| 9 |
+
|
| 10 |
+
namespace native {
|
| 11 |
+
|
| 12 |
+
using forward_fn = void (*)(const Tensor&, const Tensor&);
|
| 13 |
+
using backward_fn = void(*)(const Tensor &, const Tensor &, const Tensor&);
|
| 14 |
+
|
| 15 |
+
DECLARE_DISPATCH(forward_fn, softmax_lastdim_kernel)
|
| 16 |
+
DECLARE_DISPATCH(forward_fn, log_softmax_lastdim_kernel)
|
| 17 |
+
DECLARE_DISPATCH(backward_fn, softmax_backward_lastdim_kernel)
|
| 18 |
+
DECLARE_DISPATCH(backward_fn, log_softmax_backward_lastdim_kernel)
|
| 19 |
+
|
| 20 |
+
using forward_fn_with_dim = void(*)(const Tensor &, const Tensor &, const int64_t);
|
| 21 |
+
using backward_fn_with_dim =
|
| 22 |
+
void (*)(const Tensor&, const Tensor&, const Tensor&, const int64_t);
|
| 23 |
+
|
| 24 |
+
DECLARE_DISPATCH(forward_fn_with_dim, softmax_kernel)
|
| 25 |
+
DECLARE_DISPATCH(forward_fn_with_dim, log_softmax_kernel)
|
| 26 |
+
DECLARE_DISPATCH(backward_fn_with_dim, softmax_backward_kernel)
|
| 27 |
+
DECLARE_DISPATCH(backward_fn_with_dim, log_softmax_backward_kernel)
|
| 28 |
+
}
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
#else
|
| 32 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 33 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/SpmmReduceKernel.h
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/core/Tensor.h>
|
| 5 |
+
#include <ATen/native/DispatchStub.h>
|
| 6 |
+
#include <ATen/native/ReductionType.h>
|
| 7 |
+
|
| 8 |
+
namespace at::native {
|
| 9 |
+
|
| 10 |
+
using spmm_reduce_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
|
| 11 |
+
using spmm_reduce_arg_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
|
| 12 |
+
using spmm_reduce_backward_input_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
|
| 13 |
+
using spmm_reduce_backward_input_arg_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
|
| 14 |
+
using spmm_reduce_backward_other_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
|
| 15 |
+
|
| 16 |
+
DECLARE_DISPATCH(spmm_reduce_fn, spmm_reduce_stub)
|
| 17 |
+
DECLARE_DISPATCH(spmm_reduce_arg_fn, spmm_reduce_arg_stub)
|
| 18 |
+
DECLARE_DISPATCH(spmm_reduce_backward_input_fn, spmm_reduce_backward_input_stub)
|
| 19 |
+
DECLARE_DISPATCH(spmm_reduce_backward_input_arg_fn, spmm_reduce_backward_input_arg_stub)
|
| 20 |
+
DECLARE_DISPATCH(spmm_reduce_backward_other_fn, spmm_reduce_backward_other_stub)
|
| 21 |
+
DECLARE_DISPATCH(spmm_reduce_backward_input_arg_fn, spmm_reduce_backward_other_arg_stub)
|
| 22 |
+
|
| 23 |
+
} // at::native
|
| 24 |
+
|
| 25 |
+
#else
|
| 26 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 27 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/StackKernel.h
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
// Copyright 2004-present Facebook. All Rights Reserved.
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <ATen/core/Tensor.h>
|
| 6 |
+
#include <ATen/native/DispatchStub.h>
|
| 7 |
+
|
| 8 |
+
namespace at::native {
|
| 9 |
+
|
| 10 |
+
using stack_serial_fn = void(*)(Tensor &, TensorList, int64_t);
|
| 11 |
+
DECLARE_DISPATCH(stack_serial_fn, stack_serial_stub)
|
| 12 |
+
|
| 13 |
+
} // namespace at::native
|
| 14 |
+
|
| 15 |
+
#else
|
| 16 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 17 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/UpSampleKernelAVXAntialias.h
ADDED
|
@@ -0,0 +1,1381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
/*
|
| 3 |
+
The Python Imaging Library (PIL) is
|
| 4 |
+
|
| 5 |
+
Copyright © 1997-2011 by Secret Labs AB
|
| 6 |
+
Copyright © 1995-2011 by Fredrik Lundh
|
| 7 |
+
|
| 8 |
+
Pillow is the friendly PIL fork. It is
|
| 9 |
+
|
| 10 |
+
Copyright © 2010-2022 by Alex Clark and contributors
|
| 11 |
+
|
| 12 |
+
Like PIL, Pillow is licensed under the open source HPND License
|
| 13 |
+
*/
|
| 14 |
+
|
| 15 |
+
// This code is heavily inspired from PILLOW-SIMD's implementation:
|
| 16 |
+
// https://github.com/uploadcare/pillow-simd/blob/simd/master/src/libImaging/Resample.c
|
| 17 |
+
|
| 18 |
+
#pragma once
|
| 19 |
+
#ifdef CPU_CAPABILITY_AVX2
|
| 20 |
+
// TODO: This file only supports AVX2. We could split the AVX kernels into
|
| 21 |
+
// smaller logical blocks in order to port them into the Vec.h logic. This would
|
| 22 |
+
// allow to support other vectorization architectures and perhaps also support
|
| 23 |
+
// the non-vectorized fallback (we'd need to make sure it's not slower than the
|
| 24 |
+
// current fallback).
|
| 25 |
+
|
| 26 |
+
#include <ATen/core/Tensor.h>
|
| 27 |
+
#include <ATen/cpu/vec/intrinsics.h>
|
| 28 |
+
#include <c10/util/irange.h>
|
| 29 |
+
|
| 30 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 31 |
+
#include <ATen/Functions.h>
|
| 32 |
+
#else
|
| 33 |
+
#include <ATen/ops/empty.h>
|
| 34 |
+
#endif
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
namespace {
|
| 38 |
+
|
| 39 |
+
inline __m128i mm_cvtsi32_si128(const uint8_t* C10_RESTRICT ptr, bool i32_aligned) {
|
| 40 |
+
int32_t v;
|
| 41 |
+
if (i32_aligned) {
|
| 42 |
+
v = *(const int32_t*)ptr;
|
| 43 |
+
} else {
|
| 44 |
+
std::memcpy(&v, ptr, 4);
|
| 45 |
+
}
|
| 46 |
+
return _mm_cvtsi32_si128(v);
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
inline __m128i mm_cvtepu8_epi32(const uint8_t* C10_RESTRICT ptr, bool i32_aligned) {
|
| 50 |
+
return _mm_cvtepu8_epi32(mm_cvtsi32_si128(ptr, i32_aligned));
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
inline void _write_endline_rgb_as_uint32(
|
| 54 |
+
uint8_t* C10_RESTRICT output,
|
| 55 |
+
uint32_t data
|
| 56 |
+
) {
|
| 57 |
+
// data is (R G B X), output is (X1 X2 X3 | R1 B1 G1 R2 ...)
|
| 58 |
+
// Here we explicitly set X as R1
|
| 59 |
+
uint8_t* data_ptr = reinterpret_cast<uint8_t*>(&data);
|
| 60 |
+
data_ptr[3] = output[3];
|
| 61 |
+
std::memcpy(output, data_ptr, 4);
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
at::Tensor unpack_rgb(const at::Tensor& packed_tensor) {
|
| 65 |
+
// Convert a "packed" tensor (typically RGBRGBRGB if channels_last) into
|
| 66 |
+
// RGBARGBARGBA format where A is hard-coded to 0. Each pixel is encoded
|
| 67 |
+
// into as 32 bits. This generalizes to num_channels <= 4 and also works for
|
| 68 |
+
// non-channels_last tensors.
|
| 69 |
+
|
| 70 |
+
const uint8_t* packed = (const uint8_t*)packed_tensor.const_data_ptr<uint8_t>();
|
| 71 |
+
auto num_pixels = packed_tensor.size(1) * packed_tensor.size(2);
|
| 72 |
+
auto num_channels = packed_tensor.size(0);
|
| 73 |
+
|
| 74 |
+
constexpr int rgba_size = 4;
|
| 75 |
+
auto unpacked_tensor = at::empty({rgba_size, packed_tensor.size(1), packed_tensor.size(2)}, at::CPU(at::kByte));
|
| 76 |
+
uint8_t* unpacked = (uint8_t*) unpacked_tensor.data_ptr<uint8_t>();
|
| 77 |
+
|
| 78 |
+
auto stride_i = packed_tensor.stride(2);
|
| 79 |
+
auto stride_j = packed_tensor.stride(0);
|
| 80 |
+
|
| 81 |
+
for (const auto i : c10::irange(num_pixels)) {
|
| 82 |
+
for (const auto j : c10::irange(rgba_size)) {
|
| 83 |
+
unpacked[rgba_size * i + j] = (j < num_channels) ? packed[stride_i * i + stride_j * j] : 0;
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
return unpacked_tensor;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
void pack_rgb(
|
| 90 |
+
const at::Tensor& unpacked_tensor, // IN
|
| 91 |
+
const at::Tensor& packed_tensor // OUT
|
| 92 |
+
) {
|
| 93 |
+
// Convert from unpacked channels last 3-channels or 4-channels tensor into original data layout.
|
| 94 |
+
|
| 95 |
+
uint8_t* unpacked = (uint8_t*)unpacked_tensor.data_ptr<uint8_t>();
|
| 96 |
+
uint8_t* packed = (uint8_t*)packed_tensor.data_ptr<uint8_t>();
|
| 97 |
+
auto num_pixels = packed_tensor.size(1) * packed_tensor.size(2);
|
| 98 |
+
auto num_channels = packed_tensor.size(0);
|
| 99 |
+
|
| 100 |
+
auto unpacked_increment = unpacked_tensor.size(0);
|
| 101 |
+
auto packed_increment = packed_tensor.stride(2);
|
| 102 |
+
auto packed_stride = packed_tensor.stride(0);
|
| 103 |
+
|
| 104 |
+
TORCH_INTERNAL_ASSERT(unpacked_increment == 3 || unpacked_increment == 4);
|
| 105 |
+
|
| 106 |
+
for ([[maybe_unused]] const auto i : c10::irange(num_pixels)) {
|
| 107 |
+
for (const auto j : c10::irange(num_channels)) {
|
| 108 |
+
packed[j * packed_stride] = unpacked[j];
|
| 109 |
+
}
|
| 110 |
+
unpacked += unpacked_increment;
|
| 111 |
+
packed += packed_increment;
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
void ImagingResampleHorizontalConvolution8u4x(
|
| 116 |
+
uint8_t* C10_RESTRICT lineOut0,
|
| 117 |
+
uint8_t* C10_RESTRICT lineOut1,
|
| 118 |
+
uint8_t* C10_RESTRICT lineOut2,
|
| 119 |
+
uint8_t* C10_RESTRICT lineOut3,
|
| 120 |
+
int64_t out_xsize,
|
| 121 |
+
const uint8_t* C10_RESTRICT lineIn0,
|
| 122 |
+
const uint8_t* C10_RESTRICT lineIn1,
|
| 123 |
+
const uint8_t* C10_RESTRICT lineIn2,
|
| 124 |
+
const uint8_t* C10_RESTRICT lineIn3,
|
| 125 |
+
int64_t in_xsize,
|
| 126 |
+
const int64_t* idx_ptr_xmin,
|
| 127 |
+
const int64_t* idx_ptr_size,
|
| 128 |
+
const int16_t* kk,
|
| 129 |
+
int kmax,
|
| 130 |
+
unsigned int coefs_precision,
|
| 131 |
+
int64_t num_channels,
|
| 132 |
+
bool is_last_line);
|
| 133 |
+
|
| 134 |
+
void ImagingResampleHorizontalConvolution8u(
|
| 135 |
+
uint8_t* C10_RESTRICT lineOut,
|
| 136 |
+
int64_t out_xsize,
|
| 137 |
+
const uint8_t* C10_RESTRICT lineIn,
|
| 138 |
+
int64_t in_xsize,
|
| 139 |
+
const int64_t* idx_ptr_xmin,
|
| 140 |
+
const int64_t* idx_ptr_size,
|
| 141 |
+
const int16_t* kk,
|
| 142 |
+
int kmax,
|
| 143 |
+
unsigned int coefs_precision,
|
| 144 |
+
int64_t num_channels,
|
| 145 |
+
bool is_last_line);
|
| 146 |
+
|
| 147 |
+
void ImagingResampleVerticalConvolution8u(
|
| 148 |
+
uint8_t* C10_RESTRICT lineOut,
|
| 149 |
+
const uint8_t* C10_RESTRICT lineIn,
|
| 150 |
+
int64_t xsize,
|
| 151 |
+
int64_t ids_min,
|
| 152 |
+
int64_t ids_size,
|
| 153 |
+
const int16_t* k,
|
| 154 |
+
unsigned int coefs_precision,
|
| 155 |
+
int64_t num_channels);
|
| 156 |
+
|
| 157 |
+
template<int num_channels>
|
| 158 |
+
void ImagingResampleHorizontal(
|
| 159 |
+
const at::Tensor & unpacked_output,
|
| 160 |
+
const at::Tensor & unpacked_input,
|
| 161 |
+
int ksize,
|
| 162 |
+
const std::vector<at::Tensor>& horiz_indices_weights,
|
| 163 |
+
unsigned int horiz_weights_precision) {
|
| 164 |
+
|
| 165 |
+
// Interpolation horizontal pass: we compute x-axis (image width) interpolation outputs.
|
| 166 |
+
|
| 167 |
+
// Input data is stored as
|
| 168 |
+
// input = [r[0], g[0], b[0], a[0], r[1], g[1], b[1], a[1], r[2], g[2], b[2], a[2], ...]
|
| 169 |
+
// Weights are float values computed for each output pixel and rescaled to uint16:
|
| 170 |
+
// weights[i] = [w[i, 0], w[i, 1], ..., w[i, K-1]]
|
| 171 |
+
// We want to compute the output as following:
|
| 172 |
+
// output = [oR[0], oG[0], oB[0], oA[0], oR[1], oG[1], oB[1], oA[1], ...]
|
| 173 |
+
// where
|
| 174 |
+
// oR[yoffset + i] = r[yoffset + xmin[i]] * w[i, 0] + ... + r[yoffset + xmin[i] + K-1] * w[i, K-1]
|
| 175 |
+
// oG[yoffset + i] = g[yoffset + xmin[i]] * w[i, 0] + ... + g[yoffset + xmin[i] + K-1] * w[i, K-1]
|
| 176 |
+
// oB[yoffset + i] = b[yoffset + xmin[i]] * w[i, 0] + ... + b[yoffset + xmin[i] + K-1] * w[i, K-1]
|
| 177 |
+
//
|
| 178 |
+
|
| 179 |
+
// TODO: we may want to merge that into the fallback code (currently called
|
| 180 |
+
// basic_loop_aa_horizontal<uint8_t>)
|
| 181 |
+
// Although this may not be needed if / when we port all this code to use
|
| 182 |
+
// Vec.h since this would potentially give us another fall-back implem
|
| 183 |
+
|
| 184 |
+
const int16_t* kk = (int16_t*)(horiz_indices_weights[3].const_data_ptr<double>());
|
| 185 |
+
|
| 186 |
+
auto xout = unpacked_output.size(2);
|
| 187 |
+
auto yout = unpacked_output.size(1);
|
| 188 |
+
auto xin = unpacked_input.size(2);
|
| 189 |
+
TORCH_INTERNAL_ASSERT(num_channels == unpacked_input.size(0));
|
| 190 |
+
|
| 191 |
+
const int64_t* idx_ptr_xmin = horiz_indices_weights[0].const_data_ptr<int64_t>();
|
| 192 |
+
const int64_t* idx_ptr_size = horiz_indices_weights[1].const_data_ptr<int64_t>();
|
| 193 |
+
|
| 194 |
+
uint8_t* unpacked_output_p = unpacked_output.data_ptr<uint8_t>();
|
| 195 |
+
const uint8_t* unpacked_input_p = unpacked_input.const_data_ptr<uint8_t>();
|
| 196 |
+
|
| 197 |
+
int64_t yy = 0;
|
| 198 |
+
auto xout_stride = xout * num_channels;
|
| 199 |
+
auto xin_stride = xin * num_channels;
|
| 200 |
+
for (; yy < yout - 3; yy += 4) {
|
| 201 |
+
ImagingResampleHorizontalConvolution8u4x(
|
| 202 |
+
unpacked_output_p + yy * xout_stride,
|
| 203 |
+
unpacked_output_p + (yy + 1) * xout_stride,
|
| 204 |
+
unpacked_output_p + (yy + 2) * xout_stride,
|
| 205 |
+
unpacked_output_p + (yy + 3) * xout_stride,
|
| 206 |
+
xout,
|
| 207 |
+
unpacked_input_p + yy * xin_stride,
|
| 208 |
+
unpacked_input_p + (yy + 1) * xin_stride,
|
| 209 |
+
unpacked_input_p + (yy + 2) * xin_stride,
|
| 210 |
+
unpacked_input_p + (yy + 3) * xin_stride,
|
| 211 |
+
xin,
|
| 212 |
+
idx_ptr_xmin,
|
| 213 |
+
idx_ptr_size,
|
| 214 |
+
kk,
|
| 215 |
+
ksize,
|
| 216 |
+
horiz_weights_precision,
|
| 217 |
+
num_channels,
|
| 218 |
+
yy + 3 == yout - 1);
|
| 219 |
+
}
|
| 220 |
+
for (; yy < yout; yy++) {
|
| 221 |
+
ImagingResampleHorizontalConvolution8u(
|
| 222 |
+
unpacked_output_p + yy * xout_stride,
|
| 223 |
+
xout,
|
| 224 |
+
unpacked_input_p + yy * xin_stride,
|
| 225 |
+
xin,
|
| 226 |
+
idx_ptr_xmin,
|
| 227 |
+
idx_ptr_size,
|
| 228 |
+
kk,
|
| 229 |
+
ksize,
|
| 230 |
+
horiz_weights_precision,
|
| 231 |
+
num_channels,
|
| 232 |
+
yy == yout - 1);
|
| 233 |
+
}
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
void ImagingResampleVertical(
|
| 237 |
+
const at::Tensor & unpacked_output,
|
| 238 |
+
const at::Tensor & unpacked_input,
|
| 239 |
+
int ksize,
|
| 240 |
+
const std::vector<at::Tensor>& vert_indices_weights,
|
| 241 |
+
unsigned int vert_weights_precision) {
|
| 242 |
+
|
| 243 |
+
// Interpolation vertical pass: we compute y-axis interpolation outputs.
|
| 244 |
+
// Input data is stored as
|
| 245 |
+
// input = [r[0], g[0], b[0], a[0], r[1], g[1], b[1], a[1], r[2], g[2], b[2], a[2], ...]
|
| 246 |
+
// Weights are float values computed for each output pixel and rescaled to uint16:
|
| 247 |
+
// weights[i] = [w[i, 0], w[i, 1], ..., w[i, K-1]]
|
| 248 |
+
// We want to compute the output as following:
|
| 249 |
+
// output = [oR[0], oG[0], oB[0], oA[0], oR[1], oG[1], oB[1], oA[1], ...]
|
| 250 |
+
// where
|
| 251 |
+
// oR[xoffset + i] = r[xoffset + ymin[i]] * w[i, 0] + ... + r[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1]
|
| 252 |
+
// oG[xoffset + i] = g[xoffset + ymin[i]] * w[i, 0] + ... + g[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1]
|
| 253 |
+
// oB[xoffset + i] = b[xoffset + ymin[i]] * w[i, 0] + ... + b[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1]
|
| 254 |
+
|
| 255 |
+
// TODO: we may want to merge that into the fallback code (currently called
|
| 256 |
+
// basic_loop_aa_vertical<uint8_t>)
|
| 257 |
+
// Although this may not be needed if / when we port all this code to use
|
| 258 |
+
// Vec.h since this would potentially give us another fall-back implem
|
| 259 |
+
const int16_t* kk = (int16_t*)(vert_indices_weights[3].const_data_ptr<double>());
|
| 260 |
+
|
| 261 |
+
const int64_t* idx_ptr_xmin = vert_indices_weights[0].const_data_ptr<int64_t>();
|
| 262 |
+
const int64_t* idx_ptr_size = vert_indices_weights[1].const_data_ptr<int64_t>();
|
| 263 |
+
|
| 264 |
+
uint8_t* unpacked_output_p = unpacked_output.data_ptr<uint8_t>();
|
| 265 |
+
const uint8_t* unpacked_input_p = unpacked_input.const_data_ptr<uint8_t>();
|
| 266 |
+
|
| 267 |
+
auto xout = unpacked_output.size(2);
|
| 268 |
+
auto yout = unpacked_output.size(1);
|
| 269 |
+
const auto num_channels = unpacked_input.size(0);
|
| 270 |
+
TORCH_INTERNAL_ASSERT(num_channels == unpacked_output.size(0));
|
| 271 |
+
|
| 272 |
+
auto xout_stride = xout * num_channels;
|
| 273 |
+
for (const auto yy : c10::irange(yout)) {
|
| 274 |
+
const auto* k = &kk[yy * ksize];
|
| 275 |
+
auto ids_min = idx_ptr_xmin[yy];
|
| 276 |
+
auto ids_size = idx_ptr_size[yy];
|
| 277 |
+
ImagingResampleVerticalConvolution8u(
|
| 278 |
+
unpacked_output_p + yy * xout_stride,
|
| 279 |
+
unpacked_input_p,
|
| 280 |
+
xout,
|
| 281 |
+
ids_min,
|
| 282 |
+
ids_size,
|
| 283 |
+
k,
|
| 284 |
+
vert_weights_precision,
|
| 285 |
+
num_channels);
|
| 286 |
+
}
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
// This is the only public entry point in this file. It supports bilinear or bicubic
|
| 290 |
+
// mode for uint8 dtype when C <= 4, with or without antialias. The
|
| 291 |
+
// implem is based on PIL-SIMD.
|
| 292 |
+
// Its equivalent implementation (fallback) for when AVX isn't supported or when
|
| 293 |
+
// C > 4 is separable_upsample_generic_Nd_kernel_impl() There are a bunch of
|
| 294 |
+
// future improvement that can be done: look for the TODOs in this file.
|
| 295 |
+
// For details on how the weights are computed and how the multiplications are
|
| 296 |
+
// run on int (instead of float weights), see
|
| 297 |
+
// [ Weights computation for uint8_t and multiplication trick ]
|
| 298 |
+
// For details on how the AVX kernels are implemented, see
|
| 299 |
+
// https://gist.github.com/NicolasHug/47c97d731f05eaad5694c173849b86f5
|
| 300 |
+
// See also [ Support for antialias=False as a subcase of antialias=True ] to
|
| 301 |
+
// learn more about how the antialias=False case is computed. The same holds
|
| 302 |
+
// here: all these kernels are general enough to handle an arbitrary number of
|
| 303 |
+
// weights, but when aa=False they could be optimized further.
|
| 304 |
+
template <typename scale_type, class F>
|
| 305 |
+
void upsample_avx_bilinear_bicubic_uint8(
|
| 306 |
+
const at::Tensor& input_,
|
| 307 |
+
const at::Tensor& output,
|
| 308 |
+
bool align_corners,
|
| 309 |
+
const scale_type& scales,
|
| 310 |
+
bool antialias) {
|
| 311 |
+
auto batch_size = input_.size(0);
|
| 312 |
+
auto num_channels = input_.size(1);
|
| 313 |
+
auto xin = input_.size(3);
|
| 314 |
+
auto yin = input_.size(2);
|
| 315 |
+
auto xout = output.size(3);
|
| 316 |
+
auto yout = output.size(2);
|
| 317 |
+
|
| 318 |
+
if (xin == xout && yin == yout) {
|
| 319 |
+
output.copy_(input_);
|
| 320 |
+
return;
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
at::Tensor input = input_;
|
| 324 |
+
if (!(input.is_contiguous() || input.is_contiguous(at::MemoryFormat::ChannelsLast))) {
|
| 325 |
+
// If input is not contiguous with memory format channels first or channels last,
|
| 326 |
+
// we explicitly convert the input to contiguous channels last memory format.
|
| 327 |
+
// This simplifies the rest of the code and let us assume that the format is only contiguous channels first or channels last,
|
| 328 |
+
// Most tensors going through this `if` block won't need to go through unpacking, but those having C < 3 may
|
| 329 |
+
// have to (this means 2 copies are made). We could avoid the extra copy by handling non-contiguous input
|
| 330 |
+
// directly within unpack_rgb() and pack_rgb(), but initial attempts showed that this is fairly complex.
|
| 331 |
+
input = input.contiguous(at::MemoryFormat::ChannelsLast);
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
auto need_horizontal = xout != xin;
|
| 335 |
+
auto need_vertical = yout != yin;
|
| 336 |
+
|
| 337 |
+
int ksize_horiz, ksize_vert;
|
| 338 |
+
std::vector<at::Tensor> horiz_indices_weights, vert_indices_weights;
|
| 339 |
+
unsigned int horiz_weights_precision, vert_weights_precision;
|
| 340 |
+
|
| 341 |
+
bool skip_unpacking = (num_channels == 3 || num_channels == 4) && input.is_contiguous(at::MemoryFormat::ChannelsLast);
|
| 342 |
+
bool skip_packing = (num_channels == 3 || num_channels == 4) && output.is_contiguous(at::MemoryFormat::ChannelsLast);
|
| 343 |
+
|
| 344 |
+
if (need_horizontal) {
|
| 345 |
+
int interp_dim = 3;
|
| 346 |
+
auto stride = skip_unpacking ? num_channels : 4;
|
| 347 |
+
std::tie(horiz_indices_weights, ksize_horiz, horiz_weights_precision) =
|
| 348 |
+
F::compute_index_ranges_int16_weights(
|
| 349 |
+
/*input_size=*/xin,
|
| 350 |
+
/*output_size=*/xout,
|
| 351 |
+
/*stride=*/stride,
|
| 352 |
+
/*ndims=*/4,
|
| 353 |
+
/*reshape_dim=*/interp_dim,
|
| 354 |
+
/*align_corners=*/align_corners,
|
| 355 |
+
/*opt_scale=*/scales[interp_dim - 2],
|
| 356 |
+
/*antialias=*/antialias,
|
| 357 |
+
/*align_i32=*/true);
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
if (need_vertical) {
|
| 361 |
+
int interp_dim = 2;
|
| 362 |
+
auto stride = skip_unpacking ? num_channels * xout : 4 * xout;
|
| 363 |
+
std::tie(vert_indices_weights, ksize_vert, vert_weights_precision) =
|
| 364 |
+
F::compute_index_ranges_int16_weights(
|
| 365 |
+
/*input_size=*/yin,
|
| 366 |
+
/*output_size=*/yout,
|
| 367 |
+
/*stride=*/stride,
|
| 368 |
+
/*ndims=*/4,
|
| 369 |
+
/*reshape_dim=*/interp_dim,
|
| 370 |
+
/*align_corners=*/align_corners,
|
| 371 |
+
/*opt_scale=*/scales[interp_dim - 2],
|
| 372 |
+
/*antialias=*/antialias,
|
| 373 |
+
/*align_i32=*/true);
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
at::Tensor buffer_horiz, buffer_vert;
|
| 377 |
+
// Minor optimization: we can avoid allocating an extra buffer if we're performing
|
| 378 |
+
// horizontal-only or vertical-only interpolation, and if the tensor doesn't
|
| 379 |
+
// need repacking
|
| 380 |
+
if (need_horizontal && (need_vertical || !skip_packing)) {
|
| 381 |
+
auto c = skip_unpacking ? num_channels : 4;
|
| 382 |
+
buffer_horiz = at::empty({c, yin, xout}, input.options());
|
| 383 |
+
}
|
| 384 |
+
if (need_vertical && !skip_packing) {
|
| 385 |
+
auto c = skip_unpacking ? num_channels : 4;
|
| 386 |
+
buffer_vert = at::empty({c, yout, xout}, input.options());
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
for (const auto i : c10::irange(batch_size)) {
|
| 390 |
+
|
| 391 |
+
at::Tensor unpacked_input = skip_unpacking ? input[i] : unpack_rgb(input[i]);
|
| 392 |
+
at::Tensor unpacked_output;
|
| 393 |
+
|
| 394 |
+
if (need_horizontal) {
|
| 395 |
+
at::Tensor unpacked_output_temp = (need_vertical || !skip_packing) ? buffer_horiz : output[i];
|
| 396 |
+
|
| 397 |
+
if (skip_unpacking && num_channels == 3) {
|
| 398 |
+
ImagingResampleHorizontal<3>(
|
| 399 |
+
unpacked_output_temp,
|
| 400 |
+
unpacked_input,
|
| 401 |
+
ksize_horiz,
|
| 402 |
+
horiz_indices_weights,
|
| 403 |
+
horiz_weights_precision);
|
| 404 |
+
} else {
|
| 405 |
+
ImagingResampleHorizontal<4>(
|
| 406 |
+
unpacked_output_temp,
|
| 407 |
+
unpacked_input,
|
| 408 |
+
ksize_horiz,
|
| 409 |
+
horiz_indices_weights,
|
| 410 |
+
horiz_weights_precision);
|
| 411 |
+
}
|
| 412 |
+
unpacked_output = unpacked_input = unpacked_output_temp;
|
| 413 |
+
}
|
| 414 |
+
if (need_vertical) {
|
| 415 |
+
unpacked_output = skip_packing ? output[i] : buffer_vert;
|
| 416 |
+
|
| 417 |
+
ImagingResampleVertical(
|
| 418 |
+
unpacked_output,
|
| 419 |
+
unpacked_input,
|
| 420 |
+
ksize_vert,
|
| 421 |
+
vert_indices_weights,
|
| 422 |
+
vert_weights_precision
|
| 423 |
+
);
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
TORCH_INTERNAL_ASSERT(unpacked_output.defined());
|
| 427 |
+
|
| 428 |
+
if (!skip_packing) {
|
| 429 |
+
pack_rgb(unpacked_output, output[i]);
|
| 430 |
+
}
|
| 431 |
+
}
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
void ImagingResampleHorizontalConvolution8u4x(
|
| 435 |
+
uint8_t* C10_RESTRICT lineOut0,
|
| 436 |
+
uint8_t* C10_RESTRICT lineOut1,
|
| 437 |
+
uint8_t* C10_RESTRICT lineOut2,
|
| 438 |
+
uint8_t* C10_RESTRICT lineOut3,
|
| 439 |
+
int64_t out_xsize,
|
| 440 |
+
const uint8_t* C10_RESTRICT lineIn0,
|
| 441 |
+
const uint8_t* C10_RESTRICT lineIn1,
|
| 442 |
+
const uint8_t* C10_RESTRICT lineIn2,
|
| 443 |
+
const uint8_t* C10_RESTRICT lineIn3,
|
| 444 |
+
int64_t in_xsize,
|
| 445 |
+
const int64_t* idx_ptr_xmin,
|
| 446 |
+
const int64_t* idx_ptr_size,
|
| 447 |
+
const int16_t* kk,
|
| 448 |
+
int kmax,
|
| 449 |
+
unsigned int coefs_precision,
|
| 450 |
+
int64_t num_channels,
|
| 451 |
+
bool is_last_line) {
|
| 452 |
+
|
| 453 |
+
// Interpolation horizontal pass processing together 4 vertical lines.
|
| 454 |
+
// - Input data format is RGBA or RGB with R,G,B,A being uint8. In case of RGBA
|
| 455 |
+
// we can encode 4 values as a single uint32 value.
|
| 456 |
+
// - We split the size of weight vector for a given output index as a sum:
|
| 457 |
+
// ids_size = num_blocks_4 * 4 + num_blocks_2 * 2 + num_blocks_1.
|
| 458 |
+
// - We load and process 4 weights values in a loop ("block 4") then we process 2 weights values
|
| 459 |
+
// in another loop ("block 2") and finally we process 1 weights value in the final loop ("block 1").
|
| 460 |
+
|
| 461 |
+
// Define shuffling masks (low/high) for num_channels 4 and 3
|
| 462 |
+
// Mask low casts lower half of each lane to epi16 and reorder RGBARGBA -> RRGGBBAA:
|
| 463 |
+
// [r1 g1 b1 a1 r2 g2 b2 a2 ... | R1 G1 B1 A1 R2 G2 B2 A2 ... ] ->
|
| 464 |
+
// [r1 0 r2 0 g1 0 g2 0 b1 0 b2 0 a1 0 a2 0 | R1 0 R2 0 G1 0 G2 0 B1 0 B2 0 A1 0 A2 0]
|
| 465 |
+
// Mask high casts upper half of each lane to epi16 and reorder RGBARGBA -> RRGGBBAA::
|
| 466 |
+
// [ ... r3 g3 b3 a3 r4 g4 b4 a4 | ... R3 G3 B3 A3 R4 G4 B4 A4 ] ->
|
| 467 |
+
// [r3 0 r4 0 g3 0 g4 0 b3 0 b4 0 a3 0 a4 0 | R3 0 R4 0 G3 0 G4 0 B3 0 B4 0 A3 0 A4 0]
|
| 468 |
+
|
| 469 |
+
const auto mask_low_c4 = _mm256_set_epi8(
|
| 470 |
+
-1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0,
|
| 471 |
+
-1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);
|
| 472 |
+
const auto mask_high_c4 = _mm256_set_epi8(
|
| 473 |
+
-1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8,
|
| 474 |
+
-1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8);
|
| 475 |
+
const auto mask_low_c3 = _mm256_set_epi8(
|
| 476 |
+
-1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0,
|
| 477 |
+
-1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0);
|
| 478 |
+
const auto mask_high_c3 = _mm256_set_epi8(
|
| 479 |
+
-1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6,
|
| 480 |
+
-1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6);
|
| 481 |
+
|
| 482 |
+
const auto mask_low = (num_channels == 3) ? mask_low_c3 : mask_low_c4;
|
| 483 |
+
const auto mask_high = (num_channels == 3) ? mask_high_c3 : mask_high_c4;
|
| 484 |
+
|
| 485 |
+
const auto stride = num_channels * sizeof(uint8_t);
|
| 486 |
+
|
| 487 |
+
TORCH_INTERNAL_ASSERT(stride == 3 || stride == 4);
|
| 488 |
+
|
| 489 |
+
// out_xsize = output width, out_x = output x index
|
| 490 |
+
// ids_min is the input offset index corresponding to out_x
|
| 491 |
+
// ids_size is the interpolation size for out_x
|
| 492 |
+
|
| 493 |
+
// Let's precompute ids_size limits for block 4 and block 2.
|
| 494 |
+
//
|
| 495 |
+
// In block 4 (4 means we process 4 weight values together), we read input data
|
| 496 |
+
// with _mm_loadu_si128, i.e. 16 bytes, per one line:
|
| 497 |
+
// lineIn0 + stride * (i + ids_min) + 16 <= lineIn0 + stride * (ids_size + ids_min)
|
| 498 |
+
// --> i <= ids_size - 16.0 / stride
|
| 499 |
+
// Strict boundary:
|
| 500 |
+
// --> i < ids_size + 1 - int(ceil(16.0 / stride)) = ids_size - b4_delta
|
| 501 |
+
// Soft boundary for reading inside the buffer except its boundaries:
|
| 502 |
+
// --> i < ids_size + 1 - int(16.0 / stride) = ids_size - b4_delta_soft
|
| 503 |
+
// RGBA: b4_delta = b4_delta_soft = 3
|
| 504 |
+
// RGB : b4_delta = 5
|
| 505 |
+
// RGB : b4_delta_soft = 4
|
| 506 |
+
const auto b4_delta = (stride == 4) ? 3 : (is_last_line ? 5 : 4);
|
| 507 |
+
|
| 508 |
+
// In block 2 (2 means we process 2 weights values together), we read input data
|
| 509 |
+
// with _mm_loadl_epi64, i.e. 8 bytes, per one line:
|
| 510 |
+
// lineIn0 + stride * (i + ids_min) + 8 <= lineIn0 + stride * (ids_size + ids_min)
|
| 511 |
+
// --> i <= ids_size - 8.0 / stride
|
| 512 |
+
// Strict boundary:
|
| 513 |
+
// --> i < ids_size + 1 - int(ceil(8.0 / stride)) = ids_size - b2_delta
|
| 514 |
+
// Soft boundary for reading inside the buffer except its boundaries:
|
| 515 |
+
// --> i < ids_size + 1 - int(8.0 / stride) = ids_size - b2_delta_soft
|
| 516 |
+
// RGBA: b2_delta = b2_delta_soft = 1
|
| 517 |
+
// RGB : b2_delta = 2
|
| 518 |
+
// RGB : b2_delta_soft = 1
|
| 519 |
+
const auto b2_delta = (stride == 4) ? 1 : (is_last_line ? 2 : 1);
|
| 520 |
+
|
| 521 |
+
const auto max_out_x_strided = out_xsize * stride;
|
| 522 |
+
const auto max_in_x_strided = in_xsize * stride;
|
| 523 |
+
|
| 524 |
+
const auto zero = _mm256_setzero_si256();
|
| 525 |
+
const auto initial = _mm256_set1_epi32(1 << (coefs_precision - 1));
|
| 526 |
+
|
| 527 |
+
for (const auto out_x : c10::irange(out_xsize)) {
|
| 528 |
+
const auto ids_min = idx_ptr_xmin[out_x];
|
| 529 |
+
const auto ids_size = idx_ptr_size[out_x];
|
| 530 |
+
const auto * k = &kk[out_x * kmax];
|
| 531 |
+
int64_t i = 0;
|
| 532 |
+
|
| 533 |
+
auto sss0 = initial;
|
| 534 |
+
auto sss1 = initial;
|
| 535 |
+
|
| 536 |
+
const auto * lineIn0_min = lineIn0 + ids_min;
|
| 537 |
+
const auto * lineIn1_min = lineIn1 + ids_min;
|
| 538 |
+
const auto * lineIn2_min = lineIn2 + ids_min;
|
| 539 |
+
const auto * lineIn3_min = lineIn3 + ids_min;
|
| 540 |
+
|
| 541 |
+
// block 4
|
| 542 |
+
for (; i < ids_size - b4_delta; i += 4) {
|
| 543 |
+
// Load 4 values from weight vector
|
| 544 |
+
// mmk0 = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ...]
|
| 545 |
+
// mmk1 = [wl_2 wh_2 wl_3 wh_3 wl_2 wh_2 wl_3 wh_3 ...]
|
| 546 |
+
const auto mmk0 = _mm256_set1_epi32(*(int32_t*)&k[i]);
|
| 547 |
+
const auto mmk1 = _mm256_set1_epi32(*(int32_t*)&k[i + 2]);
|
| 548 |
+
|
| 549 |
+
// RGBA: Load 8 pixels (4 per line) from input lines 0 and 1:
|
| 550 |
+
// source = [
|
| 551 |
+
// r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3
|
| 552 |
+
// R0 G0 B0 A0 R1 G1 B1 A1 R2 G2 B2 A2 R3 G3 B3 A3
|
| 553 |
+
// ]
|
| 554 |
+
// RGB: Load 10 pixels (5 per line)
|
| 555 |
+
// source = [
|
| 556 |
+
// r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5
|
| 557 |
+
// R0 G0 B0 R1 G1 B1 R2 G2 B2 R3 G3 B3 R4 G4 B4 R5
|
| 558 |
+
// ]
|
| 559 |
+
auto source = _mm256_inserti128_si256(_mm256_castsi128_si256(
|
| 560 |
+
_mm_loadu_si128((__m128i *) (lineIn0_min + stride * i))),
|
| 561 |
+
_mm_loadu_si128((__m128i *) (lineIn1_min + stride * i)), 1);
|
| 562 |
+
|
| 563 |
+
// Apply mask_low:
|
| 564 |
+
// RGBA:
|
| 565 |
+
// [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 A0 0 A1 0]
|
| 566 |
+
// RGB:
|
| 567 |
+
// [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 0 0 0 0]
|
| 568 |
+
auto pix1 = _mm256_shuffle_epi8(source, mask_low);
|
| 569 |
+
// Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision
|
| 570 |
+
sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk0));
|
| 571 |
+
|
| 572 |
+
// Apply mask_high:
|
| 573 |
+
// RGBA:
|
| 574 |
+
// [r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 a2 0 a3 0 | R2 0 R3 0 G2 0 G3 0 B2 0 B3 0 A2 0 A3 0]
|
| 575 |
+
// RGB:
|
| 576 |
+
// [r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 0 0 0 0 | R2 0 R3 0 G2 0 G3 0 B2 0 B3 0 0 0 0 0]
|
| 577 |
+
auto pix2 = _mm256_shuffle_epi8(source, mask_high);
|
| 578 |
+
// Compute output value as C += w2 * C2 + w3 * C3 for each channel in 32-bit precision
|
| 579 |
+
sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix2, mmk1));
|
| 580 |
+
|
| 581 |
+
// Same as above to next lines 2 and 3:
|
| 582 |
+
auto source2 = _mm256_inserti128_si256(_mm256_castsi128_si256(
|
| 583 |
+
_mm_loadu_si128((__m128i *) (lineIn2_min + stride * i))),
|
| 584 |
+
_mm_loadu_si128((__m128i *) (lineIn3_min + stride * i)), 1);
|
| 585 |
+
auto pix3 = _mm256_shuffle_epi8(source2, mask_low);
|
| 586 |
+
sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix3, mmk0));
|
| 587 |
+
auto pix4 = _mm256_shuffle_epi8(source2, mask_high);
|
| 588 |
+
sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix4, mmk1));
|
| 589 |
+
}
|
| 590 |
+
|
| 591 |
+
// block 2
|
| 592 |
+
for (; i < ids_size - b2_delta; i += 2) {
|
| 593 |
+
// Load 2 values from weight vector
|
| 594 |
+
// mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ...]
|
| 595 |
+
const auto mmk = _mm256_set1_epi32(*(int32_t*)&k[i]);
|
| 596 |
+
|
| 597 |
+
// Load 4 pixels (2 per line) from input lines 0 and 1:
|
| 598 |
+
// RGBA: source1 = [
|
| 599 |
+
// r0 g0 b0 a0 r1 g1 b1 a1 0 0 0 0 0 0 0 0
|
| 600 |
+
// R0 G0 B0 A0 R1 G1 B1 A1 0 0 0 0 0 0 0 0
|
| 601 |
+
// ]
|
| 602 |
+
// RGB: source1 = [
|
| 603 |
+
// r0 g0 b0 r1 g1 b1 r2 0 0 0 0 0 0 0 0
|
| 604 |
+
// R0 G0 B0 R1 G1 B1 R2 0 0 0 0 0 0 0 0
|
| 605 |
+
// ]
|
| 606 |
+
auto source1 = _mm256_inserti128_si256(_mm256_castsi128_si256(
|
| 607 |
+
_mm_loadl_epi64((__m128i *) (lineIn0_min + stride * i))),
|
| 608 |
+
_mm_loadl_epi64((__m128i *) (lineIn1_min + stride * i)), 1);
|
| 609 |
+
// Apply mask_low:
|
| 610 |
+
// RGBA:
|
| 611 |
+
// [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 A0 0 A1 0]
|
| 612 |
+
// RGB:
|
| 613 |
+
// [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 0 0 0 0]
|
| 614 |
+
auto pix1 = _mm256_shuffle_epi8(source1, mask_low);
|
| 615 |
+
// Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision
|
| 616 |
+
sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk));
|
| 617 |
+
|
| 618 |
+
// Same as above for lines 2 and 3:
|
| 619 |
+
auto source2 = _mm256_inserti128_si256(_mm256_castsi128_si256(
|
| 620 |
+
_mm_loadl_epi64((__m128i *) (lineIn2_min + stride * i))),
|
| 621 |
+
_mm_loadl_epi64((__m128i *) (lineIn3_min + stride * i)), 1);
|
| 622 |
+
auto pix2 = _mm256_shuffle_epi8(source2, mask_low);
|
| 623 |
+
sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
|
| 624 |
+
}
|
| 625 |
+
|
| 626 |
+
// block 1
|
| 627 |
+
const auto i32_aligned = num_channels == 4;
|
| 628 |
+
for (; i < ids_size - 1; i++) {
|
| 629 |
+
// Load 1 value from weight vector
|
| 630 |
+
// mmk = [wl_0 wh_0 0 0 wl_0 wh_0 0 0 ...]
|
| 631 |
+
const auto mmk = _mm256_set1_epi32(k[i]);
|
| 632 |
+
|
| 633 |
+
// Load 2 pixels (one per line) from input lines 0 and 1:
|
| 634 |
+
// RGBA: pix1 = [
|
| 635 |
+
// r0 0 0 0 g0 0 0 0 b0 0 0 0 a0 0 0 0
|
| 636 |
+
// R0 0 0 0 G0 0 0 0 B0 0 0 0 A0 0 0 0
|
| 637 |
+
// ]
|
| 638 |
+
// RGB: pix1 = [
|
| 639 |
+
// r0 0 0 0 g0 0 0 0 b0 0 0 0 r1 0 0 0
|
| 640 |
+
// R0 0 0 0 G0 0 0 0 B0 0 0 0 R1 0 0 0
|
| 641 |
+
// ]
|
| 642 |
+
auto pix1 = _mm256_inserti128_si256(_mm256_castsi128_si256(
|
| 643 |
+
mm_cvtepu8_epi32(lineIn0_min + stride * i, i32_aligned)),
|
| 644 |
+
mm_cvtepu8_epi32(lineIn1_min + stride * i, i32_aligned), 1);
|
| 645 |
+
// Compute output value as C += w0 * C0 for each channel in 32-bit precision
|
| 646 |
+
sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk));
|
| 647 |
+
|
| 648 |
+
// Same as above for lines 2 and 3
|
| 649 |
+
auto pix2 = _mm256_inserti128_si256(_mm256_castsi128_si256(
|
| 650 |
+
mm_cvtepu8_epi32(lineIn2_min + stride * i, i32_aligned)),
|
| 651 |
+
mm_cvtepu8_epi32(lineIn3_min + stride * i, i32_aligned), 1);
|
| 652 |
+
sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
|
| 653 |
+
}
|
| 654 |
+
|
| 655 |
+
if (i == ids_size - 1) {
|
| 656 |
+
// last element
|
| 657 |
+
auto mmk = _mm256_set1_epi32(k[i]);
|
| 658 |
+
// For num_channels == 3 (3 bytes = one pixel) we tolerate to read 4 bytes
|
| 659 |
+
// lines 0, 1 and 2 won't go out of allocated memory bounds
|
| 660 |
+
auto pix = _mm256_inserti128_si256(_mm256_castsi128_si256(
|
| 661 |
+
mm_cvtepu8_epi32(lineIn0_min + stride * i, i32_aligned)),
|
| 662 |
+
mm_cvtepu8_epi32(lineIn1_min + stride * i, i32_aligned), 1);
|
| 663 |
+
sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix, mmk));
|
| 664 |
+
|
| 665 |
+
auto p0 = mm_cvtepu8_epi32(lineIn2_min + stride * i, i32_aligned);
|
| 666 |
+
__m128i p1;
|
| 667 |
+
if (num_channels == 3 && C10_UNLIKELY(is_last_line && ids_min + stride * i + 4 >= max_in_x_strided)) {
|
| 668 |
+
uint8_t input[4];
|
| 669 |
+
std::memcpy(input, lineIn3_min + stride * i, 3);
|
| 670 |
+
p1 = mm_cvtepu8_epi32(input, true);
|
| 671 |
+
} else {
|
| 672 |
+
p1 = mm_cvtepu8_epi32(lineIn3_min + stride * i, i32_aligned);
|
| 673 |
+
}
|
| 674 |
+
auto pix2 = _mm256_inserti128_si256(_mm256_castsi128_si256(p0), p1, 1);
|
| 675 |
+
sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
|
| 676 |
+
}
|
| 677 |
+
|
| 678 |
+
// Convert fixed point values back to integers (truncating)
|
| 679 |
+
sss0 = _mm256_srai_epi32(sss0, coefs_precision);
|
| 680 |
+
sss1 = _mm256_srai_epi32(sss1, coefs_precision);
|
| 681 |
+
// Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
|
| 682 |
+
// (a a a a b b b b c c c c d d d d) -> (a a b b c c d d 0 0 0 0 0 0 0 0)
|
| 683 |
+
sss0 = _mm256_packs_epi32(sss0, zero);
|
| 684 |
+
sss1 = _mm256_packs_epi32(sss1, zero);
|
| 685 |
+
// Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
|
| 686 |
+
// (a a b b c c d d) -> (a b c d 0 0 0 0)
|
| 687 |
+
sss0 = _mm256_packus_epi16(sss0, zero);
|
| 688 |
+
sss1 = _mm256_packus_epi16(sss1, zero);
|
| 689 |
+
|
| 690 |
+
// Write the output into single uint32
|
| 691 |
+
// (a b c d) -> x_uint32
|
| 692 |
+
auto o0 = _mm_cvtsi128_si32(_mm256_castsi256_si128(sss0));
|
| 693 |
+
auto o1 = _mm_cvtsi128_si32(_mm256_extracti128_si256(sss0, 1));
|
| 694 |
+
auto o2 = _mm_cvtsi128_si32(_mm256_castsi256_si128(sss1));
|
| 695 |
+
auto o3 = _mm_cvtsi128_si32(_mm256_extracti128_si256(sss1, 1));
|
| 696 |
+
|
| 697 |
+
const auto out_x_strided = stride * out_x;
|
| 698 |
+
|
| 699 |
+
if (num_channels == 3 && C10_UNLIKELY(out_x_strided + 4 >= max_out_x_strided)) {
|
| 700 |
+
// Memcpy 4-bytes is faster than 3-bytes and this is a boundary case when we want to write
|
| 701 |
+
// 4 bytes (R G B | X) to the output buffer (X1 X2 X3 | R1).
|
| 702 |
+
// The 4th byte in the register (X) has a garbage value and 4th byte in the output buffer (R1) has a correct
|
| 703 |
+
// value which was previously computed by another line. In other words, it means that we can not overwrite
|
| 704 |
+
// it by simply writing 4 bytes from the register to the output. We'll do the following:
|
| 705 |
+
// v----------|
|
| 706 |
+
// Output = [... X1 X2 X3 | R1 G1 B1 R2 ...]
|
| 707 |
+
// First, we write R1 value to the 4th byte of (R G B | X) -> (R G B | R1)
|
| 708 |
+
// Second, we write 4 bytes from the register to the output: (X1 X2 X3 | R1) -> (R G B | R1)
|
| 709 |
+
// Output = [... R G B | R1 G1 B1 R2 ...]
|
| 710 |
+
|
| 711 |
+
_write_endline_rgb_as_uint32(lineOut0 + out_x_strided, o0);
|
| 712 |
+
_write_endline_rgb_as_uint32(lineOut1 + out_x_strided, o1);
|
| 713 |
+
_write_endline_rgb_as_uint32(lineOut2 + out_x_strided, o2);
|
| 714 |
+
|
| 715 |
+
if (C10_UNLIKELY(is_last_line)) {
|
| 716 |
+
// When we handle the last line, we can not access the next 4 bytes
|
| 717 |
+
// as they are out of memory bounds.
|
| 718 |
+
std::memcpy(lineOut3 + out_x_strided, (uint8_t *) &o3, num_channels);
|
| 719 |
+
} else {
|
| 720 |
+
_write_endline_rgb_as_uint32(lineOut3 + out_x_strided, o3);
|
| 721 |
+
}
|
| 722 |
+
} else if (num_channels == 3) {
|
| 723 |
+
// Memcpy 4-bytes is faster than 3-bytes and here
|
| 724 |
+
// we simply write 4 bytes (... R G B X 0 0 0 0 0 ...) where X is a garbage value
|
| 725 |
+
// that we will overwrite on the next iteration: (... R G B R G B X 0 0 ...)
|
| 726 |
+
std::memcpy(lineOut0 + out_x_strided, (uint8_t *) &o0, 4);
|
| 727 |
+
std::memcpy(lineOut1 + out_x_strided, (uint8_t *) &o1, 4);
|
| 728 |
+
std::memcpy(lineOut2 + out_x_strided, (uint8_t *) &o2, 4);
|
| 729 |
+
std::memcpy(lineOut3 + out_x_strided, (uint8_t *) &o3, 4);
|
| 730 |
+
} else {
|
| 731 |
+
// num_channels = 4 -> lineOutX + out_x_strided should be uint32 aligned
|
| 732 |
+
*(uint32_t *)(lineOut0 + out_x_strided) = o0;
|
| 733 |
+
*(uint32_t *)(lineOut1 + out_x_strided) = o1;
|
| 734 |
+
*(uint32_t *)(lineOut2 + out_x_strided) = o2;
|
| 735 |
+
*(uint32_t *)(lineOut3 + out_x_strided) = o3;
|
| 736 |
+
}
|
| 737 |
+
}
|
| 738 |
+
}
|
| 739 |
+
|
| 740 |
+
void ImagingResampleHorizontalConvolution8u(
|
| 741 |
+
uint8_t* C10_RESTRICT lineOut,
|
| 742 |
+
int64_t out_xsize,
|
| 743 |
+
const uint8_t* C10_RESTRICT lineIn,
|
| 744 |
+
int64_t in_xsize,
|
| 745 |
+
const int64_t* idx_ptr_xmin,
|
| 746 |
+
const int64_t* idx_ptr_size,
|
| 747 |
+
const int16_t* kk,
|
| 748 |
+
int kmax,
|
| 749 |
+
unsigned int coefs_precision,
|
| 750 |
+
int64_t num_channels,
|
| 751 |
+
bool is_last_line) {
|
| 752 |
+
|
| 753 |
+
// Interpolation horizontal pass processing only one vertical line.
|
| 754 |
+
// - Input data format is RGBA or RGB with R,G,B,A being uint8. In case of RGBA
|
| 755 |
+
// we can encode 4 values as a single uint32 value.
|
| 756 |
+
// - We split the size of weight vector for a given output index as a sum:
|
| 757 |
+
// ids_size = num_blocks_8 * 8 + num_blocks_4 * 4 + num_blocks_2 * 2 + num_blocks_1
|
| 758 |
+
// - We load and process 8 weights values in a loop ("block 8") then 4 weights and 2 weights values in
|
| 759 |
+
// in another loops ("block 4" and "block 2") and finally we process 1 weight value in the final loop ("block 1").
|
| 760 |
+
|
| 761 |
+
// Define various shuffling masks
|
| 762 |
+
const auto kmask_low = _mm256_set_epi8(
|
| 763 |
+
11, 10, 9, 8, 11, 10, 9, 8, 11, 10, 9, 8, 11, 10, 9, 8,
|
| 764 |
+
3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0);
|
| 765 |
+
const auto kmask_high = _mm256_set_epi8(
|
| 766 |
+
15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12,
|
| 767 |
+
7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4);
|
| 768 |
+
const auto kmask_hl = _mm256_set_epi8(
|
| 769 |
+
7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4,
|
| 770 |
+
3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0);
|
| 771 |
+
|
| 772 |
+
const auto mask_low_c4 = _mm256_set_epi8(
|
| 773 |
+
-1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0,
|
| 774 |
+
-1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);
|
| 775 |
+
const auto mask_high_c4 = _mm256_set_epi8(
|
| 776 |
+
-1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8,
|
| 777 |
+
-1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8);
|
| 778 |
+
const auto mask_low_c3 = _mm256_set_epi8(
|
| 779 |
+
-1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0,
|
| 780 |
+
-1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0);
|
| 781 |
+
const auto mask_high_c3 = _mm256_set_epi8(
|
| 782 |
+
-1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6,
|
| 783 |
+
-1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6);
|
| 784 |
+
const auto mask_hl_c3 = _mm256_set_epi8(
|
| 785 |
+
-1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6,
|
| 786 |
+
-1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0);
|
| 787 |
+
const auto mask_hl_c4 = _mm256_set_epi8(
|
| 788 |
+
-1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8,
|
| 789 |
+
-1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);
|
| 790 |
+
|
| 791 |
+
const auto mask_low128_c3 = _mm_set_epi8(
|
| 792 |
+
-1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0);
|
| 793 |
+
const auto mask_low128_c4 = _mm_set_epi8(
|
| 794 |
+
-1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);
|
| 795 |
+
|
| 796 |
+
const auto mask_low = (num_channels == 3) ? mask_low_c3 : mask_low_c4;
|
| 797 |
+
const auto mask_high = (num_channels == 3) ? mask_high_c3 : mask_high_c4;
|
| 798 |
+
const auto mask_hl = (num_channels == 3) ? mask_hl_c3 : mask_hl_c4;
|
| 799 |
+
const auto mask_low128 = (num_channels == 3) ? mask_low128_c3 : mask_low128_c4;
|
| 800 |
+
|
| 801 |
+
// out_xsize = output width, out_x = output x index
|
| 802 |
+
// ids_min is the input offset index corresponding to out_x
|
| 803 |
+
// ids_size is the interpolation size for out_x
|
| 804 |
+
|
| 805 |
+
const auto stride = num_channels * sizeof(uint8_t);
|
| 806 |
+
const auto zero = _mm_setzero_si128();
|
| 807 |
+
|
| 808 |
+
TORCH_INTERNAL_ASSERT(stride == 3 || stride == 4);
|
| 809 |
+
|
| 810 |
+
// Let's precompute ids_size limits for block 8, block 4 and block 2
|
| 811 |
+
//
|
| 812 |
+
// In block 8 (8 means we process 8 weight values together), we read at
|
| 813 |
+
// most 32 bytes input data (16 + 16 bytes for RGBA and 12 + 16 bytes for RGB)
|
| 814 |
+
// lineIn + stride * (i + ids_min) + 32 <= lineIn + stride * (ids_size + ids_min)
|
| 815 |
+
// --> i <= ids_size - 32.0 / stride
|
| 816 |
+
// Strict boundary:
|
| 817 |
+
// --> i < ids_size + 1 - int(ceil(32.0 / stride)) = ids_size - b8_delta
|
| 818 |
+
// Soft boundary for reading inside the buffer except its boundaries:
|
| 819 |
+
// --> i < ids_size + 1 - int(32.0 / stride) = ids_size - b8_delta_soft
|
| 820 |
+
// RGBA: b8_delta = b8_delta_soft = 7
|
| 821 |
+
// RGB : b8_delta = 10
|
| 822 |
+
// RGB : b8_delta_soft = 9
|
| 823 |
+
const auto b8_delta = (stride == 4) ? 7 : (is_last_line ? 10 : 9);
|
| 824 |
+
|
| 825 |
+
// In block 4 (4 means we process 4 weight values together), we read
|
| 826 |
+
// 16 bytes of input data.
|
| 827 |
+
// lineIn + stride * (i + ids_min) + 16 <= lineIn0 + stride * (ids_size + ids_min)
|
| 828 |
+
// --> i <= ids_size - 16.0 / stride
|
| 829 |
+
// Strict boundary:
|
| 830 |
+
// --> i < ids_size + 1 - int(ceil(16.0 / stride)) = ids_size - b4_delta
|
| 831 |
+
// Soft boundary for reading inside the buffer except its boundaries:
|
| 832 |
+
// --> i < ids_size + 1 - int(16.0 / stride) = ids_size - b4_delta_soft
|
| 833 |
+
// RGBA: b4_delta = b4_delta_soft = 3
|
| 834 |
+
// RGB : b4_delta = 5
|
| 835 |
+
// RGB : b4_delta_soft = 4
|
| 836 |
+
const auto b4_delta = (stride == 4) ? 3 : (is_last_line ? 5 : 4);
|
| 837 |
+
|
| 838 |
+
// In block 2 (2 means we process 2 weight values together), we read
|
| 839 |
+
// 8 bytes of input data.
|
| 840 |
+
// lineIn0 + stride * (i + ids_min) + 8 <= lineIn0 + stride * (ids_size + ids_min)
|
| 841 |
+
// --> i <= ids_size - 8.0 / stride
|
| 842 |
+
// Strict boundary:
|
| 843 |
+
// --> i < ids_size + 1 - int(ceil(8.0 / stride)) = ids_size - b2_delta
|
| 844 |
+
// Soft boundary for reading inside the buffer except its boundaries:
|
| 845 |
+
// --> i < ids_size + 1 - int(8.0 / stride) = ids_size - b2_delta_soft
|
| 846 |
+
// RGBA: b2_delta = b2_delta_soft = 1
|
| 847 |
+
// RGB : b2_delta = 2
|
| 848 |
+
// RGB : b2_delta_soft = 1
|
| 849 |
+
const auto b2_delta = (stride == 4) ? 1 : (is_last_line ? 2 : 1);
|
| 850 |
+
|
| 851 |
+
const auto max_out_x_strided = out_xsize * stride;
|
| 852 |
+
const auto max_in_x_strided = in_xsize * stride;
|
| 853 |
+
|
| 854 |
+
for (const auto out_x : c10::irange(out_xsize)) {
|
| 855 |
+
__m128i sss;
|
| 856 |
+
const auto ids_min = idx_ptr_xmin[out_x];
|
| 857 |
+
const auto ids_size = idx_ptr_size[out_x];
|
| 858 |
+
const auto * k = &kk[out_x * kmax];
|
| 859 |
+
int64_t i = 0;
|
| 860 |
+
|
| 861 |
+
const auto * lineIn_min = lineIn + ids_min;
|
| 862 |
+
|
| 863 |
+
if (ids_size < 8) {
|
| 864 |
+
sss = _mm_set1_epi32(1 << (coefs_precision - 1));
|
| 865 |
+
} else {
|
| 866 |
+
// Lower part will be added to higher, use only half of the error
|
| 867 |
+
auto sss256 = _mm256_set1_epi32(1 << (coefs_precision - 2));
|
| 868 |
+
|
| 869 |
+
// block 8
|
| 870 |
+
for (; i < ids_size - b8_delta; i += 8) {
|
| 871 |
+
// Load 8 values from weight vector
|
| 872 |
+
auto tmp = _mm_loadu_si128((__m128i*)&k[i]);
|
| 873 |
+
// ksource = [
|
| 874 |
+
// wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 wl_4 wh_4 wl_5 wh_5 wl_6 wh_6 wl_7 wh_7
|
| 875 |
+
// wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 wl_4 wh_4 wl_5 wh_5 wl_6 wh_6 wl_7 wh_7
|
| 876 |
+
// ]
|
| 877 |
+
auto ksource = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1);
|
| 878 |
+
|
| 879 |
+
// RGBA: Load 8 pixels from input:
|
| 880 |
+
// source = [
|
| 881 |
+
// r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3
|
| 882 |
+
// r4 g4 b4 a4 r5 g5 b5 a5 r6 g6 b6 a6 r7 g7 b7 a7
|
| 883 |
+
// ]
|
| 884 |
+
// RGB: Load 10 pixels from input (however we can process only 8 pixels):
|
| 885 |
+
// source = [
|
| 886 |
+
// r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5
|
| 887 |
+
// r4 g4 b4 r5 g5 b5 r6 g6 b6 r7 g7 b7 r8 g8 b8 r9
|
| 888 |
+
// ]
|
| 889 |
+
auto source = _mm256_inserti128_si256(_mm256_castsi128_si256(
|
| 890 |
+
_mm_loadu_si128((__m128i *) (lineIn_min + stride * i))),
|
| 891 |
+
_mm_loadu_si128((__m128i *) (lineIn_min + stride * (i + 4))), 1);
|
| 892 |
+
|
| 893 |
+
// Extract lower part of each lane, cast to epi16 and reorder RGBARGBA -> RRGGBBAA
|
| 894 |
+
// RGBA: pix1 = [
|
| 895 |
+
// r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0
|
| 896 |
+
// r4 0 r5 0 g4 0 g5 0 b4 0 b5 0 a4 0 a5 0
|
| 897 |
+
// ]
|
| 898 |
+
// RGB: pix1 = [
|
| 899 |
+
// r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0
|
| 900 |
+
// r4 0 r5 0 g4 0 g5 0 b4 0 b5 0 0 0 0 0
|
| 901 |
+
// ]
|
| 902 |
+
auto pix1 = _mm256_shuffle_epi8(source, mask_low);
|
| 903 |
+
// mmk1 = [
|
| 904 |
+
// wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ...
|
| 905 |
+
// wl_4 wh_4 wl_5 wh_5 wl_4 wh_4 wl_5 wh_5 ... ...
|
| 906 |
+
// ]
|
| 907 |
+
auto mmk1 = _mm256_shuffle_epi8(ksource, kmask_low);
|
| 908 |
+
// Compute output value as
|
| 909 |
+
// C += w0 * C0 + w1 * C1
|
| 910 |
+
// C += w4 * C4 + w5 * C5 for each channel in 32-bit precision
|
| 911 |
+
sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix1, mmk1));
|
| 912 |
+
|
| 913 |
+
// Same as above for higher part of each lane
|
| 914 |
+
auto pix2 = _mm256_shuffle_epi8(source, mask_high);
|
| 915 |
+
auto mmk2 = _mm256_shuffle_epi8(ksource, kmask_high);
|
| 916 |
+
// Compute output value as
|
| 917 |
+
// C += w2 * C2 + w3 * C3
|
| 918 |
+
// C += w6 * C6 + w7 * C7 for each channel in 32-bit precision
|
| 919 |
+
sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix2, mmk2));
|
| 920 |
+
}
|
| 921 |
+
|
| 922 |
+
// block 4
|
| 923 |
+
for (; i < ids_size - b4_delta; i += 4) {
|
| 924 |
+
// Load 4 values from weight vector
|
| 925 |
+
auto tmp = _mm_loadl_epi64((__m128i *) &k[i]);
|
| 926 |
+
// ksource = [
|
| 927 |
+
// wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 0 0 0 0 0 0 0 0
|
| 928 |
+
// wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 0 0 0 0 0 0 0 0
|
| 929 |
+
// ]
|
| 930 |
+
auto ksource = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1);
|
| 931 |
+
|
| 932 |
+
// Load pixels from input line
|
| 933 |
+
tmp = _mm_loadu_si128((__m128i *) (lineIn_min + stride * i));
|
| 934 |
+
// RGBA: source = [
|
| 935 |
+
// r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3
|
| 936 |
+
// r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3
|
| 937 |
+
// ]
|
| 938 |
+
// RGB: source = [
|
| 939 |
+
// r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5
|
| 940 |
+
// r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5
|
| 941 |
+
// ]
|
| 942 |
+
auto source = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1);
|
| 943 |
+
|
| 944 |
+
// Cast source to epi16 and reorder RGBARGBA -> RRGGBBAA
|
| 945 |
+
// RGBA: pix = [
|
| 946 |
+
// r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0
|
| 947 |
+
// r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 a2 0 a3 0
|
| 948 |
+
// ]
|
| 949 |
+
// RGB: pix = [
|
| 950 |
+
// r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0
|
| 951 |
+
// r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 0 0 0 0
|
| 952 |
+
// ]
|
| 953 |
+
auto pix = _mm256_shuffle_epi8(source, mask_hl);
|
| 954 |
+
// mmk = [
|
| 955 |
+
// wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ...
|
| 956 |
+
// wl_2 wh_2 wl_3 wh_3 wl_2 wh_2 wl_3 wh_3 ... ...
|
| 957 |
+
// ]
|
| 958 |
+
auto mmk = _mm256_shuffle_epi8(ksource, kmask_hl);
|
| 959 |
+
// Compute output value as
|
| 960 |
+
// C += w0 * C0 + w1 * C1
|
| 961 |
+
// C += w2 * C2 + w3 * C3 for each channel in 32-bit precision
|
| 962 |
+
sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix, mmk));
|
| 963 |
+
}
|
| 964 |
+
|
| 965 |
+
// Sum results between the lanes
|
| 966 |
+
sss = _mm_add_epi32(
|
| 967 |
+
_mm256_extracti128_si256(sss256, 0),
|
| 968 |
+
_mm256_extracti128_si256(sss256, 1));
|
| 969 |
+
}
|
| 970 |
+
|
| 971 |
+
// block 2
|
| 972 |
+
for (; i < ids_size - b2_delta; i += 2) {
|
| 973 |
+
// Load 2 values from weight vector
|
| 974 |
+
// mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ...]
|
| 975 |
+
auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]);
|
| 976 |
+
// Load pixels from input line
|
| 977 |
+
// RGBA: source = [
|
| 978 |
+
// r0 g0 b0 a0 r1 g1 b1 a1 0 0 0 0 0 0 0 0
|
| 979 |
+
// ]
|
| 980 |
+
// RGB: source = [
|
| 981 |
+
// r0 g0 b0 r1 g1 b1 r2 g2 0 0 0 0 0 0 0 0
|
| 982 |
+
// ]
|
| 983 |
+
auto source = _mm_loadl_epi64((__m128i *) (lineIn_min + stride * i));
|
| 984 |
+
// Cast source to epi16 and reorder RGBARGBA -> RRGGBBAA
|
| 985 |
+
auto pix = _mm_shuffle_epi8(source, mask_low128);
|
| 986 |
+
// Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision
|
| 987 |
+
sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
|
| 988 |
+
}
|
| 989 |
+
|
| 990 |
+
// block 1
|
| 991 |
+
const auto i32_aligned = num_channels == 4;
|
| 992 |
+
for (; i < ids_size - 1; i++) {
|
| 993 |
+
// Load 1 value from weight vector
|
| 994 |
+
// mmk = [wl_0 wh_0 0 0 wl_0 wh_0 0 0 ...]
|
| 995 |
+
auto mmk = _mm_set1_epi32(k[i]);
|
| 996 |
+
// Load one pixel from input line
|
| 997 |
+
// RGBA: pix = [
|
| 998 |
+
// r0 0 0 0 g0 0 0 0 b0 0 0 0 a0 0 0 0
|
| 999 |
+
// ]
|
| 1000 |
+
// RGB: pix = [
|
| 1001 |
+
// r0 0 0 0 g0 0 0 0 b0 0 0 0 r1 0 0 0
|
| 1002 |
+
// ]
|
| 1003 |
+
auto pix = mm_cvtepu8_epi32(lineIn_min + stride * i, i32_aligned);
|
| 1004 |
+
// Compute output value as C += w0 * C0 for each channel in 32-bit precision
|
| 1005 |
+
sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
|
| 1006 |
+
}
|
| 1007 |
+
|
| 1008 |
+
if (i == ids_size - 1) {
|
| 1009 |
+
// last element
|
| 1010 |
+
auto mmk = _mm_set1_epi32(k[i]);
|
| 1011 |
+
__m128i pix;
|
| 1012 |
+
auto p = lineIn_min + stride * i;
|
| 1013 |
+
if (num_channels == 3 && C10_UNLIKELY(is_last_line && ids_min + stride * i + 4 >= max_in_x_strided)) {
|
| 1014 |
+
uint8_t input[4];
|
| 1015 |
+
std::memcpy(input, p, 3);
|
| 1016 |
+
pix = mm_cvtepu8_epi32(input, true);
|
| 1017 |
+
} else {
|
| 1018 |
+
pix = mm_cvtepu8_epi32(p, i32_aligned);
|
| 1019 |
+
}
|
| 1020 |
+
sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
|
| 1021 |
+
}
|
| 1022 |
+
|
| 1023 |
+
// Convert fixed point values back to integers (truncating)
|
| 1024 |
+
sss = _mm_srai_epi32(sss, coefs_precision);
|
| 1025 |
+
// Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
|
| 1026 |
+
// (a a a a b b b b c c c c d d d d) -> (a a b b c c d d 0 0 0 0 0 0 0 0)
|
| 1027 |
+
sss = _mm_packs_epi32(sss, zero);
|
| 1028 |
+
// Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
|
| 1029 |
+
// (a a b b c c d d) -> (a b c d 0 0 0 0)
|
| 1030 |
+
sss = _mm_packus_epi16(sss, zero);
|
| 1031 |
+
// Write the output into single uint32
|
| 1032 |
+
// (a b c d) -> x_uint32
|
| 1033 |
+
auto o = _mm_cvtsi128_si32(sss);
|
| 1034 |
+
const auto out_x_strided = stride * out_x;
|
| 1035 |
+
if (num_channels == 3 && C10_UNLIKELY(out_x_strided + 4 >= max_out_x_strided)) {
|
| 1036 |
+
if (C10_UNLIKELY(is_last_line)) {
|
| 1037 |
+
// When we handle the last line, we can not access the next 4 bytes
|
| 1038 |
+
// as they are out of memory bounds.
|
| 1039 |
+
std::memcpy(lineOut + out_x_strided, (uint8_t *) &o, 3);
|
| 1040 |
+
} else {
|
| 1041 |
+
// Memcpy 4-bytes is faster than 3-bytes and this is a boundary case when we want to write
|
| 1042 |
+
// 4 bytes (R G B | X) to the output buffer (X1 X2 X3 | R1).
|
| 1043 |
+
// The 4th byte in the register (X) has a garbage value and 4th byte in the output buffer (R1) has a correct
|
| 1044 |
+
// value which was previously computed by another line. In other words, it means that we can not overwrite
|
| 1045 |
+
// it by simply writing 4 bytes from the register to the output. We'll do the following:
|
| 1046 |
+
// v----------|
|
| 1047 |
+
// Output = [... X1 X2 X3 | R1 G1 B1 R2 ...]
|
| 1048 |
+
// First, we write R1 value to the 4th byte of (R G B | X) -> (R G B | R1)
|
| 1049 |
+
// Second, we write 4 bytes from the register to the output: (X1 X2 X3 | R1) -> (R G B | R1)
|
| 1050 |
+
// Output = [... R G B | R1 G1 B1 R2 ...]
|
| 1051 |
+
_write_endline_rgb_as_uint32(lineOut + out_x_strided, o);
|
| 1052 |
+
}
|
| 1053 |
+
} else if (num_channels == 3) {
|
| 1054 |
+
// Memcpy 4-bytes is faster than 3-bytes and here
|
| 1055 |
+
// we simply write 4 bytes (... R G B X 0 0 0 0 0 ...) where X is a garbage value
|
| 1056 |
+
// that we will overwrite on the next iteration: (... R G B R G B X 0 0 ...)
|
| 1057 |
+
std::memcpy(lineOut + out_x_strided, (uint8_t *) &o, 4);
|
| 1058 |
+
} else {
|
| 1059 |
+
// num_channels = 4 -> lineOut + out_x_strided should be uint32 aligned
|
| 1060 |
+
*(uint32_t *)(lineOut + out_x_strided) = o;
|
| 1061 |
+
}
|
| 1062 |
+
}
|
| 1063 |
+
}
|
| 1064 |
+
|
| 1065 |
+
void ImagingResampleVerticalConvolution8u(
|
| 1066 |
+
uint8_t* C10_RESTRICT lineOut,
|
| 1067 |
+
const uint8_t* C10_RESTRICT lineIn,
|
| 1068 |
+
int64_t xsize,
|
| 1069 |
+
int64_t ids_min,
|
| 1070 |
+
int64_t ids_size,
|
| 1071 |
+
const int16_t* k,
|
| 1072 |
+
unsigned int coefs_precision,
|
| 1073 |
+
int64_t num_channels) {
|
| 1074 |
+
|
| 1075 |
+
// Interpolation vertical pass processing one line.
|
| 1076 |
+
// - We process x-axis data with blocks of 8, 2 and 1
|
| 1077 |
+
// - We split the size of weight vector for a given output index as a sum: K = n * 2 + m.
|
| 1078 |
+
|
| 1079 |
+
// xsize = output width, also equals to input width
|
| 1080 |
+
// ids_size = interpolation size
|
| 1081 |
+
// ids_min = input y start index
|
| 1082 |
+
const auto stride = num_channels * sizeof(uint8_t);
|
| 1083 |
+
|
| 1084 |
+
TORCH_INTERNAL_ASSERT(stride == 3 || stride == 4);
|
| 1085 |
+
|
| 1086 |
+
const int64_t data_size = xsize * stride;
|
| 1087 |
+
const int64_t data_stride = stride;
|
| 1088 |
+
constexpr auto vec_size = 256 / 8;
|
| 1089 |
+
|
| 1090 |
+
const auto initial = _mm_set1_epi32(1 << (coefs_precision - 1));
|
| 1091 |
+
const auto initial_256 = _mm256_set1_epi32(1 << (coefs_precision - 1));
|
| 1092 |
+
const auto zero = _mm_setzero_si128();
|
| 1093 |
+
const auto zero_256 = _mm256_setzero_si256();
|
| 1094 |
+
|
| 1095 |
+
int64_t j = 0;
|
| 1096 |
+
// block 8
|
| 1097 |
+
const auto b8_usable_vec_stride = (vec_size / data_stride) * data_stride;
|
| 1098 |
+
for (; j < data_size - vec_size; j += b8_usable_vec_stride) {
|
| 1099 |
+
auto sss0 = initial_256;
|
| 1100 |
+
auto sss1 = initial_256;
|
| 1101 |
+
auto sss2 = initial_256;
|
| 1102 |
+
auto sss3 = initial_256;
|
| 1103 |
+
int64_t i = 0;
|
| 1104 |
+
const auto * lineIn_min = lineIn + j + ids_min;
|
| 1105 |
+
|
| 1106 |
+
for (; i < ids_size - 1; i += 2) {
|
| 1107 |
+
// Load 2 values from weight vector
|
| 1108 |
+
auto mmk = _mm256_set1_epi32(*(int32_t*)&k[i]);
|
| 1109 |
+
|
| 1110 |
+
// RGBA: Load 8 pixels per line
|
| 1111 |
+
// source1 = [
|
| 1112 |
+
// r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3
|
| 1113 |
+
// r4 g4 b4 a4 r5 g5 b5 a5 r6 g6 b6 a6 r7 g7 b7 a7
|
| 1114 |
+
// ]
|
| 1115 |
+
// RGB: Load 10 pixels per line (however we can process only 8 pixels):
|
| 1116 |
+
// source1 = [
|
| 1117 |
+
// r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5
|
| 1118 |
+
// r4 g4 b4 r5 g5 b5 r6 g6 b6 r7 g7 b7 r8 g8 b8 r9
|
| 1119 |
+
// ]
|
| 1120 |
+
auto source1 =
|
| 1121 |
+
_mm256_loadu_si256((__m256i*)(lineIn_min + data_size * i));
|
| 1122 |
+
auto source2 =
|
| 1123 |
+
_mm256_loadu_si256((__m256i*)(lineIn_min + data_size * (i + 1)));
|
| 1124 |
+
|
| 1125 |
+
// Interleave source1 and source2 from the low half of each 128-bit lane
|
| 1126 |
+
// and cast the result to epi16
|
| 1127 |
+
// RGBA: pix1 = [
|
| 1128 |
+
// r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 a0 0 A0 0
|
| 1129 |
+
// r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 a1 0 A1 0
|
| 1130 |
+
// ]
|
| 1131 |
+
// RGB: pix1 = [
|
| 1132 |
+
// r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 0 0 0 0
|
| 1133 |
+
// r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 0 0 0 0
|
| 1134 |
+
// ]
|
| 1135 |
+
auto source_lo = _mm256_unpacklo_epi8(source1, source2);
|
| 1136 |
+
auto pix1 = _mm256_unpacklo_epi8(source_lo, zero_256);
|
| 1137 |
+
// Compute output value as
|
| 1138 |
+
// C += w0 * c0 + w1 * C0
|
| 1139 |
+
// C += w0 * c1 + w1 * C1 for each channel in 32-bit precision
|
| 1140 |
+
sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk));
|
| 1141 |
+
|
| 1142 |
+
// RGBA: pix2 = [
|
| 1143 |
+
// r2 0 R2 0 g2 0 G2 0 b2 0 B2 0 a2 0 A2 0
|
| 1144 |
+
// r3 0 R3 0 g3 0 G3 0 b3 0 B3 0 a3 0 A3 0
|
| 1145 |
+
// ]
|
| 1146 |
+
// RGB: pix2 = [
|
| 1147 |
+
// r2 0 R2 0 g2 0 G2 0 b2 0 B2 0 0 0 0 0
|
| 1148 |
+
// r3 0 R3 0 g3 0 G3 0 b3 0 B3 0 0 0 0 0
|
| 1149 |
+
// ]
|
| 1150 |
+
auto pix2 = _mm256_unpackhi_epi8(source_lo, zero_256);
|
| 1151 |
+
// Compute output value as
|
| 1152 |
+
// C += w0 * c2 + w1 * C2
|
| 1153 |
+
// C += w0 * c3 + w1 * C3 for each channel in 32-bit precision
|
| 1154 |
+
sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
|
| 1155 |
+
|
| 1156 |
+
// Same as above for the high half of each 128-bit lane
|
| 1157 |
+
auto source_hi = _mm256_unpackhi_epi8(source1, source2);
|
| 1158 |
+
auto pix3 = _mm256_unpacklo_epi8(source_hi, zero_256);
|
| 1159 |
+
sss2 = _mm256_add_epi32(sss2, _mm256_madd_epi16(pix3, mmk));
|
| 1160 |
+
auto pix4 = _mm256_unpackhi_epi8(source_hi, zero_256);
|
| 1161 |
+
sss3 = _mm256_add_epi32(sss3, _mm256_madd_epi16(pix4, mmk));
|
| 1162 |
+
}
|
| 1163 |
+
// Same processing as above but with a single weight value
|
| 1164 |
+
for (; i < ids_size; i += 1) {
|
| 1165 |
+
auto mmk = _mm256_set1_epi32(k[i]);
|
| 1166 |
+
|
| 1167 |
+
auto source1 = _mm256_loadu_si256((__m256i*)(lineIn_min + i * data_size));
|
| 1168 |
+
|
| 1169 |
+
auto source_lo = _mm256_unpacklo_epi8(source1, zero_256);
|
| 1170 |
+
auto pix1 = _mm256_unpacklo_epi8(source_lo, zero_256);
|
| 1171 |
+
sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk));
|
| 1172 |
+
auto pix2 = _mm256_unpackhi_epi8(source_lo, zero_256);
|
| 1173 |
+
sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
|
| 1174 |
+
|
| 1175 |
+
auto source_hi = _mm256_unpackhi_epi8(source1, zero_256);
|
| 1176 |
+
auto pix3 = _mm256_unpacklo_epi8(source_hi, _mm256_setzero_si256());
|
| 1177 |
+
sss2 = _mm256_add_epi32(sss2, _mm256_madd_epi16(pix3, mmk));
|
| 1178 |
+
auto pix4 = _mm256_unpackhi_epi8(source_hi, _mm256_setzero_si256());
|
| 1179 |
+
sss3 = _mm256_add_epi32(sss3, _mm256_madd_epi16(pix4, mmk));
|
| 1180 |
+
}
|
| 1181 |
+
// Convert fixed point values back to integers (truncating)
|
| 1182 |
+
sss0 = _mm256_srai_epi32(sss0, coefs_precision);
|
| 1183 |
+
sss1 = _mm256_srai_epi32(sss1, coefs_precision);
|
| 1184 |
+
sss2 = _mm256_srai_epi32(sss2, coefs_precision);
|
| 1185 |
+
sss3 = _mm256_srai_epi32(sss3, coefs_precision);
|
| 1186 |
+
// Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
|
| 1187 |
+
// (a a a a b b b b c c c c d d d d) -> (a a b b c c d d)
|
| 1188 |
+
sss0 = _mm256_packs_epi32(sss0, sss1);
|
| 1189 |
+
sss2 = _mm256_packs_epi32(sss2, sss3);
|
| 1190 |
+
// Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
|
| 1191 |
+
// (a a b b c c d d) -> (a b c d)
|
| 1192 |
+
sss0 = _mm256_packus_epi16(sss0, sss2);
|
| 1193 |
+
|
| 1194 |
+
// Stores 32 bytes
|
| 1195 |
+
_mm256_storeu_si256((__m256i*)(lineOut + j), sss0);
|
| 1196 |
+
}
|
| 1197 |
+
|
| 1198 |
+
// TODO: Do we also need block 4 ???
|
| 1199 |
+
// block 2
|
| 1200 |
+
const auto b2_usable_vec_stride = (8 / data_stride) * data_stride;
|
| 1201 |
+
for (; j < data_size - vec_size / 4; j += b2_usable_vec_stride) {
|
| 1202 |
+
auto sss0 = initial;
|
| 1203 |
+
auto sss1 = initial;
|
| 1204 |
+
int64_t i = 0;
|
| 1205 |
+
const auto * lineIn_min = lineIn + j + ids_min;
|
| 1206 |
+
|
| 1207 |
+
for (; i < ids_size - 1; i += 2) {
|
| 1208 |
+
// Load 2 values from weight vector
|
| 1209 |
+
// mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ]
|
| 1210 |
+
auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]);
|
| 1211 |
+
|
| 1212 |
+
// Load 2 pixels per line
|
| 1213 |
+
// RGBA: source1 = [
|
| 1214 |
+
// r0 g0 b0 a0 r1 g1 b1 a1 0 0 0 0 0 0 0 0
|
| 1215 |
+
// ]
|
| 1216 |
+
// RGB: source1 = [
|
| 1217 |
+
// r0 g0 b0 r1 g1 b1 r2 g2 0 0 0 0 0 0 0 0
|
| 1218 |
+
// ]
|
| 1219 |
+
auto source1 = _mm_loadl_epi64((__m128i *) (lineIn_min + i * data_size));
|
| 1220 |
+
auto source2 = _mm_loadl_epi64((__m128i *) (lineIn_min + (i + 1) * data_size));
|
| 1221 |
+
// Interleave source1 and source2 and cast the result to epi16
|
| 1222 |
+
// RGBA: pix = [
|
| 1223 |
+
// r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 a0 0 A0 0
|
| 1224 |
+
// ]
|
| 1225 |
+
// RGB: pix = [
|
| 1226 |
+
// r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 0 0 0 0
|
| 1227 |
+
// ]
|
| 1228 |
+
auto source = _mm_unpacklo_epi8(source1, source2);
|
| 1229 |
+
auto pix = _mm_unpacklo_epi8(source, zero);
|
| 1230 |
+
// Compute output value as C += w0 * c0 + w1 * C0 for each channel in 32-bit precision
|
| 1231 |
+
sss0 = _mm_add_epi32(sss0, _mm_madd_epi16(pix, mmk));
|
| 1232 |
+
// RGBA: pix = [
|
| 1233 |
+
// r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 a1 0 A1 0
|
| 1234 |
+
// ]
|
| 1235 |
+
// RGB: pix = [
|
| 1236 |
+
// r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 0 0 0 0
|
| 1237 |
+
// ]
|
| 1238 |
+
pix = _mm_unpackhi_epi8(source, zero);
|
| 1239 |
+
// Compute output value as C += w0 * c1 + w1 * C1 for each channel in 32-bit precision
|
| 1240 |
+
sss1 = _mm_add_epi32(sss1, _mm_madd_epi16(pix, mmk));
|
| 1241 |
+
}
|
| 1242 |
+
// Same processing as above but with a single weight value
|
| 1243 |
+
for (; i < ids_size; i += 1) {
|
| 1244 |
+
auto mmk = _mm_set1_epi32(k[i]);
|
| 1245 |
+
|
| 1246 |
+
auto source1 = _mm_loadl_epi64((__m128i*) (lineIn_min + i * data_size));
|
| 1247 |
+
|
| 1248 |
+
auto source = _mm_unpacklo_epi8(source1, zero);
|
| 1249 |
+
auto pix1 = _mm_unpacklo_epi8(source, zero);
|
| 1250 |
+
sss0 = _mm_add_epi32(sss0, _mm_madd_epi16(pix1, mmk));
|
| 1251 |
+
auto pix2 = _mm_unpackhi_epi8(source, zero);
|
| 1252 |
+
sss1 = _mm_add_epi32(sss1, _mm_madd_epi16(pix2, mmk));
|
| 1253 |
+
}
|
| 1254 |
+
// Convert fixed point values back to integers (truncating)
|
| 1255 |
+
sss0 = _mm_srai_epi32(sss0, coefs_precision);
|
| 1256 |
+
sss1 = _mm_srai_epi32(sss1, coefs_precision);
|
| 1257 |
+
// Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
|
| 1258 |
+
// (a a a a b b b b c c c c d d d d) -> (a a b b c c d d)
|
| 1259 |
+
sss0 = _mm_packs_epi32(sss0, sss1);
|
| 1260 |
+
// Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
|
| 1261 |
+
// (a a b b c c d d) -> (a b c d)
|
| 1262 |
+
sss0 = _mm_packus_epi16(sss0, sss0);
|
| 1263 |
+
// Store 2 pixels to the output
|
| 1264 |
+
_mm_storel_epi64((__m128i*)(lineOut + j), sss0);
|
| 1265 |
+
}
|
| 1266 |
+
|
| 1267 |
+
// block 1
|
| 1268 |
+
const auto b1_usable_vec_stride = (4 / data_stride) * data_stride;
|
| 1269 |
+
const auto i32_aligned = num_channels == 4;
|
| 1270 |
+
for (; j < data_size - 4; j += b1_usable_vec_stride) {
|
| 1271 |
+
auto sss = initial;
|
| 1272 |
+
int64_t i = 0;
|
| 1273 |
+
const auto * lineIn_min = lineIn + j + ids_min;
|
| 1274 |
+
|
| 1275 |
+
for (; i < ids_size - 1; i += 2) {
|
| 1276 |
+
// Load 2 values from weight vector
|
| 1277 |
+
// mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ]
|
| 1278 |
+
auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]);
|
| 1279 |
+
|
| 1280 |
+
// Load one pixel per line
|
| 1281 |
+
// RGBA: source1 = [
|
| 1282 |
+
// r0 g0 b0 a0 0 0 0 0 0 0 0 0 0 0 0 0
|
| 1283 |
+
// ]
|
| 1284 |
+
// RGB: source1 = [
|
| 1285 |
+
// r0 g0 b0 r1 0 0 0 0 0 0 0 0 0 0 0 0
|
| 1286 |
+
// ]
|
| 1287 |
+
auto source1 = mm_cvtsi32_si128(lineIn_min + i * data_size, i32_aligned);
|
| 1288 |
+
auto source2 = mm_cvtsi32_si128(lineIn_min + (i + 1) * data_size, i32_aligned);
|
| 1289 |
+
|
| 1290 |
+
// Interleave source1 and source2 and cast the result to epi16
|
| 1291 |
+
// RGBA: pix = [
|
| 1292 |
+
// r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 a0 0 A0 0
|
| 1293 |
+
// ]
|
| 1294 |
+
// RGB: pix = [
|
| 1295 |
+
// r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 0 0 0 0
|
| 1296 |
+
// ]
|
| 1297 |
+
auto source = _mm_unpacklo_epi8(source1, source2);
|
| 1298 |
+
auto pix = _mm_unpacklo_epi8(source, zero);
|
| 1299 |
+
// Compute output value as C += w0 * c0 + w1 * C0 for each channel in 32-bit precision
|
| 1300 |
+
sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
|
| 1301 |
+
}
|
| 1302 |
+
|
| 1303 |
+
for (; i < ids_size; i++) {
|
| 1304 |
+
auto mmk = _mm_set1_epi32(k[i]);
|
| 1305 |
+
auto pix = mm_cvtepu8_epi32(lineIn_min + i * data_size, i32_aligned);
|
| 1306 |
+
sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
|
| 1307 |
+
}
|
| 1308 |
+
sss = _mm_srai_epi32(sss, coefs_precision);
|
| 1309 |
+
sss = _mm_packs_epi32(sss, zero);
|
| 1310 |
+
sss = _mm_packus_epi16(sss, zero);
|
| 1311 |
+
|
| 1312 |
+
auto o = _mm_cvtsi128_si32(sss);
|
| 1313 |
+
|
| 1314 |
+
// Here we write 4 bytes to the output even if num_channels < 4, e.g o = {r,g,b,X} for num_channels=3
|
| 1315 |
+
// It is OK to write 4th byte (e.g. X) as on the next step we will overwrite it with new data.
|
| 1316 |
+
// We also won't go out of bounds of lineOut memory allocation
|
| 1317 |
+
std::memcpy(lineOut + j, (uint8_t *) &o, 4);
|
| 1318 |
+
}
|
| 1319 |
+
|
| 1320 |
+
for (; j < data_size; j += data_stride) {
|
| 1321 |
+
auto sss = initial;
|
| 1322 |
+
int64_t i = 0;
|
| 1323 |
+
const auto * lineIn_min = lineIn + j + ids_min;
|
| 1324 |
+
// For RGBA we can use (ids_size - 1) as tighter limit but for RGB we can read outside memory boundary
|
| 1325 |
+
// for the last remaining line
|
| 1326 |
+
for (; i < ids_size - 2; i += 2) {
|
| 1327 |
+
// Load two coefficients at once
|
| 1328 |
+
auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]);
|
| 1329 |
+
|
| 1330 |
+
// Load 2 lines
|
| 1331 |
+
auto source1 = mm_cvtsi32_si128(lineIn_min + i * data_size, i32_aligned);
|
| 1332 |
+
auto source2 = mm_cvtsi32_si128(lineIn_min + (i + 1) * data_size, i32_aligned);
|
| 1333 |
+
|
| 1334 |
+
auto source = _mm_unpacklo_epi8(source1, source2);
|
| 1335 |
+
auto pix = _mm_unpacklo_epi8(source, zero);
|
| 1336 |
+
sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
|
| 1337 |
+
}
|
| 1338 |
+
|
| 1339 |
+
// Same processing as above but with a single weight value
|
| 1340 |
+
for (; i < ids_size; i++) {
|
| 1341 |
+
auto mmk = _mm_set1_epi32(k[i]);
|
| 1342 |
+
|
| 1343 |
+
const uint8_t * p = lineIn_min + i * data_size;
|
| 1344 |
+
__m128i pix;
|
| 1345 |
+
// There is no much perf gain using more detailed condition like
|
| 1346 |
+
// num_channels == 3 && ids_min + j + data_size * i + 4 >= in_max_size
|
| 1347 |
+
// const int64_t in_max_size = data_size * in_ysize;
|
| 1348 |
+
if (num_channels == 3) {
|
| 1349 |
+
uint8_t input[4];
|
| 1350 |
+
std::memcpy(input, p, 3);
|
| 1351 |
+
pix = mm_cvtepu8_epi32(input, true);
|
| 1352 |
+
} else {
|
| 1353 |
+
pix = mm_cvtepu8_epi32(p, true);
|
| 1354 |
+
}
|
| 1355 |
+
sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
|
| 1356 |
+
}
|
| 1357 |
+
|
| 1358 |
+
// Convert fixed point values back to integers (truncating)
|
| 1359 |
+
sss = _mm_srai_epi32(sss, coefs_precision);
|
| 1360 |
+
// Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
|
| 1361 |
+
// (a a a a b b b b c c c c d d d d) -> (a a b b c c d d)
|
| 1362 |
+
sss = _mm_packs_epi32(sss, zero);
|
| 1363 |
+
// Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
|
| 1364 |
+
// (a a b b c c d d) -> (a b c d)
|
| 1365 |
+
sss = _mm_packus_epi16(sss, zero);
|
| 1366 |
+
// Store one pixel to the output
|
| 1367 |
+
auto o = _mm_cvtsi128_si32(sss);
|
| 1368 |
+
if (num_channels == 3 && C10_UNLIKELY(j + 4 >= data_size)) {
|
| 1369 |
+
std::memcpy(lineOut + j, (uint8_t *) &o, 3);
|
| 1370 |
+
} else {
|
| 1371 |
+
std::memcpy(lineOut + j, (uint8_t *) &o, 4);
|
| 1372 |
+
}
|
| 1373 |
+
}
|
| 1374 |
+
}
|
| 1375 |
+
|
| 1376 |
+
} // anonymous namespace
|
| 1377 |
+
#endif // CPU_CAPABILITY_AVX2
|
| 1378 |
+
|
| 1379 |
+
#else
|
| 1380 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 1381 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/avx_mathfun.h
ADDED
|
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
/*
|
| 4 |
+
AVX implementation of sin, cos, sincos, exp and log
|
| 5 |
+
|
| 6 |
+
Based on "sse_mathfun.h", by Julien Pommier
|
| 7 |
+
http://gruntthepeon.free.fr/ssemath/
|
| 8 |
+
|
| 9 |
+
Copyright (C) 2012 Giovanni Garberoglio
|
| 10 |
+
Interdisciplinary Laboratory for Computational Science (LISC)
|
| 11 |
+
Fondazione Bruno Kessler and University of Trento
|
| 12 |
+
via Sommarive, 18
|
| 13 |
+
I-38123 Trento (Italy)
|
| 14 |
+
|
| 15 |
+
This software is provided 'as-is', without any express or implied
|
| 16 |
+
warranty. In no event will the authors be held liable for any damages
|
| 17 |
+
arising from the use of this software.
|
| 18 |
+
|
| 19 |
+
Permission is granted to anyone to use this software for any purpose,
|
| 20 |
+
including commercial applications, and to alter it and redistribute it
|
| 21 |
+
freely, subject to the following restrictions:
|
| 22 |
+
|
| 23 |
+
1. The origin of this software must not be misrepresented; you must not
|
| 24 |
+
claim that you wrote the original software. If you use this software
|
| 25 |
+
in a product, an acknowledgment in the product documentation would be
|
| 26 |
+
appreciated but is not required.
|
| 27 |
+
2. Altered source versions must be plainly marked as such, and must not be
|
| 28 |
+
misrepresented as being the original software.
|
| 29 |
+
3. This notice may not be removed or altered from any source distribution.
|
| 30 |
+
|
| 31 |
+
(this is the zlib license)
|
| 32 |
+
*/
|
| 33 |
+
|
| 34 |
+
#include <ATen/native/cpu/Intrinsics.h>
|
| 35 |
+
|
| 36 |
+
/* The original source of this file has been modified. */
|
| 37 |
+
#if defined(CPU_CAPABILITY_AVX2)
|
| 38 |
+
|
| 39 |
+
#if defined(__GNUC__)
|
| 40 |
+
# define ALIGN32_BEG __attribute__((aligned(32)))
|
| 41 |
+
#elif defined(_WIN32)
|
| 42 |
+
# define ALIGN32_BEG __declspec(align(32))
|
| 43 |
+
#endif
|
| 44 |
+
|
| 45 |
+
typedef __m256 v8sf; // vector of 8 float (avx2)
|
| 46 |
+
typedef __m256i v8si; // vector of 8 int (avx2)
|
| 47 |
+
|
| 48 |
+
/* declare some AVX constants -- why can't I figure a better way to do that? */
|
| 49 |
+
#define _PS256_CONST(Name, Val) \
|
| 50 |
+
static const ALIGN32_BEG float _ps256_##Name[8] = { Val, Val, Val, Val, Val, Val, Val, Val }
|
| 51 |
+
#define _PI32_CONST256(Name, Val) \
|
| 52 |
+
static const ALIGN32_BEG int _pi32_256_##Name[8] = { Val, Val, Val, Val, Val, Val, Val, Val }
|
| 53 |
+
#define _PS256_CONST_TYPE(Name, Type, Val) \
|
| 54 |
+
static const ALIGN32_BEG Type _ps256_##Name[8] = { Val, Val, Val, Val, Val, Val, Val, Val }
|
| 55 |
+
|
| 56 |
+
_PS256_CONST(1 , 1.0f);
|
| 57 |
+
_PS256_CONST(0p5, 0.5f);
|
| 58 |
+
/* the smallest non denormalized float number */
|
| 59 |
+
_PS256_CONST_TYPE(min_norm_pos, int, 0x00800000);
|
| 60 |
+
_PS256_CONST_TYPE(mant_mask, int, 0x7f800000);
|
| 61 |
+
_PS256_CONST_TYPE(inv_mant_mask, int, ~0x7f800000);
|
| 62 |
+
|
| 63 |
+
_PS256_CONST_TYPE(sign_mask, int, (int)0x80000000);
|
| 64 |
+
_PS256_CONST_TYPE(inv_sign_mask, int, ~0x80000000);
|
| 65 |
+
|
| 66 |
+
_PI32_CONST256(0, 0);
|
| 67 |
+
_PI32_CONST256(1, 1);
|
| 68 |
+
_PI32_CONST256(inv1, ~1);
|
| 69 |
+
_PI32_CONST256(2, 2);
|
| 70 |
+
_PI32_CONST256(4, 4);
|
| 71 |
+
_PI32_CONST256(0x7f, 0x7f);
|
| 72 |
+
|
| 73 |
+
_PS256_CONST(cephes_SQRTHF, 0.707106781186547524);
|
| 74 |
+
_PS256_CONST(cephes_log_p0, 7.0376836292E-2);
|
| 75 |
+
_PS256_CONST(cephes_log_p1, - 1.1514610310E-1);
|
| 76 |
+
_PS256_CONST(cephes_log_p2, 1.1676998740E-1);
|
| 77 |
+
_PS256_CONST(cephes_log_p3, - 1.2420140846E-1);
|
| 78 |
+
_PS256_CONST(cephes_log_p4, + 1.4249322787E-1);
|
| 79 |
+
_PS256_CONST(cephes_log_p5, - 1.6668057665E-1);
|
| 80 |
+
_PS256_CONST(cephes_log_p6, + 2.0000714765E-1);
|
| 81 |
+
_PS256_CONST(cephes_log_p7, - 2.4999993993E-1);
|
| 82 |
+
_PS256_CONST(cephes_log_p8, + 3.3333331174E-1);
|
| 83 |
+
_PS256_CONST(cephes_log_q1, -2.12194440e-4);
|
| 84 |
+
_PS256_CONST(cephes_log_q2, 0.693359375);
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
/* natural logarithm computed for 8 simultaneous float
|
| 88 |
+
return NaN for x <= 0
|
| 89 |
+
*/
|
| 90 |
+
inline v8sf log256_ps(v8sf x) {
|
| 91 |
+
v8si imm0;
|
| 92 |
+
v8sf one = *(v8sf*)_ps256_1;
|
| 93 |
+
|
| 94 |
+
//v8sf invalid_mask = _mm256_cmple_ps(x, _mm256_setzero_ps());
|
| 95 |
+
v8sf invalid_mask = _mm256_cmp_ps(x, _mm256_setzero_ps(), _CMP_LE_OS);
|
| 96 |
+
|
| 97 |
+
x = _mm256_max_ps(x, *(v8sf*)_ps256_min_norm_pos); /* cut off denormalized stuff */
|
| 98 |
+
|
| 99 |
+
// can be done with AVX2
|
| 100 |
+
imm0 = _mm256_srli_epi32(_mm256_castps_si256(x), 23);
|
| 101 |
+
|
| 102 |
+
/* keep only the fractional part */
|
| 103 |
+
x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_mant_mask);
|
| 104 |
+
x = _mm256_or_ps(x, *(v8sf*)_ps256_0p5);
|
| 105 |
+
|
| 106 |
+
// this is again another AVX2 instruction
|
| 107 |
+
imm0 = _mm256_sub_epi32(imm0, *(v8si*)_pi32_256_0x7f);
|
| 108 |
+
v8sf e = _mm256_cvtepi32_ps(imm0);
|
| 109 |
+
|
| 110 |
+
e = _mm256_add_ps(e, one);
|
| 111 |
+
|
| 112 |
+
/* part2:
|
| 113 |
+
if( x < SQRTHF ) {
|
| 114 |
+
e -= 1;
|
| 115 |
+
x = x + x - 1.0;
|
| 116 |
+
} else { x = x - 1.0; }
|
| 117 |
+
*/
|
| 118 |
+
//v8sf mask = _mm256_cmplt_ps(x, *(v8sf*)_ps256_cephes_SQRTHF);
|
| 119 |
+
v8sf mask = _mm256_cmp_ps(x, *(v8sf*)_ps256_cephes_SQRTHF, _CMP_LT_OS);
|
| 120 |
+
v8sf tmp = _mm256_and_ps(x, mask);
|
| 121 |
+
x = _mm256_sub_ps(x, one);
|
| 122 |
+
e = _mm256_sub_ps(e, _mm256_and_ps(one, mask));
|
| 123 |
+
x = _mm256_add_ps(x, tmp);
|
| 124 |
+
|
| 125 |
+
v8sf z = _mm256_mul_ps(x,x);
|
| 126 |
+
|
| 127 |
+
v8sf y = *(v8sf*)_ps256_cephes_log_p0;
|
| 128 |
+
y = _mm256_mul_ps(y, x);
|
| 129 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p1);
|
| 130 |
+
y = _mm256_mul_ps(y, x);
|
| 131 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p2);
|
| 132 |
+
y = _mm256_mul_ps(y, x);
|
| 133 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p3);
|
| 134 |
+
y = _mm256_mul_ps(y, x);
|
| 135 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p4);
|
| 136 |
+
y = _mm256_mul_ps(y, x);
|
| 137 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p5);
|
| 138 |
+
y = _mm256_mul_ps(y, x);
|
| 139 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p6);
|
| 140 |
+
y = _mm256_mul_ps(y, x);
|
| 141 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p7);
|
| 142 |
+
y = _mm256_mul_ps(y, x);
|
| 143 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p8);
|
| 144 |
+
y = _mm256_mul_ps(y, x);
|
| 145 |
+
|
| 146 |
+
y = _mm256_mul_ps(y, z);
|
| 147 |
+
|
| 148 |
+
tmp = _mm256_mul_ps(e, *(v8sf*)_ps256_cephes_log_q1);
|
| 149 |
+
y = _mm256_add_ps(y, tmp);
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5);
|
| 153 |
+
y = _mm256_sub_ps(y, tmp);
|
| 154 |
+
|
| 155 |
+
tmp = _mm256_mul_ps(e, *(v8sf*)_ps256_cephes_log_q2);
|
| 156 |
+
x = _mm256_add_ps(x, y);
|
| 157 |
+
x = _mm256_add_ps(x, tmp);
|
| 158 |
+
x = _mm256_or_ps(x, invalid_mask); // negative arg will be NAN
|
| 159 |
+
return x;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
_PS256_CONST(exp_hi, 88.3762626647949f);
|
| 163 |
+
_PS256_CONST(exp_lo, -88.3762626647949f);
|
| 164 |
+
|
| 165 |
+
_PS256_CONST(cephes_LOG2EF, 1.44269504088896341);
|
| 166 |
+
_PS256_CONST(cephes_exp_C1, 0.693359375);
|
| 167 |
+
_PS256_CONST(cephes_exp_C2, -2.12194440e-4);
|
| 168 |
+
|
| 169 |
+
_PS256_CONST(cephes_exp_p0, 1.9875691500E-4);
|
| 170 |
+
_PS256_CONST(cephes_exp_p1, 1.3981999507E-3);
|
| 171 |
+
_PS256_CONST(cephes_exp_p2, 8.3334519073E-3);
|
| 172 |
+
_PS256_CONST(cephes_exp_p3, 4.1665795894E-2);
|
| 173 |
+
_PS256_CONST(cephes_exp_p4, 1.6666665459E-1);
|
| 174 |
+
_PS256_CONST(cephes_exp_p5, 5.0000001201E-1);
|
| 175 |
+
|
| 176 |
+
inline v8sf exp256_ps(v8sf x) {
|
| 177 |
+
v8sf tmp = _mm256_setzero_ps(), fx;
|
| 178 |
+
v8si imm0;
|
| 179 |
+
v8sf one = *(v8sf*)_ps256_1;
|
| 180 |
+
|
| 181 |
+
x = _mm256_min_ps(x, *(v8sf*)_ps256_exp_hi);
|
| 182 |
+
x = _mm256_max_ps(x, *(v8sf*)_ps256_exp_lo);
|
| 183 |
+
|
| 184 |
+
/* express exp(x) as exp(g + n*log(2)) */
|
| 185 |
+
fx = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_LOG2EF);
|
| 186 |
+
fx = _mm256_add_ps(fx, *(v8sf*)_ps256_0p5);
|
| 187 |
+
|
| 188 |
+
/* how to perform a floorf with SSE: just below */
|
| 189 |
+
//imm0 = _mm256_cvttps_epi32(fx);
|
| 190 |
+
//tmp = _mm256_cvtepi32_ps(imm0);
|
| 191 |
+
|
| 192 |
+
tmp = _mm256_floor_ps(fx);
|
| 193 |
+
|
| 194 |
+
/* if greater, subtract 1 */
|
| 195 |
+
//v8sf mask = _mm256_cmpgt_ps(tmp, fx);
|
| 196 |
+
v8sf mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS);
|
| 197 |
+
mask = _mm256_and_ps(mask, one);
|
| 198 |
+
fx = _mm256_sub_ps(tmp, mask);
|
| 199 |
+
|
| 200 |
+
tmp = _mm256_mul_ps(fx, *(v8sf*)_ps256_cephes_exp_C1);
|
| 201 |
+
v8sf z = _mm256_mul_ps(fx, *(v8sf*)_ps256_cephes_exp_C2);
|
| 202 |
+
x = _mm256_sub_ps(x, tmp);
|
| 203 |
+
x = _mm256_sub_ps(x, z);
|
| 204 |
+
|
| 205 |
+
z = _mm256_mul_ps(x,x);
|
| 206 |
+
|
| 207 |
+
v8sf y = *(v8sf*)_ps256_cephes_exp_p0;
|
| 208 |
+
y = _mm256_mul_ps(y, x);
|
| 209 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p1);
|
| 210 |
+
y = _mm256_mul_ps(y, x);
|
| 211 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p2);
|
| 212 |
+
y = _mm256_mul_ps(y, x);
|
| 213 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p3);
|
| 214 |
+
y = _mm256_mul_ps(y, x);
|
| 215 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p4);
|
| 216 |
+
y = _mm256_mul_ps(y, x);
|
| 217 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p5);
|
| 218 |
+
y = _mm256_mul_ps(y, z);
|
| 219 |
+
y = _mm256_add_ps(y, x);
|
| 220 |
+
y = _mm256_add_ps(y, one);
|
| 221 |
+
|
| 222 |
+
/* build 2^n */
|
| 223 |
+
imm0 = _mm256_cvttps_epi32(fx);
|
| 224 |
+
// another two AVX2 instructions
|
| 225 |
+
imm0 = _mm256_add_epi32(imm0, *(v8si*)_pi32_256_0x7f);
|
| 226 |
+
imm0 = _mm256_slli_epi32(imm0, 23);
|
| 227 |
+
v8sf pow2n = _mm256_castsi256_ps(imm0);
|
| 228 |
+
y = _mm256_mul_ps(y, pow2n);
|
| 229 |
+
return y;
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
_PS256_CONST(minus_cephes_DP1, -0.78515625);
|
| 233 |
+
_PS256_CONST(minus_cephes_DP2, -2.4187564849853515625e-4);
|
| 234 |
+
_PS256_CONST(minus_cephes_DP3, -3.77489497744594108e-8);
|
| 235 |
+
_PS256_CONST(sincof_p0, -1.9515295891E-4);
|
| 236 |
+
_PS256_CONST(sincof_p1, 8.3321608736E-3);
|
| 237 |
+
_PS256_CONST(sincof_p2, -1.6666654611E-1);
|
| 238 |
+
_PS256_CONST(coscof_p0, 2.443315711809948E-005);
|
| 239 |
+
_PS256_CONST(coscof_p1, -1.388731625493765E-003);
|
| 240 |
+
_PS256_CONST(coscof_p2, 4.166664568298827E-002);
|
| 241 |
+
_PS256_CONST(cephes_FOPI, 1.27323954473516); // 4 / M_PI
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
/* evaluation of 8 sines at once using AVX intrinsics
|
| 245 |
+
|
| 246 |
+
The code is the exact rewriting of the cephes sinf function.
|
| 247 |
+
Precision is excellent as long as x < 8192 (I did not bother to
|
| 248 |
+
take into account the special handling they have for greater values
|
| 249 |
+
-- it does not return garbage for arguments over 8192, though, but
|
| 250 |
+
the extra precision is missing).
|
| 251 |
+
|
| 252 |
+
Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the
|
| 253 |
+
surprising but correct result.
|
| 254 |
+
|
| 255 |
+
*/
|
| 256 |
+
inline v8sf sin256_ps(v8sf x) { // any x
|
| 257 |
+
v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, sign_bit, y;
|
| 258 |
+
v8si imm0, imm2;
|
| 259 |
+
|
| 260 |
+
sign_bit = x;
|
| 261 |
+
/* take the absolute value */
|
| 262 |
+
x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);
|
| 263 |
+
/* extract the sign bit (upper one) */
|
| 264 |
+
sign_bit = _mm256_and_ps(sign_bit, *(v8sf*)_ps256_sign_mask);
|
| 265 |
+
|
| 266 |
+
/* scale by 4/Pi */
|
| 267 |
+
y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI);
|
| 268 |
+
|
| 269 |
+
/*
|
| 270 |
+
Here we start a series of integer operations, which are in the
|
| 271 |
+
realm of AVX2.
|
| 272 |
+
If we don't have AVX, let's perform them using SSE2 directives
|
| 273 |
+
*/
|
| 274 |
+
|
| 275 |
+
/* store the integer part of y in mm0 */
|
| 276 |
+
imm2 = _mm256_cvttps_epi32(y);
|
| 277 |
+
/* j=(j+1) & (~1) (see the cephes sources) */
|
| 278 |
+
// another two AVX2 instruction
|
| 279 |
+
imm2 = _mm256_add_epi32(imm2, *(v8si*)_pi32_256_1);
|
| 280 |
+
imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_inv1);
|
| 281 |
+
y = _mm256_cvtepi32_ps(imm2);
|
| 282 |
+
|
| 283 |
+
/* get the swap sign flag */
|
| 284 |
+
imm0 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_4);
|
| 285 |
+
imm0 = _mm256_slli_epi32(imm0, 29);
|
| 286 |
+
/* get the polynom selection mask
|
| 287 |
+
there is one polynom for 0 <= x <= Pi/4
|
| 288 |
+
and another one for Pi/4<x<=Pi/2
|
| 289 |
+
|
| 290 |
+
Both branches will be computed.
|
| 291 |
+
*/
|
| 292 |
+
imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2);
|
| 293 |
+
imm2 = _mm256_cmpeq_epi32(imm2,*(v8si*)_pi32_256_0);
|
| 294 |
+
|
| 295 |
+
v8sf swap_sign_bit = _mm256_castsi256_ps(imm0);
|
| 296 |
+
v8sf poly_mask = _mm256_castsi256_ps(imm2);
|
| 297 |
+
sign_bit = _mm256_xor_ps(sign_bit, swap_sign_bit);
|
| 298 |
+
|
| 299 |
+
/* The magic pass: "Extended precision modular arithmetic"
|
| 300 |
+
x = ((x - y * DP1) - y * DP2) - y * DP3; */
|
| 301 |
+
xmm1 = *(v8sf*)_ps256_minus_cephes_DP1;
|
| 302 |
+
xmm2 = *(v8sf*)_ps256_minus_cephes_DP2;
|
| 303 |
+
xmm3 = *(v8sf*)_ps256_minus_cephes_DP3;
|
| 304 |
+
xmm1 = _mm256_mul_ps(y, xmm1);
|
| 305 |
+
xmm2 = _mm256_mul_ps(y, xmm2);
|
| 306 |
+
xmm3 = _mm256_mul_ps(y, xmm3);
|
| 307 |
+
x = _mm256_add_ps(x, xmm1);
|
| 308 |
+
x = _mm256_add_ps(x, xmm2);
|
| 309 |
+
x = _mm256_add_ps(x, xmm3);
|
| 310 |
+
|
| 311 |
+
/* Evaluate the first polynom (0 <= x <= Pi/4) */
|
| 312 |
+
y = *(v8sf*)_ps256_coscof_p0;
|
| 313 |
+
v8sf z = _mm256_mul_ps(x,x);
|
| 314 |
+
|
| 315 |
+
y = _mm256_mul_ps(y, z);
|
| 316 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p1);
|
| 317 |
+
y = _mm256_mul_ps(y, z);
|
| 318 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p2);
|
| 319 |
+
y = _mm256_mul_ps(y, z);
|
| 320 |
+
y = _mm256_mul_ps(y, z);
|
| 321 |
+
v8sf tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5);
|
| 322 |
+
y = _mm256_sub_ps(y, tmp);
|
| 323 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_1);
|
| 324 |
+
|
| 325 |
+
/* Evaluate the second polynom (Pi/4 <= x <= 0) */
|
| 326 |
+
|
| 327 |
+
v8sf y2 = *(v8sf*)_ps256_sincof_p0;
|
| 328 |
+
y2 = _mm256_mul_ps(y2, z);
|
| 329 |
+
y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p1);
|
| 330 |
+
y2 = _mm256_mul_ps(y2, z);
|
| 331 |
+
y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p2);
|
| 332 |
+
y2 = _mm256_mul_ps(y2, z);
|
| 333 |
+
y2 = _mm256_mul_ps(y2, x);
|
| 334 |
+
y2 = _mm256_add_ps(y2, x);
|
| 335 |
+
|
| 336 |
+
/* select the correct result from the two polynoms */
|
| 337 |
+
xmm3 = poly_mask;
|
| 338 |
+
y2 = _mm256_and_ps(xmm3, y2); //, xmm3);
|
| 339 |
+
y = _mm256_andnot_ps(xmm3, y);
|
| 340 |
+
y = _mm256_add_ps(y,y2);
|
| 341 |
+
/* update the sign */
|
| 342 |
+
y = _mm256_xor_ps(y, sign_bit);
|
| 343 |
+
|
| 344 |
+
return y;
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
/* almost the same as sin_ps */
|
| 348 |
+
inline v8sf cos256_ps(v8sf x) { // any x
|
| 349 |
+
v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, y;
|
| 350 |
+
v8si imm0, imm2;
|
| 351 |
+
|
| 352 |
+
/* take the absolute value */
|
| 353 |
+
x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);
|
| 354 |
+
|
| 355 |
+
/* scale by 4/Pi */
|
| 356 |
+
y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI);
|
| 357 |
+
|
| 358 |
+
/* store the integer part of y in mm0 */
|
| 359 |
+
imm2 = _mm256_cvttps_epi32(y);
|
| 360 |
+
/* j=(j+1) & (~1) (see the cephes sources) */
|
| 361 |
+
imm2 = _mm256_add_epi32(imm2, *(v8si*)_pi32_256_1);
|
| 362 |
+
imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_inv1);
|
| 363 |
+
y = _mm256_cvtepi32_ps(imm2);
|
| 364 |
+
imm2 = _mm256_sub_epi32(imm2, *(v8si*)_pi32_256_2);
|
| 365 |
+
|
| 366 |
+
/* get the swap sign flag */
|
| 367 |
+
imm0 = _mm256_andnot_si256(imm2, *(v8si*)_pi32_256_4);
|
| 368 |
+
imm0 = _mm256_slli_epi32(imm0, 29);
|
| 369 |
+
/* get the polynom selection mask */
|
| 370 |
+
imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2);
|
| 371 |
+
imm2 = _mm256_cmpeq_epi32(imm2, *(v8si*)_pi32_256_0);
|
| 372 |
+
|
| 373 |
+
v8sf sign_bit = _mm256_castsi256_ps(imm0);
|
| 374 |
+
v8sf poly_mask = _mm256_castsi256_ps(imm2);
|
| 375 |
+
|
| 376 |
+
/* The magic pass: "Extended precision modular arithmetic"
|
| 377 |
+
x = ((x - y * DP1) - y * DP2) - y * DP3; */
|
| 378 |
+
xmm1 = *(v8sf*)_ps256_minus_cephes_DP1;
|
| 379 |
+
xmm2 = *(v8sf*)_ps256_minus_cephes_DP2;
|
| 380 |
+
xmm3 = *(v8sf*)_ps256_minus_cephes_DP3;
|
| 381 |
+
xmm1 = _mm256_mul_ps(y, xmm1);
|
| 382 |
+
xmm2 = _mm256_mul_ps(y, xmm2);
|
| 383 |
+
xmm3 = _mm256_mul_ps(y, xmm3);
|
| 384 |
+
x = _mm256_add_ps(x, xmm1);
|
| 385 |
+
x = _mm256_add_ps(x, xmm2);
|
| 386 |
+
x = _mm256_add_ps(x, xmm3);
|
| 387 |
+
|
| 388 |
+
/* Evaluate the first polynom (0 <= x <= Pi/4) */
|
| 389 |
+
y = *(v8sf*)_ps256_coscof_p0;
|
| 390 |
+
v8sf z = _mm256_mul_ps(x,x);
|
| 391 |
+
|
| 392 |
+
y = _mm256_mul_ps(y, z);
|
| 393 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p1);
|
| 394 |
+
y = _mm256_mul_ps(y, z);
|
| 395 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p2);
|
| 396 |
+
y = _mm256_mul_ps(y, z);
|
| 397 |
+
y = _mm256_mul_ps(y, z);
|
| 398 |
+
v8sf tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5);
|
| 399 |
+
y = _mm256_sub_ps(y, tmp);
|
| 400 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_1);
|
| 401 |
+
|
| 402 |
+
/* Evaluate the second polynom (Pi/4 <= x <= 0) */
|
| 403 |
+
|
| 404 |
+
v8sf y2 = *(v8sf*)_ps256_sincof_p0;
|
| 405 |
+
y2 = _mm256_mul_ps(y2, z);
|
| 406 |
+
y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p1);
|
| 407 |
+
y2 = _mm256_mul_ps(y2, z);
|
| 408 |
+
y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p2);
|
| 409 |
+
y2 = _mm256_mul_ps(y2, z);
|
| 410 |
+
y2 = _mm256_mul_ps(y2, x);
|
| 411 |
+
y2 = _mm256_add_ps(y2, x);
|
| 412 |
+
|
| 413 |
+
/* select the correct result from the two polynoms */
|
| 414 |
+
xmm3 = poly_mask;
|
| 415 |
+
y2 = _mm256_and_ps(xmm3, y2); //, xmm3);
|
| 416 |
+
y = _mm256_andnot_ps(xmm3, y);
|
| 417 |
+
y = _mm256_add_ps(y,y2);
|
| 418 |
+
/* update the sign */
|
| 419 |
+
y = _mm256_xor_ps(y, sign_bit);
|
| 420 |
+
|
| 421 |
+
return y;
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
/* since sin256_ps and cos256_ps are almost identical, sincos256_ps could replace both of them..
|
| 425 |
+
it is almost as fast, and gives you a free cosine with your sine */
|
| 426 |
+
inline void sincos256_ps(v8sf x, v8sf *s, v8sf *c) {
|
| 427 |
+
|
| 428 |
+
v8sf xmm1, xmm2, xmm3 = _mm256_setzero_ps(), sign_bit_sin, y;
|
| 429 |
+
v8si imm0, imm2, imm4;
|
| 430 |
+
|
| 431 |
+
sign_bit_sin = x;
|
| 432 |
+
/* take the absolute value */
|
| 433 |
+
x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);
|
| 434 |
+
/* extract the sign bit (upper one) */
|
| 435 |
+
sign_bit_sin = _mm256_and_ps(sign_bit_sin, *(v8sf*)_ps256_sign_mask);
|
| 436 |
+
|
| 437 |
+
/* scale by 4/Pi */
|
| 438 |
+
y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI);
|
| 439 |
+
|
| 440 |
+
/* store the integer part of y in imm2 */
|
| 441 |
+
imm2 = _mm256_cvttps_epi32(y);
|
| 442 |
+
|
| 443 |
+
/* j=(j+1) & (~1) (see the cephes sources) */
|
| 444 |
+
imm2 = _mm256_add_epi32(imm2, *(v8si*)_pi32_256_1);
|
| 445 |
+
imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_inv1);
|
| 446 |
+
|
| 447 |
+
y = _mm256_cvtepi32_ps(imm2);
|
| 448 |
+
imm4 = imm2;
|
| 449 |
+
|
| 450 |
+
/* get the swap sign flag for the sine */
|
| 451 |
+
imm0 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_4);
|
| 452 |
+
imm0 = _mm256_slli_epi32(imm0, 29);
|
| 453 |
+
//v8sf swap_sign_bit_sin = _mm256_castsi256_ps(imm0);
|
| 454 |
+
|
| 455 |
+
/* get the polynom selection mask for the sine*/
|
| 456 |
+
imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2);
|
| 457 |
+
imm2 = _mm256_cmpeq_epi32(imm2, *(v8si*)_pi32_256_0);
|
| 458 |
+
//v8sf poly_mask = _mm256_castsi256_ps(imm2);
|
| 459 |
+
|
| 460 |
+
v8sf swap_sign_bit_sin = _mm256_castsi256_ps(imm0);
|
| 461 |
+
v8sf poly_mask = _mm256_castsi256_ps(imm2);
|
| 462 |
+
|
| 463 |
+
/* The magic pass: "Extended precision modular arithmetic"
|
| 464 |
+
x = ((x - y * DP1) - y * DP2) - y * DP3; */
|
| 465 |
+
xmm1 = *(v8sf*)_ps256_minus_cephes_DP1;
|
| 466 |
+
xmm2 = *(v8sf*)_ps256_minus_cephes_DP2;
|
| 467 |
+
xmm3 = *(v8sf*)_ps256_minus_cephes_DP3;
|
| 468 |
+
xmm1 = _mm256_mul_ps(y, xmm1);
|
| 469 |
+
xmm2 = _mm256_mul_ps(y, xmm2);
|
| 470 |
+
xmm3 = _mm256_mul_ps(y, xmm3);
|
| 471 |
+
x = _mm256_add_ps(x, xmm1);
|
| 472 |
+
x = _mm256_add_ps(x, xmm2);
|
| 473 |
+
x = _mm256_add_ps(x, xmm3);
|
| 474 |
+
|
| 475 |
+
imm4 = _mm256_sub_epi32(imm4, *(v8si*)_pi32_256_2);
|
| 476 |
+
imm4 = _mm256_andnot_si256(imm4, *(v8si*)_pi32_256_4);
|
| 477 |
+
imm4 = _mm256_slli_epi32(imm4, 29);
|
| 478 |
+
|
| 479 |
+
v8sf sign_bit_cos = _mm256_castsi256_ps(imm4);
|
| 480 |
+
|
| 481 |
+
sign_bit_sin = _mm256_xor_ps(sign_bit_sin, swap_sign_bit_sin);
|
| 482 |
+
|
| 483 |
+
/* Evaluate the first polynom (0 <= x <= Pi/4) */
|
| 484 |
+
v8sf z = _mm256_mul_ps(x,x);
|
| 485 |
+
y = *(v8sf*)_ps256_coscof_p0;
|
| 486 |
+
|
| 487 |
+
y = _mm256_mul_ps(y, z);
|
| 488 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p1);
|
| 489 |
+
y = _mm256_mul_ps(y, z);
|
| 490 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p2);
|
| 491 |
+
y = _mm256_mul_ps(y, z);
|
| 492 |
+
y = _mm256_mul_ps(y, z);
|
| 493 |
+
v8sf tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5);
|
| 494 |
+
y = _mm256_sub_ps(y, tmp);
|
| 495 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_1);
|
| 496 |
+
|
| 497 |
+
/* Evaluate the second polynom (Pi/4 <= x <= 0) */
|
| 498 |
+
|
| 499 |
+
v8sf y2 = *(v8sf*)_ps256_sincof_p0;
|
| 500 |
+
y2 = _mm256_mul_ps(y2, z);
|
| 501 |
+
y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p1);
|
| 502 |
+
y2 = _mm256_mul_ps(y2, z);
|
| 503 |
+
y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p2);
|
| 504 |
+
y2 = _mm256_mul_ps(y2, z);
|
| 505 |
+
y2 = _mm256_mul_ps(y2, x);
|
| 506 |
+
y2 = _mm256_add_ps(y2, x);
|
| 507 |
+
|
| 508 |
+
/* select the correct result from the two polynoms */
|
| 509 |
+
xmm3 = poly_mask;
|
| 510 |
+
v8sf ysin2 = _mm256_and_ps(xmm3, y2);
|
| 511 |
+
v8sf ysin1 = _mm256_andnot_ps(xmm3, y);
|
| 512 |
+
y2 = _mm256_sub_ps(y2,ysin2);
|
| 513 |
+
y = _mm256_sub_ps(y, ysin1);
|
| 514 |
+
|
| 515 |
+
xmm1 = _mm256_add_ps(ysin1,ysin2);
|
| 516 |
+
xmm2 = _mm256_add_ps(y,y2);
|
| 517 |
+
|
| 518 |
+
/* update the sign */
|
| 519 |
+
*s = _mm256_xor_ps(xmm1, sign_bit_sin);
|
| 520 |
+
*c = _mm256_xor_ps(xmm2, sign_bit_cos);
|
| 521 |
+
}
|
| 522 |
+
|
| 523 |
+
#endif // CPU_CAPABILITY_AVX2
|
| 524 |
+
|
| 525 |
+
#else
|
| 526 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 527 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/int_mm_kernel.h
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/core/Tensor.h>
|
| 5 |
+
#include <ATen/native/DispatchStub.h>
|
| 6 |
+
|
| 7 |
+
namespace at::native {
|
| 8 |
+
|
| 9 |
+
using weight_to_int4pack_fn = void (*)(const Tensor&, const Tensor&);
|
| 10 |
+
using int4pack_mm_fn =
|
| 11 |
+
void (*)(const Tensor&, const Tensor&, const Tensor&, int, const Tensor&);
|
| 12 |
+
using int8pack_mm_fn =
|
| 13 |
+
void (*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&);
|
| 14 |
+
using dyn_quant_pack_4bit_weight_fn = void (*)(
|
| 15 |
+
Tensor&,
|
| 16 |
+
const Tensor&,
|
| 17 |
+
const Tensor&,
|
| 18 |
+
const std::optional<Tensor>& bias,
|
| 19 |
+
const int64_t,
|
| 20 |
+
const int64_t,
|
| 21 |
+
const int64_t);
|
| 22 |
+
using dyn_quant_matmul_4bit_fn = void (*)(
|
| 23 |
+
const Tensor&,
|
| 24 |
+
const Tensor&,
|
| 25 |
+
const Tensor&,
|
| 26 |
+
const int64_t,
|
| 27 |
+
const int64_t,
|
| 28 |
+
const int64_t,
|
| 29 |
+
const int64_t);
|
| 30 |
+
|
| 31 |
+
DECLARE_DISPATCH(weight_to_int4pack_fn, weight_to_int4pack_stub)
|
| 32 |
+
DECLARE_DISPATCH(int4pack_mm_fn, int4pack_mm_stub)
|
| 33 |
+
DECLARE_DISPATCH(int8pack_mm_fn, int8pack_mm_stub)
|
| 34 |
+
DECLARE_DISPATCH(
|
| 35 |
+
dyn_quant_pack_4bit_weight_fn,
|
| 36 |
+
dyn_quant_pack_4bit_weight_stub)
|
| 37 |
+
DECLARE_DISPATCH(dyn_quant_matmul_4bit_fn, dyn_quant_matmul_4bit_stub)
|
| 38 |
+
|
| 39 |
+
} // namespace at::native
|
| 40 |
+
|
| 41 |
+
#else
|
| 42 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 43 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/mixed_data_type.h
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/core/Tensor.h>
|
| 5 |
+
|
| 6 |
+
namespace at::native {
|
| 7 |
+
|
| 8 |
+
inline ScalarType first_type() {
|
| 9 |
+
return ScalarType::Undefined;
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
template <typename... Args>
|
| 13 |
+
inline ScalarType first_type(const Tensor& arg, const Args&... parameters) {
|
| 14 |
+
return arg.defined() ? arg.scalar_type() : first_type(parameters...);
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
template <typename... Args>
|
| 18 |
+
inline bool is_mixed_type(const Tensor& input, const Args&... parameters) {
|
| 19 |
+
const auto parameter_type = first_type(parameters...);
|
| 20 |
+
return ((parameter_type != ScalarType::Undefined) &&
|
| 21 |
+
(parameter_type != input.scalar_type()));
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
// currently on CPU, mixed data type is only supported
|
| 25 |
+
// when input is 'BFloat16' or 'Half' and parameters are 'Float'
|
| 26 |
+
inline void check_mixed_data_type(const Tensor& input) {
|
| 27 |
+
TORCH_CHECK(at::isReducedFloatingType(input.scalar_type()),
|
| 28 |
+
"mixed dtype (CPU): all inputs must share same datatype.");
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
template <typename... Args>
|
| 32 |
+
inline void check_mixed_data_type(const Tensor& input, const Tensor& parameter, const Args&... parameters) {
|
| 33 |
+
TORCH_CHECK(!parameter.defined() || parameter.scalar_type() == ScalarType::Float,
|
| 34 |
+
"mixed dtype (CPU): expect parameter to have scalar type of Float");
|
| 35 |
+
check_mixed_data_type(input, parameters...);
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
inline ScalarType param_scalar_type(const Tensor& t, bool is_mixed_type) {
|
| 39 |
+
return is_mixed_type ? ScalarType::Float : t.scalar_type();
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
} // namespace at::native
|
| 43 |
+
|
| 44 |
+
#else
|
| 45 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 46 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/utils.h
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/Parallel.h>
|
| 5 |
+
#include <ATen/core/TensorAccessor.h>
|
| 6 |
+
#include <ATen/cpu/vec/vec.h>
|
| 7 |
+
#include <c10/util/llvmMathExtras.h>
|
| 8 |
+
|
| 9 |
+
#ifdef USE_FBGEMM
|
| 10 |
+
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi")
|
| 11 |
+
#include <fbgemm/Fbgemm.h>
|
| 12 |
+
C10_DIAGNOSTIC_POP()
|
| 13 |
+
#endif
|
| 14 |
+
|
| 15 |
+
namespace at::native {
|
| 16 |
+
|
| 17 |
+
template <typename T>
|
| 18 |
+
inline void _store(T* dst, at::vec::Vectorized<T> src) {
|
| 19 |
+
src.store(dst);
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
inline void _store(at::BFloat16* dst, at::vec::Vectorized<float> src) {
|
| 23 |
+
auto res = at::vec::convert_float_bfloat16(src, src);
|
| 24 |
+
res.store(dst, at::vec::Vectorized<float>::size());
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
inline void _store(at::Half* dst, at::vec::Vectorized<float> src) {
|
| 28 |
+
auto res = at::vec::convert_float_half(src, src);
|
| 29 |
+
res.store(dst, at::vec::Vectorized<float>::size());
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
inline namespace CPU_CAPABILITY {
|
| 33 |
+
|
| 34 |
+
template <typename T>
|
| 35 |
+
inline T data_index_init(T offset) {
|
| 36 |
+
return offset;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
template <typename T, typename... Args>
|
| 40 |
+
inline T data_index_init(T offset, T& x, const T& X, Args&&... args) {
|
| 41 |
+
offset = data_index_init(offset, std::forward<Args>(args)...);
|
| 42 |
+
x = offset % X;
|
| 43 |
+
return offset / X;
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
inline bool data_index_step() {
|
| 47 |
+
return true;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
template <typename T, typename... Args>
|
| 51 |
+
inline bool data_index_step(T& x, const T& X, Args&&... args) {
|
| 52 |
+
if (data_index_step(std::forward<Args>(args)...)) {
|
| 53 |
+
x = ((x + 1) == X) ? 0 : (x + 1);
|
| 54 |
+
return x == 0;
|
| 55 |
+
}
|
| 56 |
+
return false;
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
// Helper struct for bfloat16/float16 vectorization
|
| 60 |
+
// Useful when you need float as immediate dtype or accumulate dtype
|
| 61 |
+
using namespace vec;
|
| 62 |
+
struct Vec2 {
|
| 63 |
+
Vectorized<float> val0, val1;
|
| 64 |
+
Vec2(Vectorized<float> v0, Vectorized<float> v1) : val0(v0), val1(v1) {}
|
| 65 |
+
Vec2(float v) : val0(v), val1(v) {}
|
| 66 |
+
static Vec2 loadu(const BFloat16* ptr) {
|
| 67 |
+
auto [v0, v1] = convert_bfloat16_float(Vectorized<BFloat16>::loadu(ptr));
|
| 68 |
+
return {v0, v1};
|
| 69 |
+
}
|
| 70 |
+
static Vec2 loadu(const Half* ptr) {
|
| 71 |
+
auto [v0, v1] = convert_half_float(Vectorized<Half>::loadu(ptr));
|
| 72 |
+
return {v0, v1};
|
| 73 |
+
}
|
| 74 |
+
static Vec2 loadu(const float* ptr) {
|
| 75 |
+
return {Vectorized<float>::loadu(ptr), Vectorized<float>::loadu(ptr + Vectorized<float>::size())};
|
| 76 |
+
}
|
| 77 |
+
void store(BFloat16* ptr) const {
|
| 78 |
+
Vectorized<BFloat16> val = convert_float_bfloat16(val0, val1);
|
| 79 |
+
val.store(ptr);
|
| 80 |
+
}
|
| 81 |
+
void store(Half* ptr) const {
|
| 82 |
+
Vectorized<Half> val = convert_float_half(val0, val1);
|
| 83 |
+
val.store(ptr);
|
| 84 |
+
}
|
| 85 |
+
void store(float* ptr) const {
|
| 86 |
+
val0.store(ptr);
|
| 87 |
+
val1.store(ptr + Vectorized<float>::size());
|
| 88 |
+
}
|
| 89 |
+
};
|
| 90 |
+
inline Vec2 operator+(const Vec2& a, const Vec2& b) { return {a.val0 + b.val0, a.val1 + b.val1}; }
|
| 91 |
+
inline Vec2 operator*(const Vec2& a, const Vec2& b) { return {a.val0 * b.val0, a.val1 * b.val1}; }
|
| 92 |
+
inline Vec2 operator-(const Vec2& a, const Vec2& b) { return {a.val0 - b.val0, a.val1 - b.val1}; }
|
| 93 |
+
inline Vec2 operator/(const Vec2& a, const Vec2& b) { return {a.val0 / b.val0, a.val1 / b.val1}; }
|
| 94 |
+
inline Vec2 maximum(const Vec2& a, const Vec2& b) { return {vec::maximum(a.val0, b.val0), vec::maximum(a.val1, b.val1)}; }
|
| 95 |
+
inline Vec2 minimum(const Vec2& a, const Vec2& b) { return {vec::minimum(a.val0, b.val0), vec::minimum(a.val1, b.val1)}; }
|
| 96 |
+
|
| 97 |
+
template <typename scalar_t> struct VectorizedType { using type = Vectorized<scalar_t>; };
|
| 98 |
+
template <> struct VectorizedType<BFloat16> { using type = Vec2; };
|
| 99 |
+
template <> struct VectorizedType<Half> { using type = Vec2; };
|
| 100 |
+
template <typename scalar_t> using VecType = typename VectorizedType<scalar_t>::type;
|
| 101 |
+
|
| 102 |
+
// Helper for mixed data type parameter Vec::load
|
| 103 |
+
inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const BFloat16* ptr) {
|
| 104 |
+
return convert_bfloat16_float(Vectorized<BFloat16>::loadu(ptr));
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const Half* ptr) {
|
| 108 |
+
return convert_half_float(Vectorized<Half>::loadu(ptr));
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const float* ptr) {
|
| 112 |
+
using Vec = Vectorized<float>;
|
| 113 |
+
return std::make_tuple(Vec::loadu(ptr), Vec::loadu(ptr + Vec::size()));
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const BFloat16* ptr, int64_t count) {
|
| 117 |
+
return convert_bfloat16_float(Vectorized<BFloat16>::loadu(ptr, count));
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const Half* ptr, int64_t count) {
|
| 121 |
+
return convert_half_float(Vectorized<Half>::loadu(ptr, count));
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const float* ptr, int64_t count) {
|
| 125 |
+
using Vec = Vectorized<float>;
|
| 126 |
+
if (count > Vec::size()) {
|
| 127 |
+
return std::make_tuple(Vec::loadu(ptr), Vec::loadu(ptr + Vec::size(), count - Vec::size()));
|
| 128 |
+
} else {
|
| 129 |
+
return std::make_tuple(Vec::loadu(ptr, count), Vec(0));
|
| 130 |
+
}
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
} // namespace
|
| 134 |
+
|
| 135 |
+
namespace utils {
|
| 136 |
+
|
| 137 |
+
template <typename T>
|
| 138 |
+
T CeilLog2(const T& x) {
|
| 139 |
+
if (x <= 2) {
|
| 140 |
+
return 1;
|
| 141 |
+
}
|
| 142 |
+
// Last set bit is floor(log2(x)), floor + 1 is ceil
|
| 143 |
+
// except when x is an exact powers of 2, so subtract 1 first
|
| 144 |
+
return static_cast<T>(llvm::findLastSet(static_cast<uint64_t>(x) - 1)) + 1;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
// matrix transpose:
|
| 148 |
+
// src has shape of M by N, with leading dimension of ld_src
|
| 149 |
+
// dst has shape of N by M, with leading dimension of ld_dst
|
| 150 |
+
template <typename T>
|
| 151 |
+
inline void transpose(int64_t M, int64_t N, const T* src, int64_t ld_src, T* dst, int64_t ld_dst) {
|
| 152 |
+
for (int64_t j = 0; j < N; j++) {
|
| 153 |
+
for (int64_t i = 0; i < M; i++) {
|
| 154 |
+
dst[j * ld_dst + i] = c10::load(&(src[i * ld_src + j]));
|
| 155 |
+
}
|
| 156 |
+
}
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
#ifdef USE_FBGEMM
|
| 160 |
+
template <>
|
| 161 |
+
inline void transpose<float>(int64_t M, int64_t N, const float* src, int64_t ld_src, float* dst, int64_t ld_dst) {
|
| 162 |
+
TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
|
| 163 |
+
fbgemm::transpose_simd<float>(M, N, src, ld_src, dst, ld_dst);
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
template <>
|
| 167 |
+
inline void transpose<uint16_t>(int64_t M, int64_t N, const uint16_t* src, int64_t ld_src, uint16_t* dst, int64_t ld_dst) {
|
| 168 |
+
TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
|
| 169 |
+
fbgemm::transpose_simd<uint16_t>(M, N, src, ld_src, dst, ld_dst);
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
template <>
|
| 173 |
+
inline void transpose<uint8_t>(int64_t M, int64_t N, const uint8_t* src, int64_t ld_src, uint8_t* dst, int64_t ld_dst) {
|
| 174 |
+
TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
|
| 175 |
+
fbgemm::transpose_simd<uint8_t>(M, N, src, ld_src, dst, ld_dst);
|
| 176 |
+
}
|
| 177 |
+
#endif
|
| 178 |
+
|
| 179 |
+
template <typename index_t, typename F>
|
| 180 |
+
inline void parallel_sparse_csr(
|
| 181 |
+
const TensorAccessor<index_t, 1>& crow_acc,
|
| 182 |
+
const int64_t M,
|
| 183 |
+
const int64_t nnz,
|
| 184 |
+
const F& f) {
|
| 185 |
+
TORCH_CHECK(crow_acc.size(0) == M + 1);
|
| 186 |
+
|
| 187 |
+
// directly parallel on `M` may lead to load imbalance,
|
| 188 |
+
// statically determine thread partition here to average payload
|
| 189 |
+
// for each thread.
|
| 190 |
+
int num_threads = at::get_num_threads();
|
| 191 |
+
std::vector<int64_t> thread_splits(num_threads + 1, M);
|
| 192 |
+
|
| 193 |
+
int64_t thread_averge_payload = std::max((int64_t)1, divup(nnz, num_threads));
|
| 194 |
+
|
| 195 |
+
thread_splits[0] = 0;
|
| 196 |
+
int64_t sum = 0;
|
| 197 |
+
int64_t t = 1;
|
| 198 |
+
for (const auto m : c10::irange(M)) {
|
| 199 |
+
int64_t row_start = crow_acc[m];
|
| 200 |
+
int64_t row_end = crow_acc[m + 1];
|
| 201 |
+
sum += row_end - row_start;
|
| 202 |
+
if (sum > t * thread_averge_payload) {
|
| 203 |
+
thread_splits[t] = m;
|
| 204 |
+
t++;
|
| 205 |
+
}
|
| 206 |
+
}
|
| 207 |
+
// need to restore the last index,
|
| 208 |
+
// due to rounding error when calculating `thread_averge_payload`.
|
| 209 |
+
thread_splits[num_threads] = M;
|
| 210 |
+
|
| 211 |
+
at::parallel_for(0, num_threads, 1, [&](int64_t cbegin, int64_t cend) {
|
| 212 |
+
int tid = at::get_thread_num();
|
| 213 |
+
int64_t begin = thread_splits[tid];
|
| 214 |
+
int64_t end = thread_splits[tid + 1];
|
| 215 |
+
f(begin, end);
|
| 216 |
+
});
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
} // namespace utils
|
| 220 |
+
|
| 221 |
+
} // namespace at::native
|
| 222 |
+
|
| 223 |
+
#else
|
| 224 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 225 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cpu/zmath.h
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
// Complex number math operations that act as no-ops for other dtypes.
|
| 5 |
+
#include <c10/util/complex.h>
|
| 6 |
+
#include <c10/util/MathConstants.h>
|
| 7 |
+
#include<ATen/NumericUtils.h>
|
| 8 |
+
|
| 9 |
+
namespace at::native {
|
| 10 |
+
inline namespace CPU_CAPABILITY {
|
| 11 |
+
|
| 12 |
+
template <typename SCALAR_TYPE, typename VALUE_TYPE=SCALAR_TYPE>
|
| 13 |
+
inline VALUE_TYPE zabs (SCALAR_TYPE z) {
|
| 14 |
+
return z;
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
template<>
|
| 18 |
+
inline c10::complex<float> zabs <c10::complex<float>> (c10::complex<float> z) {
|
| 19 |
+
return c10::complex<float>(std::abs(z));
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
template<>
|
| 23 |
+
inline float zabs <c10::complex<float>, float> (c10::complex<float> z) {
|
| 24 |
+
return std::abs(z);
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
template<>
|
| 28 |
+
inline c10::complex<double> zabs <c10::complex<double>> (c10::complex<double> z) {
|
| 29 |
+
return c10::complex<double>(std::abs(z));
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
template<>
|
| 33 |
+
inline double zabs <c10::complex<double>, double> (c10::complex<double> z) {
|
| 34 |
+
return std::abs(z);
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
// This overload corresponds to non-complex dtypes.
|
| 38 |
+
// The function is consistent with its NumPy equivalent
|
| 39 |
+
// for non-complex dtypes where `pi` is returned for
|
| 40 |
+
// negative real numbers and `0` is returned for 0 or positive
|
| 41 |
+
// real numbers.
|
| 42 |
+
// Note: `nan` is propagated.
|
| 43 |
+
template <typename SCALAR_TYPE, typename VALUE_TYPE=SCALAR_TYPE>
|
| 44 |
+
inline VALUE_TYPE angle_impl (SCALAR_TYPE z) {
|
| 45 |
+
if (at::_isnan(z)) {
|
| 46 |
+
return z;
|
| 47 |
+
}
|
| 48 |
+
return z < 0 ? c10::pi<double> : 0;
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
template<>
|
| 52 |
+
inline c10::complex<float> angle_impl <c10::complex<float>> (c10::complex<float> z) {
|
| 53 |
+
return c10::complex<float>(std::arg(z), 0.0);
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
template<>
|
| 57 |
+
inline float angle_impl <c10::complex<float>, float> (c10::complex<float> z) {
|
| 58 |
+
return std::arg(z);
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
template<>
|
| 62 |
+
inline c10::complex<double> angle_impl <c10::complex<double>> (c10::complex<double> z) {
|
| 63 |
+
return c10::complex<double>(std::arg(z), 0.0);
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
template<>
|
| 67 |
+
inline double angle_impl <c10::complex<double>, double> (c10::complex<double> z) {
|
| 68 |
+
return std::arg(z);
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
template <typename SCALAR_TYPE, typename VALUE_TYPE=SCALAR_TYPE>
|
| 72 |
+
constexpr VALUE_TYPE real_impl (SCALAR_TYPE z) {
|
| 73 |
+
return z; //No-Op
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
template<>
|
| 77 |
+
constexpr c10::complex<float> real_impl <c10::complex<float>> (c10::complex<float> z) {
|
| 78 |
+
return c10::complex<float>(z.real(), 0.0);
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
template<>
|
| 82 |
+
constexpr float real_impl <c10::complex<float>, float> (c10::complex<float> z) {
|
| 83 |
+
return z.real();
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
template<>
|
| 87 |
+
constexpr c10::complex<double> real_impl <c10::complex<double>> (c10::complex<double> z) {
|
| 88 |
+
return c10::complex<double>(z.real(), 0.0);
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
template<>
|
| 92 |
+
constexpr double real_impl <c10::complex<double>, double> (c10::complex<double> z) {
|
| 93 |
+
return z.real();
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
template <typename SCALAR_TYPE, typename VALUE_TYPE=SCALAR_TYPE>
|
| 97 |
+
constexpr VALUE_TYPE imag_impl (SCALAR_TYPE /*z*/) {
|
| 98 |
+
return 0;
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
template<>
|
| 102 |
+
constexpr c10::complex<float> imag_impl <c10::complex<float>> (c10::complex<float> z) {
|
| 103 |
+
return c10::complex<float>(z.imag(), 0.0);
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
template<>
|
| 107 |
+
constexpr float imag_impl <c10::complex<float>, float> (c10::complex<float> z) {
|
| 108 |
+
return z.imag();
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
template<>
|
| 112 |
+
constexpr c10::complex<double> imag_impl <c10::complex<double>> (c10::complex<double> z) {
|
| 113 |
+
return c10::complex<double>(z.imag(), 0.0);
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
template<>
|
| 117 |
+
constexpr double imag_impl <c10::complex<double>, double> (c10::complex<double> z) {
|
| 118 |
+
return z.imag();
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
template <typename TYPE>
|
| 122 |
+
inline TYPE conj_impl (TYPE z) {
|
| 123 |
+
return z; //No-Op
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
template<>
|
| 127 |
+
inline c10::complex<at::Half> conj_impl <c10::complex<at::Half>> (c10::complex<at::Half> z) {
|
| 128 |
+
return c10::complex<at::Half>{z.real(), -z.imag()};
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
template<>
|
| 132 |
+
inline c10::complex<float> conj_impl <c10::complex<float>> (c10::complex<float> z) {
|
| 133 |
+
return c10::complex<float>(z.real(), -z.imag());
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
template<>
|
| 137 |
+
inline c10::complex<double> conj_impl <c10::complex<double>> (c10::complex<double> z) {
|
| 138 |
+
return c10::complex<double>(z.real(), -z.imag());
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
template <typename TYPE>
|
| 142 |
+
inline TYPE ceil_impl (TYPE z) {
|
| 143 |
+
return std::ceil(z);
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
template <>
|
| 147 |
+
inline c10::complex<float> ceil_impl (c10::complex<float> z) {
|
| 148 |
+
return c10::complex<float>(std::ceil(z.real()), std::ceil(z.imag()));
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
template <>
|
| 152 |
+
inline c10::complex<double> ceil_impl (c10::complex<double> z) {
|
| 153 |
+
return c10::complex<double>(std::ceil(z.real()), std::ceil(z.imag()));
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
template<typename T>
|
| 157 |
+
inline c10::complex<T> sgn_impl (c10::complex<T> z) {
|
| 158 |
+
if (z == c10::complex<T>(0, 0)) {
|
| 159 |
+
return c10::complex<T>(0, 0);
|
| 160 |
+
} else {
|
| 161 |
+
return z / zabs(z);
|
| 162 |
+
}
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
template <typename TYPE>
|
| 166 |
+
inline TYPE floor_impl (TYPE z) {
|
| 167 |
+
return std::floor(z);
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
template <>
|
| 171 |
+
inline c10::complex<float> floor_impl (c10::complex<float> z) {
|
| 172 |
+
return c10::complex<float>(std::floor(z.real()), std::floor(z.imag()));
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
template <>
|
| 176 |
+
inline c10::complex<double> floor_impl (c10::complex<double> z) {
|
| 177 |
+
return c10::complex<double>(std::floor(z.real()), std::floor(z.imag()));
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
template <typename TYPE>
|
| 181 |
+
inline TYPE round_impl (TYPE z) {
|
| 182 |
+
return std::nearbyint(z);
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
template <>
|
| 186 |
+
inline c10::complex<float> round_impl (c10::complex<float> z) {
|
| 187 |
+
return c10::complex<float>(std::nearbyint(z.real()), std::nearbyint(z.imag()));
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
template <>
|
| 191 |
+
inline c10::complex<double> round_impl (c10::complex<double> z) {
|
| 192 |
+
return c10::complex<double>(std::nearbyint(z.real()), std::nearbyint(z.imag()));
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
template <typename TYPE>
|
| 196 |
+
inline TYPE trunc_impl (TYPE z) {
|
| 197 |
+
return std::trunc(z);
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
template <>
|
| 201 |
+
inline c10::complex<float> trunc_impl (c10::complex<float> z) {
|
| 202 |
+
return c10::complex<float>(std::trunc(z.real()), std::trunc(z.imag()));
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
template <>
|
| 206 |
+
inline c10::complex<double> trunc_impl (c10::complex<double> z) {
|
| 207 |
+
return c10::complex<double>(std::trunc(z.real()), std::trunc(z.imag()));
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
template <typename TYPE, std::enable_if_t<!c10::is_complex<TYPE>::value, int> = 0>
|
| 211 |
+
inline TYPE max_impl (TYPE a, TYPE b) {
|
| 212 |
+
if (_isnan<TYPE>(a) || _isnan<TYPE>(b)) {
|
| 213 |
+
return std::numeric_limits<TYPE>::quiet_NaN();
|
| 214 |
+
} else {
|
| 215 |
+
return std::max(a, b);
|
| 216 |
+
}
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
template <typename TYPE, std::enable_if_t<c10::is_complex<TYPE>::value, int> = 0>
|
| 220 |
+
inline TYPE max_impl (TYPE a, TYPE b) {
|
| 221 |
+
if (_isnan<TYPE>(a)) {
|
| 222 |
+
return a;
|
| 223 |
+
} else if (_isnan<TYPE>(b)) {
|
| 224 |
+
return b;
|
| 225 |
+
} else {
|
| 226 |
+
return std::abs(a) > std::abs(b) ? a : b;
|
| 227 |
+
}
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
template <typename TYPE, std::enable_if_t<!c10::is_complex<TYPE>::value, int> = 0>
|
| 231 |
+
inline TYPE min_impl (TYPE a, TYPE b) {
|
| 232 |
+
if (_isnan<TYPE>(a) || _isnan<TYPE>(b)) {
|
| 233 |
+
return std::numeric_limits<TYPE>::quiet_NaN();
|
| 234 |
+
} else {
|
| 235 |
+
return std::min(a, b);
|
| 236 |
+
}
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
template <typename TYPE, std::enable_if_t<c10::is_complex<TYPE>::value, int> = 0>
|
| 240 |
+
inline TYPE min_impl (TYPE a, TYPE b) {
|
| 241 |
+
if (_isnan<TYPE>(a)) {
|
| 242 |
+
return a;
|
| 243 |
+
} else if (_isnan<TYPE>(b)) {
|
| 244 |
+
return b;
|
| 245 |
+
} else {
|
| 246 |
+
return std::abs(a) < std::abs(b) ? a : b;
|
| 247 |
+
}
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
} // end namespace
|
| 251 |
+
} //end at::native
|
| 252 |
+
|
| 253 |
+
#else
|
| 254 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 255 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/native/cuda/CUDAJitLoops.cuh
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
| 2 |
+
#pragma once
|
| 3 |
+
#include <ATen/jit_macros.h>
|
| 4 |
+
|
| 5 |
+
// Jiterator functions are guarded behind this macro
|
| 6 |
+
#if AT_USE_JITERATOR()
|
| 7 |
+
|
| 8 |
+
#include <ATen/OpMathType.h>
|
| 9 |
+
#include <ATen/TensorIterator.h>
|
| 10 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 11 |
+
#include <ATen/cuda/detail/OffsetCalculator.cuh>
|
| 12 |
+
#include <ATen/native/cuda/jit_utils.h>
|
| 13 |
+
#include <ATen/native/cuda/MemoryAccess.cuh>
|
| 14 |
+
#include <ATen/native/cuda/thread_constants.h>
|
| 15 |
+
|
| 16 |
+
#include <ATen/native/cuda/Loops.cuh>
|
| 17 |
+
|
| 18 |
+
#include <c10/macros/Macros.h>
|
| 19 |
+
#include <c10/core/ScalarType.h>
|
| 20 |
+
#include <c10/util/SmallBuffer.h>
|
| 21 |
+
|
| 22 |
+
#include <array>
|
| 23 |
+
#include <initializer_list>
|
| 24 |
+
#include <type_traits>
|
| 25 |
+
#include <tuple>
|
| 26 |
+
#include <mutex>
|
| 27 |
+
|
| 28 |
+
namespace at::native {
|
| 29 |
+
|
| 30 |
+
template <typename Tuple, std::size_t... I>
|
| 31 |
+
// warning : unused parameter when tuple is empty.
|
| 32 |
+
constexpr auto tuple_to_array_helper(const Tuple& t [[maybe_unused]], std::index_sequence<I...> seq) {
|
| 33 |
+
constexpr auto size = seq.size();
|
| 34 |
+
return std::array<const void*, size>{static_cast<const void*>(&std::get<I>(t))...};
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
// Helper function convert tuple to std::array<const void*, N>
|
| 38 |
+
// for passing the arguments to CUDA Kernel
|
| 39 |
+
// NOTE: We capture tuple by reference,
|
| 40 |
+
// so the pointers in returned array are only valid
|
| 41 |
+
// till tuple is alive.
|
| 42 |
+
template <typename ...Args>
|
| 43 |
+
constexpr auto tuple_to_array(const std::tuple<Args...>& extra_args) {
|
| 44 |
+
constexpr auto tuple_size = sizeof...(Args);
|
| 45 |
+
return tuple_to_array_helper(extra_args, std::make_index_sequence<tuple_size>{});
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
struct JittedVecKernelCache {
|
| 49 |
+
// Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
|
| 50 |
+
at::cuda::jit::NvrtcFunction vec1;
|
| 51 |
+
at::cuda::jit::NvrtcFunction vec2;
|
| 52 |
+
at::cuda::jit::NvrtcFunction vec4;
|
| 53 |
+
at::cuda::jit::NvrtcFunction vec8;
|
| 54 |
+
#ifdef USE_ROCM
|
| 55 |
+
at::cuda::jit::NvrtcFunction vec16;
|
| 56 |
+
#endif
|
| 57 |
+
|
| 58 |
+
};
|
| 59 |
+
|
| 60 |
+
struct JittedKernelVariantCache {
|
| 61 |
+
JittedVecKernelCache vec;
|
| 62 |
+
at::cuda::jit::NvrtcFunction noncontiguous;
|
| 63 |
+
at::cuda::jit::NvrtcFunction dynamic_contiguous;
|
| 64 |
+
at::cuda::jit::NvrtcFunction dynamic_noncontiguous;
|
| 65 |
+
};
|
| 66 |
+
|
| 67 |
+
inline c10::SmallBuffer<const void*, 64> pack_kernel_args(
|
| 68 |
+
std::initializer_list<const void*> args,
|
| 69 |
+
c10::ArrayRef<const void*> extra_args) {
|
| 70 |
+
c10::SmallBuffer<const void*, 64> ret(args.size() + extra_args.size());
|
| 71 |
+
std::copy(args.begin(), args.end(), ret.data());
|
| 72 |
+
std::copy(extra_args.begin(), extra_args.end(), ret.data() + args.size());
|
| 73 |
+
return ret;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
template<typename array_t,
|
| 77 |
+
typename inp_calc_t,
|
| 78 |
+
typename out_calc_t,
|
| 79 |
+
typename loader_t,
|
| 80 |
+
typename storer_t>
|
| 81 |
+
void launch_jitted_unrolled_kernel(
|
| 82 |
+
std::mutex &jiterator_mutex,
|
| 83 |
+
at::cuda::jit::NvrtcFunction &fn_cache,
|
| 84 |
+
const at::cuda::jit::KernelDescriptor &desc,
|
| 85 |
+
int64_t N,
|
| 86 |
+
array_t data,
|
| 87 |
+
inp_calc_t ic,
|
| 88 |
+
out_calc_t oc,
|
| 89 |
+
loader_t l,
|
| 90 |
+
storer_t s,
|
| 91 |
+
bool contiguous,
|
| 92 |
+
at::cuda::jit::BinaryFuncVariant scalar_pos,
|
| 93 |
+
const void* scalar_val,
|
| 94 |
+
c10::ArrayRef<const void*> extra_args) {
|
| 95 |
+
|
| 96 |
+
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
|
| 97 |
+
|
| 98 |
+
int tws = at::cuda::jit::calc_thread_work_size(desc.nInputs, desc.nOutputs, desc.f_inputs_type, desc.result_type);
|
| 99 |
+
int bws = tws * num_threads();
|
| 100 |
+
//casting result to int is always safe, intermediate is int64 and won't overflow
|
| 101 |
+
const uint32_t grid = (N + bws - 1) / bws;
|
| 102 |
+
|
| 103 |
+
if (!fn_cache.function) {
|
| 104 |
+
const std::lock_guard<std::mutex> lock{jiterator_mutex};
|
| 105 |
+
if (!fn_cache.function) {
|
| 106 |
+
constexpr bool dynamic_casting = !std::is_same<decltype(l), memory::LoadWithoutCast>() ||
|
| 107 |
+
!std::is_same<decltype(s), memory::StoreWithoutCast>();
|
| 108 |
+
auto code = at::cuda::jit::generate_code(
|
| 109 |
+
desc, contiguous, dynamic_casting, scalar_pos, tws);
|
| 110 |
+
fn_cache = at::cuda::jit::jit_pwise_function(code, desc.name);
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
auto args = pack_kernel_args({&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args);
|
| 115 |
+
at::cuda::jit::launch_jitted_pwise_function(fn_cache, args.data(), {grid, 1u, 1u},
|
| 116 |
+
{num_threads(), 1u, 1u});
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
template<int arity, typename array_t>
|
| 120 |
+
void launch_jitted_vectorized_kernel(
|
| 121 |
+
std::mutex &jiterator_mutex, JittedVecKernelCache &fn_cache,
|
| 122 |
+
const at::cuda::jit::KernelDescriptor &desc, int64_t N, array_t data,
|
| 123 |
+
at::cuda::jit::BinaryFuncVariant scalar_pos,
|
| 124 |
+
const void *scalar_val, c10::ArrayRef<const void*> extra_args) {
|
| 125 |
+
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
|
| 126 |
+
|
| 127 |
+
int tws = at::cuda::jit::calc_thread_work_size(desc.nInputs, desc.nOutputs, desc.f_inputs_type, desc.result_type);
|
| 128 |
+
int bws = tws * num_threads();
|
| 129 |
+
// N is still int64_t for the computation, but it's always safe to cast result to int
|
| 130 |
+
const uint32_t grid = (N + bws - 1) / bws;
|
| 131 |
+
|
| 132 |
+
int vec_size = at::cuda::jit::can_vectorize_up_to(
|
| 133 |
+
desc, c10::ArrayRef<char*>(data.data(), data.size()));
|
| 134 |
+
|
| 135 |
+
#ifndef USE_ROCM
|
| 136 |
+
const auto input_size = c10::scalarTypeToTypeMeta(desc.f_inputs_type).itemsize();
|
| 137 |
+
const int optimal_vec_size = 16 / static_cast<int>(input_size);
|
| 138 |
+
vec_size = std::min<int>(optimal_vec_size, vec_size);
|
| 139 |
+
// Here we purposely omit vec8 for 1-byte data because of a bug in NVCC
|
| 140 |
+
// that causes some numerical mismatches with uint8 on sm80 and sm90.
|
| 141 |
+
// TODO: Revisit this after CUDA 12.8 update.
|
| 142 |
+
if (input_size < 2) {
|
| 143 |
+
vec_size = std::min<int>(vec_size, 4);
|
| 144 |
+
}
|
| 145 |
+
#endif
|
| 146 |
+
|
| 147 |
+
// Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
|
| 148 |
+
// fn_ptr is set to the appropriate function based on the vec size and GPU used
|
| 149 |
+
at::cuda::jit::NvrtcFunction* fn_ptr = nullptr;
|
| 150 |
+
|
| 151 |
+
#ifdef USE_ROCM
|
| 152 |
+
if (vec_size == 16) {
|
| 153 |
+
fn_ptr = &fn_cache.vec16;
|
| 154 |
+
} else
|
| 155 |
+
#endif
|
| 156 |
+
if (vec_size == 8) {
|
| 157 |
+
fn_ptr = &fn_cache.vec8;
|
| 158 |
+
} else if (vec_size == 4) {
|
| 159 |
+
fn_ptr = &fn_cache.vec4;
|
| 160 |
+
} else if (vec_size == 2) {
|
| 161 |
+
fn_ptr = &fn_cache.vec2;
|
| 162 |
+
} else if (vec_size ==1) {
|
| 163 |
+
fn_ptr = &fn_cache.vec1;
|
| 164 |
+
} else {
|
| 165 |
+
TORCH_INTERNAL_ASSERT(false, "unexpected vec_size for jitter vectorized kernel");
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
bool vectorized = vec_size > 1;
|
| 169 |
+
|
| 170 |
+
if (!fn_ptr->function) {
|
| 171 |
+
const std::lock_guard<std::mutex> lock{jiterator_mutex};
|
| 172 |
+
if (!fn_ptr->function) { // cache miss!
|
| 173 |
+
|
| 174 |
+
// Generates program
|
| 175 |
+
auto code = at::cuda::jit::generate_code(
|
| 176 |
+
desc, /*contiguous=*/true, /*dynamic_casting=*/false,
|
| 177 |
+
scalar_pos, tws, vectorized, vec_size);
|
| 178 |
+
std::string kernel_name = vectorized ? desc.name + "_vectorized" + std::to_string(vec_size) : desc.name;
|
| 179 |
+
|
| 180 |
+
// Acquires the program
|
| 181 |
+
*fn_ptr = at::cuda::jit::jit_pwise_function(code, kernel_name);
|
| 182 |
+
}
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
if (vectorized) {
|
| 186 |
+
auto args = pack_kernel_args({&N, &data, scalar_val}, extra_args);
|
| 187 |
+
at::cuda::jit::launch_jitted_pwise_function(
|
| 188 |
+
*fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
|
| 189 |
+
} else {
|
| 190 |
+
// NVCC complains about unused variables l and s.
|
| 191 |
+
// It should be false positive in most cases, so we suppress the warnings.
|
| 192 |
+
#pragma nv_diagnostic push
|
| 193 |
+
#pragma nv_diag_suppress 177
|
| 194 |
+
auto ic = TrivialOffsetCalculator<arity>();
|
| 195 |
+
auto oc = TrivialOffsetCalculator<1>();
|
| 196 |
+
auto l = memory::LoadWithoutCast();
|
| 197 |
+
auto s = memory::StoreWithoutCast();
|
| 198 |
+
|
| 199 |
+
auto args = pack_kernel_args(
|
| 200 |
+
{&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args);
|
| 201 |
+
at::cuda::jit::launch_jitted_pwise_function(
|
| 202 |
+
*fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
|
| 203 |
+
#pragma nv_diagnostic pop
|
| 204 |
+
}
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
template <int arity>
|
| 208 |
+
void jitted_gpu_kernel_generic(
|
| 209 |
+
std::mutex &jiterator_mutex,
|
| 210 |
+
JittedKernelVariantCache &cache,
|
| 211 |
+
const at::cuda::jit::KernelDescriptor &desc,
|
| 212 |
+
at::cuda::jit::BinaryFuncVariant scalar_pos,
|
| 213 |
+
c10::ArrayRef<const void*> extra_args,
|
| 214 |
+
TensorIteratorBase& iter,
|
| 215 |
+
const bool dynamic_casting,
|
| 216 |
+
const void *scalar_val) {
|
| 217 |
+
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
|
| 218 |
+
TORCH_INTERNAL_ASSERT(iter.ninputs() == arity);
|
| 219 |
+
TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
|
| 220 |
+
|
| 221 |
+
constexpr int ntensors = arity + 1;
|
| 222 |
+
std::array<char*, ntensors> data;
|
| 223 |
+
for (auto i : c10::irange(ntensors)) {
|
| 224 |
+
data[i] = (char*)iter.data_ptr(i);
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
int64_t numel = iter.numel();
|
| 228 |
+
bool contiguous = iter.is_contiguous();
|
| 229 |
+
|
| 230 |
+
// Decides which of 4 kernel types to launch
|
| 231 |
+
// Variations are:
|
| 232 |
+
// - Case 1: no dynamic casting and contiguous
|
| 233 |
+
// - Case 2: no dynamic casting and noncontiguous
|
| 234 |
+
// - Case 3: dynamic casting and contiguous
|
| 235 |
+
// - Case 4: dynamic casting and noncontiguous
|
| 236 |
+
// These cases align with the non-jitted CUDALoops.cuh cases in gpu_kernel_impl
|
| 237 |
+
|
| 238 |
+
if (!dynamic_casting) {
|
| 239 |
+
if (contiguous) {
|
| 240 |
+
// Case 1: no dynamic casting and contiguous
|
| 241 |
+
launch_jitted_vectorized_kernel<arity>(
|
| 242 |
+
jiterator_mutex, cache.vec, desc,
|
| 243 |
+
numel, data, scalar_pos, scalar_val, extra_args);
|
| 244 |
+
return;
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
// Case 2: no dynamic casting and noncontiguous
|
| 248 |
+
auto input_offset_calculator = make_input_offset_calculator<arity>(iter);
|
| 249 |
+
auto output_offset_calculator = make_output_offset_calculator(iter);
|
| 250 |
+
auto loader = memory::LoadWithoutCast();
|
| 251 |
+
auto storer = memory::StoreWithoutCast();
|
| 252 |
+
launch_jitted_unrolled_kernel(
|
| 253 |
+
jiterator_mutex, cache.noncontiguous, desc, numel, data,
|
| 254 |
+
input_offset_calculator, output_offset_calculator, loader,
|
| 255 |
+
storer, contiguous, scalar_pos, scalar_val, extra_args);
|
| 256 |
+
return;
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
// Cases 3 and 4 are handled below
|
| 260 |
+
// Both require construction of a storer (this asserts 1 output) and one or more loaders
|
| 261 |
+
|
| 262 |
+
// Creates store cast to output (the zeroth tensor in TensorIterator)
|
| 263 |
+
auto storer = memory::StoreWithCast<1>(iter);
|
| 264 |
+
|
| 265 |
+
// Creates load casts from inputs (note offset indexing into the iterators 1...n tensors)
|
| 266 |
+
auto loader = memory::LoadWithCast<arity>(iter);
|
| 267 |
+
|
| 268 |
+
if (contiguous) {
|
| 269 |
+
// Case 3: dynamic casting and contiguous
|
| 270 |
+
auto input_offset_calculator = TrivialOffsetCalculator<arity>();
|
| 271 |
+
auto output_offset_calculator = TrivialOffsetCalculator<1>();
|
| 272 |
+
launch_jitted_unrolled_kernel(
|
| 273 |
+
jiterator_mutex, cache.dynamic_contiguous, desc, numel, data, input_offset_calculator,
|
| 274 |
+
output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args);
|
| 275 |
+
return;
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
// Case 4: dynamic casting and noncontiguous
|
| 279 |
+
auto input_offset_calculator = make_input_offset_calculator<arity>(iter);
|
| 280 |
+
auto output_offset_calculator = make_output_offset_calculator(iter);
|
| 281 |
+
launch_jitted_unrolled_kernel(
|
| 282 |
+
jiterator_mutex, cache.dynamic_noncontiguous, desc, numel, data, input_offset_calculator,
|
| 283 |
+
output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args);
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
// NOTE: static to reduce chances of name collision.
|
| 287 |
+
template <
|
| 288 |
+
char const* name,
|
| 289 |
+
typename result_type,
|
| 290 |
+
typename f_inputs_type,
|
| 291 |
+
int arity,
|
| 292 |
+
at::cuda::jit::BinaryFuncVariant scalar_pos =
|
| 293 |
+
at::cuda::jit::BinaryFuncVariant::NoScalar,
|
| 294 |
+
typename... ExtraArgs>
|
| 295 |
+
static void jitted_gpu_kernel_impl(
|
| 296 |
+
TensorIteratorBase& iter,
|
| 297 |
+
const std::string &f,
|
| 298 |
+
const bool dynamic_casting,
|
| 299 |
+
at::opmath_type<f_inputs_type> scalar_val,
|
| 300 |
+
const std::tuple<ExtraArgs...>& extra_args) {
|
| 301 |
+
|
| 302 |
+
// TODO: Memory use can probably be optimized by reusing kernels across GPUs with
|
| 303 |
+
// the same compute capability
|
| 304 |
+
static std::mutex jiterator_mutex;
|
| 305 |
+
static std::vector<JittedKernelVariantCache> device_caches(c10::cuda::device_count());
|
| 306 |
+
|
| 307 |
+
constexpr int nInputs = arity;
|
| 308 |
+
constexpr int nOutputs = 1; // TODO: Support more than 1 output
|
| 309 |
+
static const auto desc = at::cuda::jit::make_kernel_descriptor<
|
| 310 |
+
result_type, f_inputs_type, ExtraArgs...>(name, f, nInputs, nOutputs);
|
| 311 |
+
|
| 312 |
+
auto &cache = device_caches[iter.device().index()];
|
| 313 |
+
auto extra_args_array = tuple_to_array(extra_args);
|
| 314 |
+
return jitted_gpu_kernel_generic<arity>(
|
| 315 |
+
jiterator_mutex,
|
| 316 |
+
cache,
|
| 317 |
+
desc,
|
| 318 |
+
scalar_pos,
|
| 319 |
+
extra_args_array,
|
| 320 |
+
iter,
|
| 321 |
+
dynamic_casting,
|
| 322 |
+
&scalar_val
|
| 323 |
+
);
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
} // at::native
|
| 327 |
+
|
| 328 |
+
#endif // AT_USE_JITERATOR()
|
| 329 |
+
|
| 330 |
+
#else
|
| 331 |
+
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
| 332 |
+
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|