Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/AtomicAddFloat.h +37 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/ChannelShuffleKernel.h +14 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/CopyKernel.h +14 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/DepthwiseConvKernel.h +21 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/DistributionTemplates.h +425 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/GridSamplerKernel.h +34 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/IndexKernelUtils.h +87 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Intrinsics.h +33 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/IsContiguous.h +62 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/LogAddExp.h +61 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/MaxUnpoolKernel.h +14 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/PixelShuffleKernel.h +14 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Reduce.h +314 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SampledAddmmKernel.h +12 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SerialStackImpl.h +146 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SoftmaxKernel.h +28 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SpmmReduceKernel.h +22 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/StackKernel.h +12 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/UpSampleKernelAVXAntialias.h +1376 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/WeightNormKernel.h +20 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/avx_mathfun.h +522 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/int_mm_kernel.h +16 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/mixed_data_type.h +41 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/moments_utils.h +202 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/utils.h +212 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/zmath.h +250 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Activation.h +20 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/BinaryInternal.h +48 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CUDAJitLoops.cuh +296 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CUDALoops.cuh +348 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CompositeRandomAccessor.h +35 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Copy.h +10 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CuFFTUtils.h +73 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/DeviceSqrt.cuh +25 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/DistributionTemplates.h +671 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Distributions.h +25 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/ForeachFunctors.cuh +681 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/GridSampler.cuh +321 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/IndexKernel.h +16 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/KernelUtils.cuh +149 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/LaunchUtils.h +18 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/MemoryAccess.cuh +389 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/MiscUtils.h +32 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/MultiTensorApply.cuh +379 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Normalization.cuh +1742 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Pow.cuh +58 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Randperm.cuh +58 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Resize.h +53 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/RowwiseScaledMM.h +15 -0
- .venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/ScanUtils.cuh +459 -0
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/AtomicAddFloat.h
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#ifndef ATOMIC_ADD_FLOAT
|
| 2 |
+
#define ATOMIC_ADD_FLOAT
|
| 3 |
+
|
| 4 |
+
#if (defined(__x86_64__) || defined(__i386__) || defined(__aarch64__))
|
| 5 |
+
#include <ATen/native/cpu/Intrinsics.h>
|
| 6 |
+
#else
|
| 7 |
+
#define _mm_pause()
|
| 8 |
+
#endif
|
| 9 |
+
|
| 10 |
+
#include <atomic>
|
| 11 |
+
|
| 12 |
+
static inline void cpu_atomic_add_float(float* dst, float fvalue)
|
| 13 |
+
{
|
| 14 |
+
typedef union {
|
| 15 |
+
unsigned intV;
|
| 16 |
+
float floatV;
|
| 17 |
+
} uf32_t;
|
| 18 |
+
|
| 19 |
+
uf32_t new_value, old_value;
|
| 20 |
+
std::atomic<unsigned>* dst_intV = (std::atomic<unsigned>*)(dst);
|
| 21 |
+
|
| 22 |
+
old_value.floatV = *dst;
|
| 23 |
+
new_value.floatV = old_value.floatV + fvalue;
|
| 24 |
+
|
| 25 |
+
unsigned* old_intV = (unsigned*)(&old_value.intV);
|
| 26 |
+
while (!std::atomic_compare_exchange_strong(dst_intV, old_intV, new_value.intV)) {
|
| 27 |
+
#ifdef __aarch64__
|
| 28 |
+
__asm__ __volatile__("yield;" : : : "memory");
|
| 29 |
+
#else
|
| 30 |
+
_mm_pause();
|
| 31 |
+
#endif
|
| 32 |
+
old_value.floatV = *dst;
|
| 33 |
+
new_value.floatV = old_value.floatV + fvalue;
|
| 34 |
+
}
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
#endif
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/ChannelShuffleKernel.h
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/native/DispatchStub.h>
|
| 3 |
+
#include <cstdint>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
class TensorBase;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
namespace at::native {
|
| 10 |
+
|
| 11 |
+
using channel_shuffle_fn = void(*)(TensorBase&, const TensorBase&, int64_t);
|
| 12 |
+
DECLARE_DISPATCH(channel_shuffle_fn, channel_shuffle_kernel);
|
| 13 |
+
|
| 14 |
+
} // at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/CopyKernel.h
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/TensorIterator.h>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
struct TensorIteratorBase;
|
| 7 |
+
|
| 8 |
+
namespace native {
|
| 9 |
+
inline namespace CPU_CAPABILITY {
|
| 10 |
+
|
| 11 |
+
void direct_copy_kernel(TensorIteratorBase &iter);
|
| 12 |
+
void copy_kernel(TensorIterator& iter, bool /*non_blocking*/);
|
| 13 |
+
|
| 14 |
+
}}} // namespace at::native::CPU_CAPABILITY
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/DepthwiseConvKernel.h
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
#include <c10/util/ArrayRef.h>
|
| 5 |
+
|
| 6 |
+
/*
|
| 7 |
+
Depthwise 3x3 Winograd convolution operator
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
namespace at {
|
| 11 |
+
class Tensor;
|
| 12 |
+
|
| 13 |
+
namespace native {
|
| 14 |
+
|
| 15 |
+
using convolution_depthwise3x3_winograd_fn =
|
| 16 |
+
Tensor (*)(const Tensor &, const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, int64_t);
|
| 17 |
+
|
| 18 |
+
DECLARE_DISPATCH(convolution_depthwise3x3_winograd_fn, convolution_depthwise3x3_winograd_stub);
|
| 19 |
+
|
| 20 |
+
} // namespace native
|
| 21 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/DistributionTemplates.h
ADDED
|
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/CPUApplyUtils.h>
|
| 4 |
+
#include <ATen/Dispatch.h>
|
| 5 |
+
#include <ATen/Dispatch_v2.h>
|
| 6 |
+
#include <ATen/ExpandBase.h>
|
| 7 |
+
#include <ATen/core/DistributionsHelper.h>
|
| 8 |
+
#include <ATen/native/TensorIterator.h>
|
| 9 |
+
#include <ATen/native/cpu/Loops.h>
|
| 10 |
+
#include <mutex>
|
| 11 |
+
|
| 12 |
+
#ifdef CPU_CAPABILITY_AVX2
|
| 13 |
+
#include <ATen/native/cpu/avx_mathfun.h>
|
| 14 |
+
#include <c10/util/irange.h>
|
| 15 |
+
#endif
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
namespace at::native::templates::cpu {
|
| 21 |
+
namespace {
|
| 22 |
+
|
| 23 |
+
// ==================================================== Random ========================================================
|
| 24 |
+
|
| 25 |
+
template<typename RNG>
|
| 26 |
+
void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG generator) {
|
| 27 |
+
AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cpu", AT_WRAP([&] {
|
| 28 |
+
std::lock_guard<std::mutex> lock(generator->mutex_);
|
| 29 |
+
cpu_serial_kernel(iter, [range, base, generator]() -> scalar_t {
|
| 30 |
+
uniform_int_from_to_distribution<scalar_t> random(range, base);
|
| 31 |
+
return random(generator);
|
| 32 |
+
});
|
| 33 |
+
}), kBool, kHalf, kBFloat16, AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
// This is the special kernel to handle single specific case:
|
| 37 |
+
// from(inclusive) = std::numeric_limits<int64_t>::lowest()
|
| 38 |
+
// to(exclusive) = None (= std::numeric_limits<int64_t>::max() + 1)
|
| 39 |
+
template<typename RNG>
|
| 40 |
+
void random_full_64_bits_range_kernel(TensorIteratorBase& iter, RNG generator) {
|
| 41 |
+
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, iter.dtype(), "random_full_64_bits_range_kernel_cpu", [&] {
|
| 42 |
+
if constexpr (std::is_same_v<scalar_t, int64_t> ||
|
| 43 |
+
std::is_same_v<scalar_t, double> ||
|
| 44 |
+
std::is_same_v<scalar_t, float> ||
|
| 45 |
+
std::is_same_v<scalar_t, at::BFloat16>) {
|
| 46 |
+
std::lock_guard<std::mutex> lock(generator->mutex_);
|
| 47 |
+
cpu_serial_kernel(iter, [generator]() -> scalar_t {
|
| 48 |
+
uniform_int_full_range_distribution<scalar_t> random;
|
| 49 |
+
return random(generator);
|
| 50 |
+
});
|
| 51 |
+
} else {
|
| 52 |
+
TORCH_CHECK(false, "random_full_64_bits_range_kernel_cpu handles only int64, double, float and bfloat16");
|
| 53 |
+
}
|
| 54 |
+
});
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
template<typename RNG>
|
| 58 |
+
struct RandomFromToKernel {
|
| 59 |
+
void operator()(TensorIteratorBase& iter, uint64_t range, int64_t base, std::optional<Generator> gen) {
|
| 60 |
+
random_from_to_kernel(iter, range, base, check_generator<RNG>(gen));
|
| 61 |
+
}
|
| 62 |
+
void operator()(TensorIteratorBase& iter, std::optional<Generator> gen) {
|
| 63 |
+
random_full_64_bits_range_kernel(iter, check_generator<RNG>(gen));
|
| 64 |
+
}
|
| 65 |
+
};
|
| 66 |
+
|
| 67 |
+
template<typename RNG>
|
| 68 |
+
void random_kernel(TensorIteratorBase& iter, RNG generator) {
|
| 69 |
+
std::lock_guard<std::mutex> lock(generator->mutex_);
|
| 70 |
+
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cpu", [&] {
|
| 71 |
+
cpu_serial_kernel(iter, [generator]() -> scalar_t {
|
| 72 |
+
uniform_int_distribution<scalar_t> random;
|
| 73 |
+
return random(generator);
|
| 74 |
+
});
|
| 75 |
+
});
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
template<typename RNG>
|
| 79 |
+
struct RandomKernel {
|
| 80 |
+
void operator()(TensorIteratorBase& iter, std::optional<Generator> gen) {
|
| 81 |
+
random_kernel(iter, check_generator<RNG>(gen));
|
| 82 |
+
}
|
| 83 |
+
};
|
| 84 |
+
|
| 85 |
+
// ==================================================== Normal ========================================================
|
| 86 |
+
|
| 87 |
+
#ifdef CPU_CAPABILITY_AVX2
|
| 88 |
+
static void normal_fill_16_AVX2(float *data,
|
| 89 |
+
const __m256* two_pi,
|
| 90 |
+
const __m256* one,
|
| 91 |
+
const __m256* minus_two,
|
| 92 |
+
const __m256* mean,
|
| 93 |
+
const __m256* std_v) {
|
| 94 |
+
const __m256 u1 = _mm256_sub_ps(*one, _mm256_loadu_ps(data));
|
| 95 |
+
const __m256 u2 = _mm256_loadu_ps(data + 8);
|
| 96 |
+
// sincos256_ps and log256_ps are from avx_mathfun.h
|
| 97 |
+
const __m256 radius = _mm256_sqrt_ps(_mm256_mul_ps(*minus_two, log256_ps(u1)));
|
| 98 |
+
const __m256 theta = _mm256_mul_ps(*two_pi, u2);
|
| 99 |
+
__m256 sintheta, costheta;
|
| 100 |
+
sincos256_ps(theta, &sintheta, &costheta);
|
| 101 |
+
const __m256 n1 = _mm256_mul_ps(radius, costheta);
|
| 102 |
+
const __m256 n2 = _mm256_mul_ps(radius, sintheta);
|
| 103 |
+
_mm256_storeu_ps(data, _mm256_fmadd_ps(n1, *std_v, *mean));
|
| 104 |
+
_mm256_storeu_ps(data + 8, _mm256_fmadd_ps(n2, *std_v, *mean));
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
template<typename RNG>
|
| 108 |
+
void normal_fill_AVX2(const TensorBase &self, const float mean, const float std, RNG generator) {
|
| 109 |
+
float *data = self.data_ptr<float>();
|
| 110 |
+
auto size = self.numel();
|
| 111 |
+
std::lock_guard<std::mutex> lock(generator->mutex_);
|
| 112 |
+
for (const auto i : c10::irange(size)) {
|
| 113 |
+
at::uniform_real_distribution<float> uniform(0, 1);
|
| 114 |
+
data[i] = uniform(generator);
|
| 115 |
+
}
|
| 116 |
+
const __m256 two_pi = _mm256_set1_ps(2.0f * c10::pi<double>);
|
| 117 |
+
const __m256 one = _mm256_set1_ps(1.0f);
|
| 118 |
+
const __m256 minus_two = _mm256_set1_ps(-2.0f);
|
| 119 |
+
const __m256 mean_v = _mm256_set1_ps(mean);
|
| 120 |
+
const __m256 std_v = _mm256_set1_ps(std);
|
| 121 |
+
|
| 122 |
+
for (int64_t i = 0; i < size - 15; i += 16) {
|
| 123 |
+
normal_fill_16_AVX2(data + i, &two_pi, &one, &minus_two, &mean_v, &std_v);
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
if (size % 16 != 0) {
|
| 127 |
+
// Recompute the last 16 values.
|
| 128 |
+
data = data + size - 16;
|
| 129 |
+
for (const auto i : c10::irange(16)) {
|
| 130 |
+
at::uniform_real_distribution<float> uniform(0, 1);
|
| 131 |
+
data[i] = uniform(generator);
|
| 132 |
+
}
|
| 133 |
+
normal_fill_16_AVX2(data, &two_pi, &one, &minus_two, &mean_v, &std_v);
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
#endif
|
| 137 |
+
|
| 138 |
+
template <typename scalar_t>
|
| 139 |
+
static void normal_fill_16(scalar_t *data, const scalar_t mean, const scalar_t std) {
|
| 140 |
+
for (const auto j : c10::irange(8)) {
|
| 141 |
+
const scalar_t u1 = 1 - data[j]; // [0, 1) -> (0, 1] for log.
|
| 142 |
+
const scalar_t u2 = data[j + 8];
|
| 143 |
+
const scalar_t radius = std::sqrt(-2 * std::log(u1));
|
| 144 |
+
const scalar_t theta = 2.0f * c10::pi<double> * u2;
|
| 145 |
+
data[j] = radius * std::cos(theta) * std + mean;
|
| 146 |
+
data[j + 8] = radius * std::sin(theta) * std + mean;
|
| 147 |
+
}
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
#if defined(__VSX__) || defined(CPU_CAPABILITY_VSX)
|
| 151 |
+
static void normal_fill_16_VSX(float *data,const Vectorized<float> &two_pi,const Vectorized<float> &one,const Vectorized<float> &minus_two,const Vectorized<float> &mean,const Vectorized<float> &std) {
|
| 152 |
+
using Vec = Vectorized<float>;
|
| 153 |
+
Vec u1=one-Vec::loadu(data);
|
| 154 |
+
Vec u2=Vec::loadu(data+8);
|
| 155 |
+
Vec radius=(minus_two * u1.log());
|
| 156 |
+
radius=radius.sqrt();
|
| 157 |
+
Vec theta=two_pi * u2;
|
| 158 |
+
Vec output_vec=radius * theta.cos() * std + mean;
|
| 159 |
+
Vec output_vec2=radius * theta.sin() * std + mean;
|
| 160 |
+
output_vec.store(data);
|
| 161 |
+
output_vec2.store(data+8);
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
template <typename scalar_t, typename RNG>
|
| 165 |
+
void normal_fill_VSX(const TensorBase &self, const scalar_t mean, const scalar_t std, RNG generator) {
|
| 166 |
+
float *data = self.data_ptr<float>();
|
| 167 |
+
auto size = self.numel();
|
| 168 |
+
std::lock_guard<std::mutex> lock(generator->mutex_);
|
| 169 |
+
for (const auto i : c10::irange(size)) {
|
| 170 |
+
at::uniform_real_distribution<scalar_t> uniform(0, 1);
|
| 171 |
+
data[i] = uniform(generator);
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
using Vec = Vectorized<float>;
|
| 175 |
+
const Vec two_pi = Vec(2.0f * c10::pi<double>);
|
| 176 |
+
const Vec one = Vec(1.0f);
|
| 177 |
+
const Vec minus_two = Vec(-2.0f);
|
| 178 |
+
const Vec var_vec = Vec(std);
|
| 179 |
+
const Vec mean_vec = Vec(mean);
|
| 180 |
+
|
| 181 |
+
for (int64_t i = 0; i < size - 15; i += 16) {
|
| 182 |
+
if(Vec::size()==8) {
|
| 183 |
+
normal_fill_16_VSX(data + i, two_pi, one, minus_two, mean_vec, var_vec);
|
| 184 |
+
}
|
| 185 |
+
else{
|
| 186 |
+
normal_fill_16<scalar_t>(data + i, mean, std);
|
| 187 |
+
}
|
| 188 |
+
}
|
| 189 |
+
if (size % 16 != 0) {
|
| 190 |
+
// Recompute the last 16 values.
|
| 191 |
+
data = data + size - 16;
|
| 192 |
+
for (const auto i : c10::irange(16)) {
|
| 193 |
+
at::uniform_real_distribution<scalar_t> uniform(0, 1);
|
| 194 |
+
data[i] = uniform(generator);
|
| 195 |
+
}
|
| 196 |
+
if(Vec::size()==8){
|
| 197 |
+
normal_fill_16_VSX(data, two_pi, one, minus_two, mean_vec, var_vec);
|
| 198 |
+
}
|
| 199 |
+
else{
|
| 200 |
+
normal_fill_16<scalar_t>(data, mean, std);
|
| 201 |
+
}
|
| 202 |
+
}
|
| 203 |
+
}
|
| 204 |
+
#endif //VSX
|
| 205 |
+
|
| 206 |
+
template <typename scalar_t, typename RNG>
|
| 207 |
+
void normal_fill(const TensorBase &self, const scalar_t mean, const scalar_t std, RNG generator) {
|
| 208 |
+
scalar_t *data = self.data_ptr<scalar_t>();
|
| 209 |
+
auto size = self.numel();
|
| 210 |
+
std::lock_guard<std::mutex> lock(generator->mutex_);
|
| 211 |
+
for (const auto i : c10::irange(size)) {
|
| 212 |
+
at::uniform_real_distribution<scalar_t> uniform(0, 1);
|
| 213 |
+
data[i] = uniform(generator);
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
for (int64_t i = 0; i < size - 15; i += 16) {
|
| 217 |
+
normal_fill_16<scalar_t>(data + i, mean, std);
|
| 218 |
+
}
|
| 219 |
+
if (size % 16 != 0) {
|
| 220 |
+
// Recompute the last 16 values.
|
| 221 |
+
data = data + size - 16;
|
| 222 |
+
for (const auto i : c10::irange(16)) {
|
| 223 |
+
at::uniform_real_distribution<scalar_t> uniform(0, 1);
|
| 224 |
+
data[i] = uniform(generator);
|
| 225 |
+
}
|
| 226 |
+
normal_fill_16<scalar_t>(data, mean, std);
|
| 227 |
+
}
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
template<typename RNG>
|
| 231 |
+
void normal_kernel(const TensorBase &self, double mean, double std, RNG generator) {
|
| 232 |
+
auto size = self.numel();
|
| 233 |
+
if (self.scalar_type() == ScalarType::Float && size >= 16 && self.is_contiguous()) {
|
| 234 |
+
#ifdef CPU_CAPABILITY_AVX2
|
| 235 |
+
normal_fill_AVX2(self, static_cast<float>(mean), static_cast<float>(std), generator);
|
| 236 |
+
#elif defined(__VSX__) || defined(CPU_CAPABILITY_VSX)
|
| 237 |
+
normal_fill_VSX(self, static_cast<float>(mean), static_cast<float>(std), generator);
|
| 238 |
+
#else
|
| 239 |
+
normal_fill(self, static_cast<float>(mean), static_cast<float>(std), generator);
|
| 240 |
+
#endif
|
| 241 |
+
} else {
|
| 242 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, self.scalar_type(), "normal_kernel_cpu", [&] {
|
| 243 |
+
if (size >= 16 && self.is_contiguous()) {
|
| 244 |
+
normal_fill<scalar_t>(self, static_cast<scalar_t>(mean), static_cast<scalar_t>(std), generator);
|
| 245 |
+
} else {
|
| 246 |
+
auto iter = TensorIterator::borrowing_nullary_op(self);
|
| 247 |
+
std::lock_guard<std::mutex> lock(generator->mutex_);
|
| 248 |
+
cpu_serial_kernel(iter, [mean, std, generator]() -> scalar_t {
|
| 249 |
+
at::normal_distribution<double> normal(mean, std);
|
| 250 |
+
return static_cast<scalar_t>(normal(generator));
|
| 251 |
+
});
|
| 252 |
+
}
|
| 253 |
+
});
|
| 254 |
+
}
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
template<typename RNG>
|
| 258 |
+
struct NormalKernel {
|
| 259 |
+
void operator()(Tensor& self, double mean, double std, std::optional<Generator> gen) {
|
| 260 |
+
normal_kernel(self, mean, std, check_generator<RNG>(gen));
|
| 261 |
+
}
|
| 262 |
+
};
|
| 263 |
+
|
| 264 |
+
// ==================================================== Uniform =======================================================
|
| 265 |
+
|
| 266 |
+
template<typename RNG>
|
| 267 |
+
void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG generator) {
|
| 268 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "uniform_kernel_cpu", [&]() {
|
| 269 |
+
std::lock_guard<std::mutex> lock(generator->mutex_);
|
| 270 |
+
auto from = static_cast<scalar_t>(from_);
|
| 271 |
+
auto to = static_cast<scalar_t>(to_);
|
| 272 |
+
at::uniform_real_distribution<scalar_t> uniform(from, to);
|
| 273 |
+
cpu_serial_kernel(iter, [&uniform, generator]() -> scalar_t {
|
| 274 |
+
return static_cast<scalar_t>(uniform(generator));
|
| 275 |
+
});
|
| 276 |
+
});
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
template<typename RNG>
|
| 280 |
+
struct UniformKernel {
|
| 281 |
+
void operator()(TensorIteratorBase& iter, double from, double to, std::optional<Generator> gen) {
|
| 282 |
+
uniform_kernel(iter, from, to, check_generator<RNG>(gen));
|
| 283 |
+
}
|
| 284 |
+
};
|
| 285 |
+
|
| 286 |
+
// ==================================================== Cauchy ========================================================
|
| 287 |
+
|
| 288 |
+
template<typename RNG>
|
| 289 |
+
void cauchy_kernel(TensorIteratorBase& iter, double median, double sigma, RNG generator) {
|
| 290 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "cauchy_cpu", [&]() {
|
| 291 |
+
std::lock_guard<std::mutex> lock(generator->mutex_);
|
| 292 |
+
at::cauchy_distribution<double> cauchy(median, sigma);
|
| 293 |
+
cpu_serial_kernel(iter, [&cauchy, generator]() -> scalar_t {
|
| 294 |
+
return static_cast<scalar_t>(cauchy(generator));
|
| 295 |
+
});
|
| 296 |
+
});
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
template<typename RNG>
|
| 300 |
+
struct CauchyKernel {
|
| 301 |
+
void operator()(TensorIteratorBase& iter, double median, double sigma, std::optional<Generator> gen) {
|
| 302 |
+
cauchy_kernel(iter, median, sigma, check_generator<RNG>(gen));
|
| 303 |
+
}
|
| 304 |
+
};
|
| 305 |
+
|
| 306 |
+
// ================================================== LogNormal =======================================================
|
| 307 |
+
|
| 308 |
+
template<typename RNG>
|
| 309 |
+
void log_normal_kernel(TensorIteratorBase& iter, double mean, double std, RNG generator) {
|
| 310 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "log_normal_cpu", [&]() {
|
| 311 |
+
std::lock_guard<std::mutex> lock(generator->mutex_);
|
| 312 |
+
at::lognormal_distribution<double> logNormal(mean, std);
|
| 313 |
+
cpu_serial_kernel(iter, [&logNormal, generator]() -> scalar_t {
|
| 314 |
+
return static_cast<scalar_t>(logNormal(generator));
|
| 315 |
+
});
|
| 316 |
+
});
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
template<typename RNG>
|
| 320 |
+
struct LogNormalKernel {
|
| 321 |
+
void operator()(TensorIteratorBase& iter, double mean, double std, std::optional<Generator> gen) {
|
| 322 |
+
log_normal_kernel(iter, mean, std, check_generator<RNG>(gen));
|
| 323 |
+
}
|
| 324 |
+
};
|
| 325 |
+
|
| 326 |
+
// =================================================== Geometric ======================================================
|
| 327 |
+
|
| 328 |
+
template<typename RNG>
|
| 329 |
+
void geometric_kernel(TensorIteratorBase& iter, double p, RNG generator) {
|
| 330 |
+
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_cpu", [&]() {
|
| 331 |
+
std::lock_guard<std::mutex> lock(generator->mutex_);
|
| 332 |
+
at::geometric_distribution<double> geometric(p);
|
| 333 |
+
cpu_serial_kernel(iter, [&geometric, generator]() -> scalar_t {
|
| 334 |
+
return static_cast<scalar_t>(geometric(generator));
|
| 335 |
+
});
|
| 336 |
+
});
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
template<typename RNG>
|
| 340 |
+
struct GeometricKernel {
|
| 341 |
+
void operator()(TensorIteratorBase& iter, double p, std::optional<Generator> gen) {
|
| 342 |
+
geometric_kernel(iter, p, check_generator<RNG>(gen));
|
| 343 |
+
}
|
| 344 |
+
};
|
| 345 |
+
|
| 346 |
+
// ================================================== Exponential =====================================================
|
| 347 |
+
|
| 348 |
+
template<typename RNG>
|
| 349 |
+
void exponential_kernel(TensorIteratorBase& iter, double lambda, RNG generator) {
|
| 350 |
+
TORCH_CHECK(isFloatingType(iter.dtype()), "Exponential distribution is a continuous probability distribution. dtype must be a floating point but you specified ", iter.dtype());
|
| 351 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_cpu", [&]() {
|
| 352 |
+
std::lock_guard<std::mutex> lock(generator->mutex_);
|
| 353 |
+
at::exponential_distribution<double> exponential(lambda);
|
| 354 |
+
cpu_serial_kernel(iter, [&exponential, generator]() -> scalar_t {
|
| 355 |
+
return static_cast<scalar_t>(exponential(generator));
|
| 356 |
+
});
|
| 357 |
+
});
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
template<typename RNG>
|
| 361 |
+
struct ExponentialKernel {
|
| 362 |
+
void operator()(TensorIteratorBase& iter, double lambda, std::optional<Generator> gen) {
|
| 363 |
+
exponential_kernel(iter, lambda, check_generator<RNG>(gen));
|
| 364 |
+
}
|
| 365 |
+
};
|
| 366 |
+
|
| 367 |
+
// ================================================== Bernoulli =======================================================
|
| 368 |
+
|
| 369 |
+
template<typename RNG>
|
| 370 |
+
void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG generator) {
|
| 371 |
+
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half,
|
| 372 |
+
self.scalar_type(), "bernoulli_tensor_cpu_self_", [&] {
|
| 373 |
+
// See Note [Acquire lock when using random generators]
|
| 374 |
+
std::lock_guard<std::mutex> lock(generator->mutex_);
|
| 375 |
+
using self_t = scalar_t;
|
| 376 |
+
auto p_cpu = p_.to(kCPU);
|
| 377 |
+
auto p = expand_inplace(self, p_cpu);
|
| 378 |
+
auto iter = TensorIteratorConfig()
|
| 379 |
+
.add_output(self)
|
| 380 |
+
.add_const_input(*p)
|
| 381 |
+
.check_all_same_dtype(false)
|
| 382 |
+
.build();
|
| 383 |
+
if (p->scalar_type() == kDouble) {
|
| 384 |
+
cpu_serial_kernel(iter, [&](const double p_val) -> self_t {
|
| 385 |
+
at::bernoulli_distribution<double> bernoulli(p_val);
|
| 386 |
+
return static_cast<self_t>(bernoulli(generator));
|
| 387 |
+
});
|
| 388 |
+
} else {
|
| 389 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half,
|
| 390 |
+
p->scalar_type(), "bernoulli_tensor_cpu_p_", [&] {
|
| 391 |
+
using p_t = scalar_t;
|
| 392 |
+
cpu_serial_kernel(iter, [&](const p_t p_val) -> self_t {
|
| 393 |
+
at::bernoulli_distribution<float> bernoulli(p_val);
|
| 394 |
+
return static_cast<self_t>(bernoulli(generator));
|
| 395 |
+
});
|
| 396 |
+
});
|
| 397 |
+
}
|
| 398 |
+
});
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
template<typename RNG>
|
| 402 |
+
void bernoulli_kernel(const TensorBase &self, double p, RNG generator) {
|
| 403 |
+
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half,
|
| 404 |
+
self.scalar_type(), "bernoulli_scalar_cpu_", [&] {
|
| 405 |
+
// See Note [Acquire lock when using random generators]
|
| 406 |
+
std::lock_guard<std::mutex> lock(generator->mutex_);
|
| 407 |
+
auto iter = TensorIterator::borrowing_nullary_op(self);
|
| 408 |
+
cpu_serial_kernel(iter, [p, generator]() -> scalar_t {
|
| 409 |
+
at::bernoulli_distribution<double> bernoulli(p);
|
| 410 |
+
return static_cast<scalar_t>(bernoulli(generator));
|
| 411 |
+
});
|
| 412 |
+
});
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
template<typename RNG>
|
| 416 |
+
struct BernoulliKernel {
|
| 417 |
+
void operator()(const TensorBase &self, double p, std::optional<Generator> gen) {
|
| 418 |
+
bernoulli_kernel(self, p, check_generator<RNG>(gen));
|
| 419 |
+
}
|
| 420 |
+
void operator()(const TensorBase &self, const TensorBase &p_, std::optional<Generator> gen) {
|
| 421 |
+
bernoulli_kernel(self, p_, check_generator<RNG>(gen));
|
| 422 |
+
}
|
| 423 |
+
};
|
| 424 |
+
|
| 425 |
+
}}
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/GridSamplerKernel.h
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
|
| 5 |
+
#include <array>
|
| 6 |
+
#include <cstdint>
|
| 7 |
+
|
| 8 |
+
namespace at {
|
| 9 |
+
class TensorBase;
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
namespace at::native {
|
| 13 |
+
|
| 14 |
+
using forward_2d_fn = void (*) (
|
| 15 |
+
const TensorBase &output,
|
| 16 |
+
const TensorBase &input,
|
| 17 |
+
const TensorBase &grid,
|
| 18 |
+
int64_t interpolation_mode,
|
| 19 |
+
int64_t padding_mode,
|
| 20 |
+
bool align_corners);
|
| 21 |
+
using backward_2d_fn = void (*) (
|
| 22 |
+
const TensorBase &grad_input,
|
| 23 |
+
const TensorBase &grad_grid,
|
| 24 |
+
const TensorBase &grad_output,
|
| 25 |
+
const TensorBase &input,
|
| 26 |
+
const TensorBase &grid,
|
| 27 |
+
int64_t interpolation_mode,
|
| 28 |
+
int64_t padding_mode,
|
| 29 |
+
bool align_corners,
|
| 30 |
+
std::array<bool, 2> output_mask);
|
| 31 |
+
DECLARE_DISPATCH(forward_2d_fn, grid_sampler_2d_cpu_kernel);
|
| 32 |
+
DECLARE_DISPATCH(backward_2d_fn, grid_sampler_2d_backward_cpu_kernel);
|
| 33 |
+
|
| 34 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/IndexKernelUtils.h
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/native/TensorIterator.h>
|
| 3 |
+
#include <c10/util/irange.h>
|
| 4 |
+
|
| 5 |
+
namespace at::native {
|
| 6 |
+
|
| 7 |
+
namespace {
|
| 8 |
+
static bool is_constant_index(int ntensor, const int64_t* strides) {
|
| 9 |
+
AT_ASSERT(ntensor >= 3);
|
| 10 |
+
for (const auto arg : c10::irange(2, ntensor)) {
|
| 11 |
+
if (strides[arg] != 0) {
|
| 12 |
+
return false;
|
| 13 |
+
}
|
| 14 |
+
}
|
| 15 |
+
return true;
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
struct Indexer {
|
| 20 |
+
Indexer(int64_t num_indexers, char** indexers, const int64_t* indexer_strides,
|
| 21 |
+
IntArrayRef original_sizes, IntArrayRef original_strides)
|
| 22 |
+
: num_indexers(num_indexers)
|
| 23 |
+
, indexers(indexers)
|
| 24 |
+
, indexer_strides(indexer_strides)
|
| 25 |
+
, original_strides(original_strides.data())
|
| 26 |
+
, original_sizes(original_sizes.data()) {
|
| 27 |
+
AT_ASSERT(static_cast<int64_t>(original_strides.size()) == num_indexers);
|
| 28 |
+
AT_ASSERT(static_cast<int64_t>(original_sizes.size()) == num_indexers);
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
int64_t num_indexers;
|
| 32 |
+
char** indexers;
|
| 33 |
+
const int64_t* indexer_strides;
|
| 34 |
+
const int64_t* original_strides;
|
| 35 |
+
const int64_t* original_sizes;
|
| 36 |
+
|
| 37 |
+
int64_t get(int64_t idx) {
|
| 38 |
+
int64_t offset = 0;
|
| 39 |
+
for (const auto j : c10::irange(num_indexers)) {
|
| 40 |
+
int64_t value = *(int64_t*)&indexers[j][idx * indexer_strides[j]];
|
| 41 |
+
int64_t size = original_sizes[j];
|
| 42 |
+
TORCH_CHECK_INDEX(value >= -size && value < size,
|
| 43 |
+
"index ", value, " is out of bounds for dimension ", j, " with size ", size);
|
| 44 |
+
if (value < 0) {
|
| 45 |
+
value += size;
|
| 46 |
+
}
|
| 47 |
+
offset += value * original_strides[j];
|
| 48 |
+
}
|
| 49 |
+
return offset;
|
| 50 |
+
}
|
| 51 |
+
};
|
| 52 |
+
} // anonymous namespace
|
| 53 |
+
|
| 54 |
+
template <typename scalar_t, typename func_t>
|
| 55 |
+
void cpu_index_kernel(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride,
|
| 56 |
+
const func_t& f, bool serial_execution=false)
|
| 57 |
+
{
|
| 58 |
+
int ntensor = iter.ntensors();
|
| 59 |
+
// When launch the index parallel version, set a relative small grain size less than the INTERNAL::GRAIN_SIZE
|
| 60 |
+
// to make the whole available thread numbers get more balanced work load and a better cache location.
|
| 61 |
+
// The grain size here is chosen by the op benchmark to overcome the thread launch overhead
|
| 62 |
+
const int index_parallel_grain_size = 3000;
|
| 63 |
+
auto loop = [&](char** data, const int64_t* strides, int64_t n) {
|
| 64 |
+
auto indexer = Indexer(ntensor - 2, &data[2], &strides[2], index_size, index_stride);
|
| 65 |
+
char* dst = data[0];
|
| 66 |
+
char* src = data[1];
|
| 67 |
+
if (is_constant_index(ntensor, strides)) {
|
| 68 |
+
// specialization for when every element uses the same index
|
| 69 |
+
int64_t offset = indexer.get(0);
|
| 70 |
+
for (const auto i : c10::irange(n)) {
|
| 71 |
+
f(dst + strides[0] * i, src + strides[1] * i, offset);
|
| 72 |
+
}
|
| 73 |
+
} else {
|
| 74 |
+
for (const auto i : c10::irange(n)) {
|
| 75 |
+
int64_t offset = indexer.get(i);
|
| 76 |
+
f(dst + strides[0] * i, src + strides[1] * i, offset);
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
};
|
| 80 |
+
if (serial_execution) {
|
| 81 |
+
iter.serial_for_each(loop, {0, iter.numel()});
|
| 82 |
+
} else {
|
| 83 |
+
iter.for_each(loop, index_parallel_grain_size);
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
} // at
|
| 87 |
+
// native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Intrinsics.h
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#if defined(__clang__) && (defined(__x86_64__) || defined(__i386__))
|
| 4 |
+
/* Clang-compatible compiler, targeting x86/x86-64 */
|
| 5 |
+
#include <x86intrin.h>
|
| 6 |
+
#elif defined(_MSC_VER)
|
| 7 |
+
/* Microsoft C/C++-compatible compiler */
|
| 8 |
+
#include <intrin.h>
|
| 9 |
+
#if _MSC_VER <= 1900
|
| 10 |
+
#define _mm256_extract_epi64(X, Y) (((uint64_t*)&X)[Y])
|
| 11 |
+
#endif
|
| 12 |
+
#elif defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__))
|
| 13 |
+
/* GCC-compatible compiler, targeting x86/x86-64 */
|
| 14 |
+
#include <x86intrin.h>
|
| 15 |
+
#elif defined(__GNUC__) && defined(__ARM_NEON__)
|
| 16 |
+
/* GCC-compatible compiler, targeting ARM with NEON */
|
| 17 |
+
#include <arm_neon.h>
|
| 18 |
+
#elif defined(__GNUC__) && defined(__IWMMXT__)
|
| 19 |
+
/* GCC-compatible compiler, targeting ARM with WMMX */
|
| 20 |
+
#include <mmintrin.h>
|
| 21 |
+
#elif (defined(__GNUC__) || defined(__xlC__)) && \
|
| 22 |
+
(defined(__VEC__) || defined(__ALTIVEC__))
|
| 23 |
+
/* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */
|
| 24 |
+
#include <altivec.h>
|
| 25 |
+
/* We need to undef those tokens defined by <altivec.h> to avoid conflicts
|
| 26 |
+
with the C++ types. => Can still use __bool/__vector */
|
| 27 |
+
#undef bool
|
| 28 |
+
#undef vector
|
| 29 |
+
#undef pixel
|
| 30 |
+
#elif defined(__GNUC__) && defined(__SPE__)
|
| 31 |
+
/* GCC-compatible compiler, targeting PowerPC with SPE */
|
| 32 |
+
#include <spe.h>
|
| 33 |
+
#endif
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/IsContiguous.h
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
namespace at::native { inline namespace CPU_CAPABILITY {
|
| 4 |
+
|
| 5 |
+
// n: number of function arguments (arity)
|
| 6 |
+
// traits: function_traits (see FunctionTraits.h)
|
| 7 |
+
// s: index of scalar argument or -1
|
| 8 |
+
template <int n, int stride_index, typename traits, int s=-1>
|
| 9 |
+
struct IsContiguous {
|
| 10 |
+
static bool eval(const int64_t* strides) {
|
| 11 |
+
using type = typename traits::template arg<n - 1>::type;
|
| 12 |
+
return strides[stride_index] == (s == n ? 0 : sizeof(type)) &&
|
| 13 |
+
IsContiguous<n - 1, stride_index - 1, traits, s>::eval(strides);
|
| 14 |
+
}
|
| 15 |
+
};
|
| 16 |
+
|
| 17 |
+
// will be called when there is an output exists
|
| 18 |
+
template <typename traits, int s>
|
| 19 |
+
struct IsContiguous<0, 0, traits, s> {
|
| 20 |
+
static bool eval(const int64_t* strides) {
|
| 21 |
+
return strides[0] == sizeof(typename traits::result_type);
|
| 22 |
+
}
|
| 23 |
+
};
|
| 24 |
+
|
| 25 |
+
// will be called when there is no output
|
| 26 |
+
template <typename traits, int s>
|
| 27 |
+
struct IsContiguous<0, -1, traits, s> {
|
| 28 |
+
static bool eval(const int64_t* /*strides*/) {
|
| 29 |
+
return true;
|
| 30 |
+
}
|
| 31 |
+
};
|
| 32 |
+
|
| 33 |
+
// output and all inputs are contiguous
|
| 34 |
+
template <typename traits,
|
| 35 |
+
typename std::enable_if<std::is_void<typename traits::result_type>::value>::type* = nullptr>
|
| 36 |
+
static inline bool is_contiguous(const int64_t* strides) {
|
| 37 |
+
return IsContiguous<traits::arity, traits::arity - 1, traits>::eval(strides);
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
template <typename traits,
|
| 41 |
+
typename std::enable_if<!std::is_void<typename traits::result_type>::value>::type* = nullptr>
|
| 42 |
+
static inline bool is_contiguous(const int64_t* strides) {
|
| 43 |
+
return IsContiguous<traits::arity, traits::arity, traits>::eval(strides);
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
// input at `s` is scalar (stride 0); output and other inputs are contiguous
|
| 47 |
+
// NB: output is typically at strides[0] so first input corresponds to s=1
|
| 48 |
+
template <typename traits, int s,
|
| 49 |
+
typename std::enable_if<std::is_void<typename traits::result_type>::value>::type* = nullptr>
|
| 50 |
+
static inline bool is_contiguous_scalar(const int64_t* strides) {
|
| 51 |
+
static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds");
|
| 52 |
+
return IsContiguous<traits::arity, traits::arity - 1, traits, s>::eval(strides);
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
template <typename traits, int s,
|
| 56 |
+
typename std::enable_if<!std::is_void<typename traits::result_type>::value>::type* = nullptr>
|
| 57 |
+
static inline bool is_contiguous_scalar(const int64_t* strides) {
|
| 58 |
+
static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds");
|
| 59 |
+
return IsContiguous<traits::arity, traits::arity, traits, s>::eval(strides);
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
}}
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/LogAddExp.h
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/util/complex.h>
|
| 4 |
+
#include <ATen/NumericUtils.h>
|
| 5 |
+
|
| 6 |
+
namespace at::native {
|
| 7 |
+
inline namespace CPU_CAPABILITY {
|
| 8 |
+
|
| 9 |
+
// custom min and max to be used in logcumsumexp for complex arguments
|
| 10 |
+
template <typename scalar_t>
|
| 11 |
+
std::pair<c10::complex<scalar_t>, c10::complex<scalar_t>> _logcumsumexp_minmax(c10::complex<scalar_t> x, c10::complex<scalar_t> y) {
|
| 12 |
+
if (at::_isnan(y)) { // either real is nan or imag is nan
|
| 13 |
+
return std::make_pair(y, y);
|
| 14 |
+
} else if (at::_isnan(x)) { // either real is nan or imag is nan
|
| 15 |
+
return std::make_pair(x, x);
|
| 16 |
+
} else {
|
| 17 |
+
return (x.real() < y.real()) ? std::make_pair(x, y) : std::make_pair(y, x);
|
| 18 |
+
}
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
template <typename scalar_t>
|
| 22 |
+
scalar_t _log_add_exp_helper(scalar_t x, scalar_t y) {
|
| 23 |
+
// Reference : https://www.tensorflow.org/api_docs/python/tf/math/cumulative_logsumexp
|
| 24 |
+
scalar_t min = at::_isnan(y) ? y : std::min(x, y); // std::min returns first arg if one of the args is nan
|
| 25 |
+
scalar_t max = at::_isnan(y) ? y : std::max(x, y); // std::max returns first arg if one of the args is nan
|
| 26 |
+
if (min != max || std::isfinite(min)) {
|
| 27 |
+
// nan will be propagated here
|
| 28 |
+
return std::log1p(std::exp(min - max)) + max;
|
| 29 |
+
} else {
|
| 30 |
+
// special case to correctly handle infinite cases
|
| 31 |
+
return x;
|
| 32 |
+
}
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
template <typename scalar_t>
|
| 36 |
+
c10::complex<scalar_t> _log_add_exp_helper(const c10::complex<scalar_t>& x, const c10::complex<scalar_t>& y) {
|
| 37 |
+
auto [min, max] = _logcumsumexp_minmax<scalar_t>(x, y);
|
| 38 |
+
auto min_real = std::real(min);
|
| 39 |
+
auto max_real = std::real(max);
|
| 40 |
+
|
| 41 |
+
if (at::_isnan(min)) { // either real is nan or imag is nan
|
| 42 |
+
// handling the "infectious" NaNs
|
| 43 |
+
return {std::numeric_limits<scalar_t>::quiet_NaN(), std::numeric_limits<scalar_t>::quiet_NaN()};
|
| 44 |
+
} else if (!std::isfinite(min_real) && (min_real == max_real)) {
|
| 45 |
+
if (min_real < 0) {
|
| 46 |
+
// handle the -inf case, the imaginary part here does not really matter as the exp(value)
|
| 47 |
+
// will be around 0.0 and the angle (i.e. the imaginary part) cannot be determined.
|
| 48 |
+
// It does not matter if we're taking the exp of this value
|
| 49 |
+
return min;
|
| 50 |
+
} else {
|
| 51 |
+
// handle the +inf case, we don't need the special precision for log1p for small values
|
| 52 |
+
// and to avoid producing nan in case of real(max) == real(min) == +inf
|
| 53 |
+
return std::log(std::exp(min) + std::exp(max));
|
| 54 |
+
}
|
| 55 |
+
} else {
|
| 56 |
+
return std::log1p(std::exp(min - max)) + max;
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
} // end namespace
|
| 61 |
+
} //end at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/MaxUnpoolKernel.h
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/native/DispatchStub.h>
|
| 3 |
+
|
| 4 |
+
namespace at {
|
| 5 |
+
class Tensor;
|
| 6 |
+
|
| 7 |
+
namespace native {
|
| 8 |
+
|
| 9 |
+
using max_unpooling_fn = void(*)(Tensor&, const Tensor&, const Tensor&);
|
| 10 |
+
|
| 11 |
+
DECLARE_DISPATCH(max_unpooling_fn, max_unpool2d_kernel);
|
| 12 |
+
DECLARE_DISPATCH(max_unpooling_fn, max_unpool3d_kernel);
|
| 13 |
+
|
| 14 |
+
}} // at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/PixelShuffleKernel.h
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/native/DispatchStub.h>
|
| 3 |
+
|
| 4 |
+
namespace at {
|
| 5 |
+
class TensorBase;
|
| 6 |
+
}
|
| 7 |
+
|
| 8 |
+
namespace at::native {
|
| 9 |
+
|
| 10 |
+
using pixel_shuffle_fn = void(*)(TensorBase&, const TensorBase&, int64_t);
|
| 11 |
+
DECLARE_DISPATCH(pixel_shuffle_fn, pixel_shuffle_kernel);
|
| 12 |
+
DECLARE_DISPATCH(pixel_shuffle_fn, pixel_unshuffle_kernel);
|
| 13 |
+
|
| 14 |
+
} // at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Reduce.h
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/cpu/Loops.h>
|
| 4 |
+
#include <ATen/Parallel.h>
|
| 5 |
+
#include <c10/util/TypeList.h>
|
| 6 |
+
#include <c10/core/Scalar.h>
|
| 7 |
+
#include <c10/util/irange.h>
|
| 8 |
+
|
| 9 |
+
#include <sstream>
|
| 10 |
+
#include <type_traits>
|
| 11 |
+
|
| 12 |
+
namespace at { namespace native { inline namespace CPU_CAPABILITY {
|
| 13 |
+
|
| 14 |
+
using namespace vec;
|
| 15 |
+
|
| 16 |
+
#define VEC_LOOP_HEADER(func_t, data) \
|
| 17 |
+
using scalar_t = typename function_traits<func_t>::result_type; \
|
| 18 |
+
using Vec = Vectorized<scalar_t>; \
|
| 19 |
+
char* out_ptr = data[0]; \
|
| 20 |
+
(void) out_ptr;
|
| 21 |
+
|
| 22 |
+
// reduction that is contiguous over the input in dim 0
|
| 23 |
+
template <typename traits>
|
| 24 |
+
inline bool is_contiguous_reduction(const int64_t* strides) {
|
| 25 |
+
return strides[0] == 0 &&
|
| 26 |
+
strides[1] == sizeof(typename traits::arg2_t);
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
// reduction that is contiguous over the input in dim 1
|
| 30 |
+
template <typename traits>
|
| 31 |
+
inline bool is_outer_reduction(const int64_t* strides) {
|
| 32 |
+
return strides[0] == 0 &&
|
| 33 |
+
strides[2] == sizeof(typename traits::result_type) &&
|
| 34 |
+
strides[3] == sizeof(typename traits::arg2_t);
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
template <typename func_t, typename vec_func_t>
|
| 38 |
+
inline void vectorized_reduction(char** data, int64_t n, int64_t stride,
|
| 39 |
+
func_t op, vec_func_t vop, bool reduce) {
|
| 40 |
+
VEC_LOOP_HEADER(func_t, data)
|
| 41 |
+
const char* in1_ptr = data[1];
|
| 42 |
+
Vec acc[4];
|
| 43 |
+
for (const auto j : c10::irange(4)) {
|
| 44 |
+
acc[j] = Vec::loadu(in1_ptr + j * Vec::size() * sizeof(scalar_t));
|
| 45 |
+
}
|
| 46 |
+
for (const auto i : c10::irange(1, n)) {
|
| 47 |
+
const char* ptr = in1_ptr + stride * i;
|
| 48 |
+
acc[0] = vop(acc[0], Vec::loadu(ptr + (0 * Vec::size() * sizeof(scalar_t))));
|
| 49 |
+
acc[1] = vop(acc[1], Vec::loadu(ptr + (1 * Vec::size() * sizeof(scalar_t))));
|
| 50 |
+
acc[2] = vop(acc[2], Vec::loadu(ptr + (2 * Vec::size() * sizeof(scalar_t))));
|
| 51 |
+
acc[3] = vop(acc[3], Vec::loadu(ptr + (3 * Vec::size() * sizeof(scalar_t))));
|
| 52 |
+
}
|
| 53 |
+
if (reduce) {
|
| 54 |
+
scalar_t buffer[Vec::size()];
|
| 55 |
+
acc[0] = vop(vop(acc[0], acc[1]), vop(acc[2], acc[3]));
|
| 56 |
+
acc[0].store(buffer);
|
| 57 |
+
for (const auto j : c10::irange(1, Vec::size())) {
|
| 58 |
+
buffer[0] = op(buffer[0], buffer[j]);
|
| 59 |
+
}
|
| 60 |
+
auto dst = (scalar_t*)out_ptr;
|
| 61 |
+
*dst = op(*dst, buffer[0]);
|
| 62 |
+
} else {
|
| 63 |
+
for (const auto j : c10::irange(4)) {
|
| 64 |
+
auto dst = out_ptr + j * Vec::size() * sizeof(scalar_t);
|
| 65 |
+
acc[j] = vop(acc[j], Vec::loadu(dst));
|
| 66 |
+
acc[j].store(dst);
|
| 67 |
+
}
|
| 68 |
+
}
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
template <typename F>
|
| 72 |
+
inline void UNARY_OUTER_LOOP(char* data[2], const int64_t strides[2], int64_t n, F f) {
|
| 73 |
+
for (const auto j C10_UNUSED : c10::irange(n)) {
|
| 74 |
+
f();
|
| 75 |
+
data[0] += strides[0];
|
| 76 |
+
data[1] += strides[1];
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
// computes the reduction out = op(out, in)
|
| 81 |
+
template <typename func_t, typename vec_func_t>
|
| 82 |
+
inline void vectorized_inner_reduction(char** data, int64_t n, func_t op, vec_func_t vop) {
|
| 83 |
+
VEC_LOOP_HEADER(func_t, data)
|
| 84 |
+
int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t);
|
| 85 |
+
int64_t count = n / (4 * Vec::size());
|
| 86 |
+
if (count > 0) {
|
| 87 |
+
vectorized_reduction(data, count, vector_stride, op, vop, /*reduce=*/true);
|
| 88 |
+
}
|
| 89 |
+
char* ptrs[3] = { data[0], data[0], data[1] };
|
| 90 |
+
int64_t strides[] = { 0, 0, sizeof(scalar_t) };
|
| 91 |
+
basic_loop(ptrs, strides, count * 4 * Vec::size(), n, op);
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
// computes the reduction out = op(out, in)
|
| 95 |
+
template <typename func_t, typename vec_func_t>
|
| 96 |
+
inline void vectorized_outer_reduction(char** data, int64_t inner_stride, int64_t size0, int64_t size1, func_t op, vec_func_t vop) {
|
| 97 |
+
VEC_LOOP_HEADER(func_t, data)
|
| 98 |
+
|
| 99 |
+
// reduce down each column of 4 * Vec::size() elements (128 or 256 bytes)
|
| 100 |
+
#if defined(CPU_CAPABILITY_AVX512)
|
| 101 |
+
int64_t outer_stride[2] = { 256, 256 };
|
| 102 |
+
#else
|
| 103 |
+
int64_t outer_stride[2] = { 128, 128 };
|
| 104 |
+
#endif
|
| 105 |
+
UNARY_OUTER_LOOP(data, outer_stride, size1 / (4 * Vec::size()), [&] {
|
| 106 |
+
vectorized_reduction(data, size0, inner_stride, op, vop, /*reduce=*/false);
|
| 107 |
+
});
|
| 108 |
+
|
| 109 |
+
// reduce down the remaining columns
|
| 110 |
+
int64_t step[] = { sizeof(scalar_t), sizeof(scalar_t) };
|
| 111 |
+
int64_t remaining = size1 % (4 * Vec::size());
|
| 112 |
+
UNARY_OUTER_LOOP(data, step, remaining, [&] {
|
| 113 |
+
char* ptrs[3] = { data[0], data[0], data[1] };
|
| 114 |
+
int64_t strides[] = { 0, 0, inner_stride };
|
| 115 |
+
basic_loop(ptrs, strides, 0, size0, op);
|
| 116 |
+
});
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
template<typename traits, typename res_t>
|
| 120 |
+
static void set_result(const int index, const res_t result, const TensorIteratorBase &iter, const int num_outputs) {
|
| 121 |
+
// static_assert(std::is_same<res_t, typename traits::arg2_t>::value, "data types must match");
|
| 122 |
+
if (index < num_outputs) {
|
| 123 |
+
char *out = (char *) iter.data_ptr(index);
|
| 124 |
+
*(res_t *) out = result;
|
| 125 |
+
}
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
template<typename traits, typename res_t>
|
| 129 |
+
static void set_results(const res_t result, const TensorIteratorBase &iter, const int num_outputs) {
|
| 130 |
+
AT_ASSERT(num_outputs == 1);
|
| 131 |
+
set_result<traits>(0, result, iter, num_outputs);
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
template<typename traits, std::size_t i = 0, typename... tuple_t>
|
| 135 |
+
inline typename std::enable_if<i == sizeof...(tuple_t), std::size_t>::type
|
| 136 |
+
for_each_in_tuple(const std::tuple<tuple_t...>& /*t*/, const TensorIteratorBase& /*iter*/, const int /*num_outputs*/) {
|
| 137 |
+
return i;
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
template<typename traits, std::size_t i = 0, typename... tuple_t>
|
| 141 |
+
inline typename std::enable_if<i < sizeof...(tuple_t), std::size_t>::type
|
| 142 |
+
for_each_in_tuple(const std::tuple<tuple_t...>& t, const TensorIteratorBase &iter, const int num_outputs) {
|
| 143 |
+
if (i < (size_t)num_outputs) {
|
| 144 |
+
set_result<traits>(i, std::get<i>(t), iter, num_outputs);
|
| 145 |
+
return for_each_in_tuple<traits, i + 1, tuple_t...>(t, iter, num_outputs);
|
| 146 |
+
}
|
| 147 |
+
return i;
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
template<typename traits, typename... res_t>
|
| 151 |
+
static void set_results(const std::tuple<res_t...>& result, const TensorIteratorBase &iter, const int num_outputs) {
|
| 152 |
+
AT_ASSERT(num_outputs >= 1);
|
| 153 |
+
std::size_t result_size = for_each_in_tuple<traits>(result, iter, num_outputs);
|
| 154 |
+
AT_ASSERT((size_t)num_outputs == result_size);
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
template <typename T, typename... Args>
|
| 158 |
+
struct all_same : std::conjunction<
|
| 159 |
+
std::is_same<T, Args>...
|
| 160 |
+
> {};
|
| 161 |
+
|
| 162 |
+
// data_t is the input/output data type.
|
| 163 |
+
// acc_t is a type that contains all the necessary data
|
| 164 |
+
// to continue reducing.
|
| 165 |
+
// index_t is a one-dimensional index
|
| 166 |
+
//
|
| 167 |
+
// ops_t is such that &ops_t::reduce, &ops_t::combine, and &ops_t::project exist and satisfy
|
| 168 |
+
// the following.
|
| 169 |
+
// reduce: (acc_t, data_t, index_t) -> acc_t adds one data point to the accumulated value.
|
| 170 |
+
// combine: (acc_t, acc_t) -> acc_t combines two accumulated values into one.
|
| 171 |
+
// project: acc_t -> out_t finishes the reduction, getting the required output.
|
| 172 |
+
//
|
| 173 |
+
// Additionally, acc_t must be default-constructible:
|
| 174 |
+
// acc_t {} is an identity for combine,
|
| 175 |
+
// and project(acc_t {}) is the value of the operation on zero elements.
|
| 176 |
+
//
|
| 177 |
+
// The point of `combine` is to support parallelization -
|
| 178 |
+
// the idea is to one sequence of `reduce` calls per thread of execution,
|
| 179 |
+
// and then to combine them at the end with `combine`.
|
| 180 |
+
//
|
| 181 |
+
// If there is more than one output element,
|
| 182 |
+
// our parallelization strategy is to use one thread for each of them,
|
| 183 |
+
// which means that `combine` will never be called.
|
| 184 |
+
//
|
| 185 |
+
// If, on the other hand, there is only one, then we split the input into
|
| 186 |
+
// into several pieces, reduce each separately, and then combine them.
|
| 187 |
+
|
| 188 |
+
template <typename ops_t, typename init_t>
|
| 189 |
+
void binary_kernel_reduce(TensorIteratorBase& iter, ops_t ops, init_t init) {
|
| 190 |
+
using rf_t = decltype(&ops_t::reduce);
|
| 191 |
+
using cf_t = decltype(&ops_t::combine);
|
| 192 |
+
using pf_t = decltype(&ops_t::project);
|
| 193 |
+
using r_traits = binary_function_traits<rf_t>;
|
| 194 |
+
using c_traits = binary_function_traits<cf_t>;
|
| 195 |
+
using p_traits = unary_function_traits<pf_t>;
|
| 196 |
+
using acc_t = typename p_traits::arg1_t;
|
| 197 |
+
using data_t = typename r_traits::arg2_t;
|
| 198 |
+
static_assert(
|
| 199 |
+
all_same<
|
| 200 |
+
acc_t,
|
| 201 |
+
init_t,
|
| 202 |
+
typename r_traits::arg1_t,
|
| 203 |
+
typename r_traits::result_type,
|
| 204 |
+
typename c_traits::arg1_t,
|
| 205 |
+
typename c_traits::arg2_t,
|
| 206 |
+
typename c_traits::result_type>::value,
|
| 207 |
+
"all accumulate types must match");
|
| 208 |
+
static_assert(
|
| 209 |
+
std::is_default_constructible<acc_t>::value,
|
| 210 |
+
"the accumulate type must be default-constructible"
|
| 211 |
+
);
|
| 212 |
+
const int num_outputs = iter.noutputs();
|
| 213 |
+
iter.foreach_reduced_elt([&ops, &init, num_outputs](TensorIteratorBase &sub_iter) {
|
| 214 |
+
auto reduction_body = [&ops, &sub_iter, num_outputs](acc_t acc, int64_t begin, int64_t end) -> acc_t {
|
| 215 |
+
int ntensors = sub_iter.ntensors();
|
| 216 |
+
sub_iter.serial_for_each([&acc, &ops, num_outputs, ntensors, begin](char** data, const int64_t* strides, int64_t size) {
|
| 217 |
+
AT_ASSERT(ntensors - num_outputs == 1);
|
| 218 |
+
char *in = data[ntensors - 1];
|
| 219 |
+
int64_t stride = strides[ntensors - 1];
|
| 220 |
+
for (const auto i : c10::irange(size)) {
|
| 221 |
+
acc = ops.reduce(acc, c10::load<data_t>(in), begin + i);
|
| 222 |
+
in += stride;
|
| 223 |
+
}
|
| 224 |
+
}, {begin, end});
|
| 225 |
+
return ops.translate_idx(acc, sub_iter.view_offsets()[0]);
|
| 226 |
+
};
|
| 227 |
+
acc_t total_acc = init;
|
| 228 |
+
auto numel = sub_iter.numel();
|
| 229 |
+
if (numel < at::internal::GRAIN_SIZE || at::get_num_threads() == 1 ||
|
| 230 |
+
at::in_parallel_region()) {
|
| 231 |
+
total_acc = reduction_body(total_acc, 0, numel);
|
| 232 |
+
} else {
|
| 233 |
+
int max_threads = at::get_num_threads();
|
| 234 |
+
AT_ASSERT(max_threads > 0);
|
| 235 |
+
static_assert(
|
| 236 |
+
!std::is_same<acc_t, bool>::value,
|
| 237 |
+
"Concurrently modifying different references into std::vector<bool> is UB."
|
| 238 |
+
);
|
| 239 |
+
std::vector<acc_t> buffer((unsigned)max_threads, init);
|
| 240 |
+
at::parallel_for(0, numel, internal::GRAIN_SIZE,
|
| 241 |
+
[&](int64_t begin, int64_t end) {
|
| 242 |
+
auto& acc = buffer[at::get_thread_num()];
|
| 243 |
+
acc = reduction_body(acc, begin, end);
|
| 244 |
+
}
|
| 245 |
+
);
|
| 246 |
+
for (const auto i : c10::irange(max_threads)) {
|
| 247 |
+
total_acc = ops.combine(total_acc, buffer[i]);
|
| 248 |
+
}
|
| 249 |
+
}
|
| 250 |
+
set_results<r_traits>(ops.project(total_acc), sub_iter, num_outputs);
|
| 251 |
+
});
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
template <typename func_t, typename vec_func_t>
|
| 255 |
+
void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, double ident = 0) {
|
| 256 |
+
using traits = binary_function_traits<func_t>;
|
| 257 |
+
static_assert(
|
| 258 |
+
all_same<
|
| 259 |
+
typename traits::result_type,
|
| 260 |
+
typename traits::arg1_t,
|
| 261 |
+
typename traits::arg2_t>::value,
|
| 262 |
+
"all types must match");
|
| 263 |
+
|
| 264 |
+
iter.output_base().fill_(ident);
|
| 265 |
+
iter.parallel_reduce([&](char** data, const int64_t* strides, int64_t size0, int64_t size1) {
|
| 266 |
+
int64_t outer_strides[] = { strides[2], strides[3] };
|
| 267 |
+
if (is_contiguous_reduction<traits>(strides)) {
|
| 268 |
+
// input is contiguous in dim 0, output is reduced in dim 0
|
| 269 |
+
UNARY_OUTER_LOOP(data, outer_strides, size1, [&] {
|
| 270 |
+
vectorized_inner_reduction(data, size0, op, vop);
|
| 271 |
+
});
|
| 272 |
+
} else if (is_outer_reduction<traits>(strides)) {
|
| 273 |
+
// input and output are contiguous in dim 1
|
| 274 |
+
int64_t inner_stride = strides[1]; // stride of input in dim 0
|
| 275 |
+
vectorized_outer_reduction(data, inner_stride, size0, size1, op, vop);
|
| 276 |
+
} else {
|
| 277 |
+
UNARY_OUTER_LOOP(data, outer_strides, size1, [&] {
|
| 278 |
+
char* ptrs[3] = { data[0], data[0], data[1] };
|
| 279 |
+
int64_t inner_strides[3] = { strides[0], strides[0], strides[1] };
|
| 280 |
+
basic_loop(ptrs, inner_strides, 0, size0, op);
|
| 281 |
+
});
|
| 282 |
+
}
|
| 283 |
+
});
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
// when reduction is on most inner dimension (dim 0 in TensorIterator)
|
| 287 |
+
// and input has contiguous most inner dimension, `binary_kernel_reduce_lastdim`
|
| 288 |
+
// can be used.
|
| 289 |
+
inline bool is_reduce_lastdim(TensorIteratorBase& iter) {
|
| 290 |
+
return iter.num_reduce_dims() == 1 && iter.is_dim_reduced(0)
|
| 291 |
+
&& iter.ninputs() == 1 && iter.strides(1)[0] == iter.element_size(1);
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
template <typename reduce_func_t>
|
| 295 |
+
void binary_kernel_reduce_lastdim(TensorIteratorBase& iter, reduce_func_t reduce_op) {
|
| 296 |
+
auto shape = iter.shape();
|
| 297 |
+
int64_t dim_size = shape[0];
|
| 298 |
+
int64_t grain_size = std::max((int64_t) 1, at::internal::GRAIN_SIZE / dim_size);
|
| 299 |
+
TensorIterator sub_iter(iter);
|
| 300 |
+
// create sub iterator to parallel on all non-reduce-dims
|
| 301 |
+
sub_iter.narrow(0, 0, 1);
|
| 302 |
+
auto loop = [&](char** data, const int64_t* strides, int64_t size) {
|
| 303 |
+
char* out = data[0];
|
| 304 |
+
char* in = data[1];
|
| 305 |
+
for (int64_t i = 0; i < size; ++i) {
|
| 306 |
+
reduce_op(out, in, dim_size);
|
| 307 |
+
out += strides[0];
|
| 308 |
+
in += strides[1];
|
| 309 |
+
}
|
| 310 |
+
};
|
| 311 |
+
sub_iter.for_each(loop, grain_size);
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
}}} // namespace at::native::<anonymous>
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SampledAddmmKernel.h
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/native/DispatchStub.h>
|
| 5 |
+
|
| 6 |
+
namespace at::native {
|
| 7 |
+
|
| 8 |
+
using sampled_addmm_sparse_csr_fn = void(*)(const Tensor&, const Tensor&, const Scalar&, const Scalar&, const Tensor&);
|
| 9 |
+
|
| 10 |
+
DECLARE_DISPATCH(sampled_addmm_sparse_csr_fn, sampled_addmm_sparse_csr_stub);
|
| 11 |
+
|
| 12 |
+
} // at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SerialStackImpl.h
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright 2004-present Facebook. All Rights Reserved.
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/core/Tensor.h>
|
| 5 |
+
|
| 6 |
+
#include <ATen/MemoryOverlap.h>
|
| 7 |
+
#include <ATen/Parallel.h>
|
| 8 |
+
#include <ATen/TensorIterator.h>
|
| 9 |
+
#include <ATen/cpu/vec/functional.h>
|
| 10 |
+
#include <ATen/cpu/vec/vec.h>
|
| 11 |
+
#include <c10/util/irange.h>
|
| 12 |
+
|
| 13 |
+
namespace at::native::detail {
|
| 14 |
+
|
| 15 |
+
struct InputMeta {
|
| 16 |
+
void* data_ptr;
|
| 17 |
+
int64_t inner_size;
|
| 18 |
+
|
| 19 |
+
InputMeta(const Tensor& t, int64_t dim, int64_t inner)
|
| 20 |
+
: data_ptr(t.data_ptr()), inner_size(t.sizes()[dim] * inner) {}
|
| 21 |
+
};
|
| 22 |
+
|
| 23 |
+
// This kernel is used by two TensorList types:
|
| 24 |
+
// 1. stack_serial_kernel uses at::ArrayRef<Tensor>
|
| 25 |
+
// 2. Static runtime calls this kernel directly (csrc/jit/runtime/static/ops.cpp) with
|
| 26 |
+
// ProcessedNodeInputWrapper.
|
| 27 |
+
// When making changes, make sure that they are compatible with both types!
|
| 28 |
+
template <typename scalar_t, typename TensorListType>
|
| 29 |
+
void stack_serial_kernel_impl(Tensor& result, TensorListType tensors, int64_t dim) {
|
| 30 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
| 31 |
+
dim >= 0 && dim <= result.dim(),
|
| 32 |
+
"dim out of range in stack_serial_kernel_impl");
|
| 33 |
+
int64_t outer =
|
| 34 |
+
result.numel() / (result.sizes()[dim] * result.strides()[dim]);
|
| 35 |
+
scalar_t* result_data = result.data_ptr<scalar_t>();
|
| 36 |
+
int64_t ninputs = tensors.size();
|
| 37 |
+
std::vector<InputMeta> inputs;
|
| 38 |
+
inputs.reserve(ninputs);
|
| 39 |
+
for (const auto& tensor : tensors) {
|
| 40 |
+
inputs.emplace_back(tensor, dim, tensor.strides()[dim]);
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
using Vec = vec::Vectorized<scalar_t>;
|
| 44 |
+
scalar_t* result_ptr = result_data;
|
| 45 |
+
for (const auto i : c10::irange(outer)) {
|
| 46 |
+
for (const auto j : c10::irange(ninputs)) {
|
| 47 |
+
int64_t local_inner = inputs[j].inner_size;
|
| 48 |
+
scalar_t* input_ptr = (scalar_t*)(inputs[j].data_ptr) + i * local_inner;
|
| 49 |
+
|
| 50 |
+
if (local_inner < Vec::size()) {
|
| 51 |
+
for (const auto k : c10::irange(local_inner)) {
|
| 52 |
+
result_ptr[k] = input_ptr[k];
|
| 53 |
+
}
|
| 54 |
+
} else {
|
| 55 |
+
vec::map(
|
| 56 |
+
[](Vec x) { return x; }, result_ptr, input_ptr, local_inner);
|
| 57 |
+
}
|
| 58 |
+
result_ptr += local_inner;
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
// Checks to see whether native stack can be invoked under these conditions:
|
| 64 |
+
// - result and input tensors are contiguous
|
| 65 |
+
// - only one thread is used
|
| 66 |
+
// - no type promotion has to occur
|
| 67 |
+
// - tensors dtype is Double or Float
|
| 68 |
+
template <typename TensorListType>
|
| 69 |
+
bool can_use_native_serial_stack_impl(Tensor& result, TensorListType tensors, int64_t dim) {
|
| 70 |
+
TORCH_CHECK(tensors.size() > 0, "expected a non-empty list of Tensors");
|
| 71 |
+
const Tensor& first_tensor = tensors[0];
|
| 72 |
+
// stack dimension should be in range [0,firstTensor.dim())
|
| 73 |
+
// dim == firstTensor.dim() is a valid input, but it is handled by default code path
|
| 74 |
+
// that uses unsqueeze
|
| 75 |
+
if (dim >= first_tensor.dim()) return false;
|
| 76 |
+
// Native stack doesn't apply any tensor is skipped.
|
| 77 |
+
if (first_tensor.numel() == 0 && first_tensor.dim() == 1) return false;
|
| 78 |
+
// there should be no type promotion
|
| 79 |
+
if (result.dtype() != first_tensor.dtype()) return false;
|
| 80 |
+
|
| 81 |
+
auto first_tensor_mem_format = first_tensor.suggest_memory_format();
|
| 82 |
+
ScalarType dtype = first_tensor.scalar_type();
|
| 83 |
+
|
| 84 |
+
if (!result.is_contiguous(first_tensor_mem_format)) {
|
| 85 |
+
return false;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
// fast path only works for Double and Float
|
| 89 |
+
if (dtype != ScalarType::Double && dtype != ScalarType::Float) {
|
| 90 |
+
return false;
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
// check remainder of inputs
|
| 94 |
+
#ifndef STRIP_ERROR_MESSAGES
|
| 95 |
+
auto const &first_tensor_shape = first_tensor.sizes();
|
| 96 |
+
#endif
|
| 97 |
+
for (const auto i : c10::irange(1, tensors.size())) {
|
| 98 |
+
auto const &tensor = tensors[i];
|
| 99 |
+
TORCH_CHECK(tensors[i].sizes() == first_tensor.sizes(),
|
| 100 |
+
"stack expects each tensor to be equal size, but got ", first_tensor_shape,
|
| 101 |
+
" at entry 0 and ", tensor.sizes(), " at entry ", i);
|
| 102 |
+
|
| 103 |
+
// every tensor must be contiguous
|
| 104 |
+
// tensor sizes and strides must be the same
|
| 105 |
+
// there should be no type promotion
|
| 106 |
+
if (!tensor.is_contiguous(first_tensor_mem_format) ||
|
| 107 |
+
tensor.strides() != first_tensor.strides() ||
|
| 108 |
+
tensor.dtype() != dtype) {
|
| 109 |
+
return false;
|
| 110 |
+
}
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
// fast native stack should only be used when it is not worth using multiple threads
|
| 114 |
+
// or there is only one thread. Note that we aren't checking result.numel() here because
|
| 115 |
+
// it may not have been resized and we want to defer that cost till later.
|
| 116 |
+
int64_t numel_in_stack = first_tensor.numel() * tensors.size();
|
| 117 |
+
return numel_in_stack < at::internal::GRAIN_SIZE || at::get_num_threads() == 1;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
template <typename TensorListType, bool should_skip_overlap_check>
|
| 121 |
+
struct CanUseNativeSerialStack;
|
| 122 |
+
|
| 123 |
+
template <typename TensorListType>
|
| 124 |
+
struct CanUseNativeSerialStack<TensorListType, false> {
|
| 125 |
+
static bool call(Tensor& result, TensorListType tensors, int64_t dim) {
|
| 126 |
+
// Inputs cannot alias the output tensor
|
| 127 |
+
for (const auto i : c10::irange(tensors.size())) {
|
| 128 |
+
auto lap = at::get_overlap_status(result, tensors[i]);
|
| 129 |
+
TORCH_CHECK(lap != at::MemOverlapStatus::Partial &&
|
| 130 |
+
lap != at::MemOverlapStatus::Full, 0,
|
| 131 |
+
"unsupported operation: the input tensors cannot refer to any of the "
|
| 132 |
+
"output memory locations. Found overlap in input tensor ", i);
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
return can_use_native_serial_stack_impl(result, tensors, dim);
|
| 136 |
+
}
|
| 137 |
+
};
|
| 138 |
+
|
| 139 |
+
template <typename TensorListType>
|
| 140 |
+
struct CanUseNativeSerialStack<TensorListType, true> {
|
| 141 |
+
static bool call(Tensor& result, TensorListType tensors, int64_t dim) {
|
| 142 |
+
return can_use_native_serial_stack_impl(result, tensors, dim);
|
| 143 |
+
}
|
| 144 |
+
};
|
| 145 |
+
|
| 146 |
+
} // namespace at::native::detail
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SoftmaxKernel.h
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/DispatchStub.h>
|
| 4 |
+
#include <cstdint>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
class Tensor;
|
| 8 |
+
|
| 9 |
+
namespace native {
|
| 10 |
+
|
| 11 |
+
using forward_fn = void (*)(const Tensor&, const Tensor&);
|
| 12 |
+
using backward_fn = void(*)(const Tensor &, const Tensor &, const Tensor&);
|
| 13 |
+
|
| 14 |
+
DECLARE_DISPATCH(forward_fn, softmax_lastdim_kernel);
|
| 15 |
+
DECLARE_DISPATCH(forward_fn, log_softmax_lastdim_kernel);
|
| 16 |
+
DECLARE_DISPATCH(backward_fn, softmax_backward_lastdim_kernel);
|
| 17 |
+
DECLARE_DISPATCH(backward_fn, log_softmax_backward_lastdim_kernel);
|
| 18 |
+
|
| 19 |
+
using forward_fn_with_dim = void(*)(const Tensor &, const Tensor &, const int64_t);
|
| 20 |
+
using backward_fn_with_dim =
|
| 21 |
+
void (*)(const Tensor&, const Tensor&, const Tensor&, const int64_t);
|
| 22 |
+
|
| 23 |
+
DECLARE_DISPATCH(forward_fn_with_dim, softmax_kernel);
|
| 24 |
+
DECLARE_DISPATCH(forward_fn_with_dim, log_softmax_kernel);
|
| 25 |
+
DECLARE_DISPATCH(backward_fn_with_dim, softmax_backward_kernel);
|
| 26 |
+
DECLARE_DISPATCH(backward_fn_with_dim, log_softmax_backward_kernel);
|
| 27 |
+
}
|
| 28 |
+
}
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SpmmReduceKernel.h
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/native/DispatchStub.h>
|
| 5 |
+
#include <ATen/native/ReductionType.h>
|
| 6 |
+
|
| 7 |
+
namespace at::native {
|
| 8 |
+
|
| 9 |
+
using spmm_reduce_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
|
| 10 |
+
using spmm_reduce_arg_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
|
| 11 |
+
using spmm_reduce_backward_input_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
|
| 12 |
+
using spmm_reduce_backward_input_arg_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
|
| 13 |
+
using spmm_reduce_backward_other_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
|
| 14 |
+
|
| 15 |
+
DECLARE_DISPATCH(spmm_reduce_fn, spmm_reduce_stub);
|
| 16 |
+
DECLARE_DISPATCH(spmm_reduce_arg_fn, spmm_reduce_arg_stub);
|
| 17 |
+
DECLARE_DISPATCH(spmm_reduce_backward_input_fn, spmm_reduce_backward_input_stub);
|
| 18 |
+
DECLARE_DISPATCH(spmm_reduce_backward_input_arg_fn, spmm_reduce_backward_input_arg_stub);
|
| 19 |
+
DECLARE_DISPATCH(spmm_reduce_backward_other_fn, spmm_reduce_backward_other_stub);
|
| 20 |
+
DECLARE_DISPATCH(spmm_reduce_backward_input_arg_fn, spmm_reduce_backward_other_arg_stub);
|
| 21 |
+
|
| 22 |
+
} // at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/StackKernel.h
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright 2004-present Facebook. All Rights Reserved.
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/core/Tensor.h>
|
| 5 |
+
#include <ATen/native/DispatchStub.h>
|
| 6 |
+
|
| 7 |
+
namespace at::native {
|
| 8 |
+
|
| 9 |
+
using stack_serial_fn = void(*)(Tensor &, TensorList, int64_t);
|
| 10 |
+
DECLARE_DISPATCH(stack_serial_fn, stack_serial_stub);
|
| 11 |
+
|
| 12 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/UpSampleKernelAVXAntialias.h
ADDED
|
@@ -0,0 +1,1376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
The Python Imaging Library (PIL) is
|
| 3 |
+
|
| 4 |
+
Copyright © 1997-2011 by Secret Labs AB
|
| 5 |
+
Copyright © 1995-2011 by Fredrik Lundh
|
| 6 |
+
|
| 7 |
+
Pillow is the friendly PIL fork. It is
|
| 8 |
+
|
| 9 |
+
Copyright © 2010-2022 by Alex Clark and contributors
|
| 10 |
+
|
| 11 |
+
Like PIL, Pillow is licensed under the open source HPND License
|
| 12 |
+
*/
|
| 13 |
+
|
| 14 |
+
// This code is heavily inspired from PILLOW-SIMD's implementation:
|
| 15 |
+
// https://github.com/uploadcare/pillow-simd/blob/simd/master/src/libImaging/Resample.c
|
| 16 |
+
|
| 17 |
+
#pragma once
|
| 18 |
+
#ifdef CPU_CAPABILITY_AVX2
|
| 19 |
+
// TODO: This file only supports AVX2. We could split the AVX kernels into
|
| 20 |
+
// smaller logical blocks in order to port them into the Vec.h logic. This would
|
| 21 |
+
// allow to support other vectorization architectures and perhaps also support
|
| 22 |
+
// the non-vectorized fallback (we'd need to make sure it's not slower than the
|
| 23 |
+
// current fallback).
|
| 24 |
+
|
| 25 |
+
#include <ATen/core/Tensor.h>
|
| 26 |
+
#include <ATen/cpu/vec/intrinsics.h>
|
| 27 |
+
#include <c10/util/irange.h>
|
| 28 |
+
|
| 29 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 30 |
+
#include <ATen/Functions.h>
|
| 31 |
+
#else
|
| 32 |
+
#include <ATen/ops/empty.h>
|
| 33 |
+
#endif
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
namespace {
|
| 37 |
+
|
| 38 |
+
static inline __m128i mm_cvtsi32_si128(const uint8_t* C10_RESTRICT ptr, bool i32_aligned) {
|
| 39 |
+
int32_t v;
|
| 40 |
+
if (i32_aligned) {
|
| 41 |
+
v = *(const int32_t*)ptr;
|
| 42 |
+
} else {
|
| 43 |
+
std::memcpy(&v, ptr, 4);
|
| 44 |
+
}
|
| 45 |
+
return _mm_cvtsi32_si128(v);
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
static inline __m128i mm_cvtepu8_epi32(const uint8_t* C10_RESTRICT ptr, bool i32_aligned) {
|
| 49 |
+
return _mm_cvtepu8_epi32(mm_cvtsi32_si128(ptr, i32_aligned));
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
static inline void _write_endline_rgb_as_uint32(
|
| 53 |
+
uint8_t* C10_RESTRICT output,
|
| 54 |
+
uint32_t data
|
| 55 |
+
) {
|
| 56 |
+
// data is (R G B X), output is (X1 X2 X3 | R1 B1 G1 R2 ...)
|
| 57 |
+
// Here we explicitly set X as R1
|
| 58 |
+
uint8_t* data_ptr = reinterpret_cast<uint8_t*>(&data);
|
| 59 |
+
data_ptr[3] = output[3];
|
| 60 |
+
std::memcpy(output, data_ptr, 4);
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
at::Tensor unpack_rgb(const at::Tensor& packed_tensor) {
|
| 64 |
+
// Convert a "packed" tensor (typically RGBRGBRGB if channels_last) into
|
| 65 |
+
// RGBARGBARGBA format where A is hard-coded to 0. Each pixel is encoded
|
| 66 |
+
// into as 32 bits. This generalizes to num_channels <= 4 and also works for
|
| 67 |
+
// non-channels_last tensors.
|
| 68 |
+
|
| 69 |
+
const uint8_t* packed = (const uint8_t*)packed_tensor.const_data_ptr<uint8_t>();
|
| 70 |
+
auto num_pixels = packed_tensor.size(1) * packed_tensor.size(2);
|
| 71 |
+
auto num_channels = packed_tensor.size(0);
|
| 72 |
+
|
| 73 |
+
constexpr int rgba_size = 4;
|
| 74 |
+
auto unpacked_tensor = at::empty({rgba_size, packed_tensor.size(1), packed_tensor.size(2)}, at::CPU(at::kByte));
|
| 75 |
+
uint8_t* unpacked = (uint8_t*) unpacked_tensor.data_ptr<uint8_t>();
|
| 76 |
+
|
| 77 |
+
auto stride_i = packed_tensor.stride(2);
|
| 78 |
+
auto stride_j = packed_tensor.stride(0);
|
| 79 |
+
|
| 80 |
+
for (const auto i : c10::irange(num_pixels)) {
|
| 81 |
+
for (const auto j : c10::irange(rgba_size)) {
|
| 82 |
+
unpacked[rgba_size * i + j] = (j < num_channels) ? packed[stride_i * i + stride_j * j] : 0;
|
| 83 |
+
}
|
| 84 |
+
}
|
| 85 |
+
return unpacked_tensor;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
void pack_rgb(
|
| 89 |
+
const at::Tensor& unpacked_tensor, // IN
|
| 90 |
+
const at::Tensor& packed_tensor // OUT
|
| 91 |
+
) {
|
| 92 |
+
// Convert from unpacked channels last 3-channels or 4-channels tensor into original data layout.
|
| 93 |
+
|
| 94 |
+
uint8_t* unpacked = (uint8_t*)unpacked_tensor.data_ptr<uint8_t>();
|
| 95 |
+
uint8_t* packed = (uint8_t*)packed_tensor.data_ptr<uint8_t>();
|
| 96 |
+
auto num_pixels = packed_tensor.size(1) * packed_tensor.size(2);
|
| 97 |
+
auto num_channels = packed_tensor.size(0);
|
| 98 |
+
|
| 99 |
+
auto unpacked_increment = unpacked_tensor.size(0);
|
| 100 |
+
auto packed_increment = packed_tensor.stride(2);
|
| 101 |
+
auto packed_stride = packed_tensor.stride(0);
|
| 102 |
+
|
| 103 |
+
TORCH_INTERNAL_ASSERT(unpacked_increment == 3 || unpacked_increment == 4);
|
| 104 |
+
|
| 105 |
+
for (const auto i C10_UNUSED : c10::irange(num_pixels)) {
|
| 106 |
+
for (const auto j : c10::irange(num_channels)) {
|
| 107 |
+
packed[j * packed_stride] = unpacked[j];
|
| 108 |
+
}
|
| 109 |
+
unpacked += unpacked_increment;
|
| 110 |
+
packed += packed_increment;
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
void ImagingResampleHorizontalConvolution8u4x(
|
| 115 |
+
uint8_t* C10_RESTRICT lineOut0,
|
| 116 |
+
uint8_t* C10_RESTRICT lineOut1,
|
| 117 |
+
uint8_t* C10_RESTRICT lineOut2,
|
| 118 |
+
uint8_t* C10_RESTRICT lineOut3,
|
| 119 |
+
int64_t out_xsize,
|
| 120 |
+
const uint8_t* C10_RESTRICT lineIn0,
|
| 121 |
+
const uint8_t* C10_RESTRICT lineIn1,
|
| 122 |
+
const uint8_t* C10_RESTRICT lineIn2,
|
| 123 |
+
const uint8_t* C10_RESTRICT lineIn3,
|
| 124 |
+
int64_t in_xsize,
|
| 125 |
+
const int64_t* idx_ptr_xmin,
|
| 126 |
+
const int64_t* idx_ptr_size,
|
| 127 |
+
const int16_t* kk,
|
| 128 |
+
int kmax,
|
| 129 |
+
unsigned int coefs_precision,
|
| 130 |
+
int64_t num_channels,
|
| 131 |
+
bool is_last_line);
|
| 132 |
+
|
| 133 |
+
void ImagingResampleHorizontalConvolution8u(
|
| 134 |
+
uint8_t* C10_RESTRICT lineOut,
|
| 135 |
+
int64_t out_xsize,
|
| 136 |
+
const uint8_t* C10_RESTRICT lineIn,
|
| 137 |
+
int64_t in_xsize,
|
| 138 |
+
const int64_t* idx_ptr_xmin,
|
| 139 |
+
const int64_t* idx_ptr_size,
|
| 140 |
+
const int16_t* kk,
|
| 141 |
+
int kmax,
|
| 142 |
+
unsigned int coefs_precision,
|
| 143 |
+
int64_t num_channels,
|
| 144 |
+
bool is_last_line);
|
| 145 |
+
|
| 146 |
+
void ImagingResampleVerticalConvolution8u(
|
| 147 |
+
uint8_t* C10_RESTRICT lineOut,
|
| 148 |
+
const uint8_t* C10_RESTRICT lineIn,
|
| 149 |
+
int64_t xsize,
|
| 150 |
+
int64_t ids_min,
|
| 151 |
+
int64_t ids_size,
|
| 152 |
+
const int16_t* k,
|
| 153 |
+
unsigned int coefs_precision,
|
| 154 |
+
int64_t num_channels);
|
| 155 |
+
|
| 156 |
+
template<int num_channels>
|
| 157 |
+
void ImagingResampleHorizontal(
|
| 158 |
+
const at::Tensor & unpacked_output,
|
| 159 |
+
const at::Tensor & unpacked_input,
|
| 160 |
+
int ksize,
|
| 161 |
+
const std::vector<at::Tensor>& horiz_indices_weights,
|
| 162 |
+
unsigned int horiz_weights_precision) {
|
| 163 |
+
|
| 164 |
+
// Interpolation horizontal pass: we compute x-axis (image width) interpolation outputs.
|
| 165 |
+
|
| 166 |
+
// Input data is stored as
|
| 167 |
+
// input = [r[0], g[0], b[0], a[0], r[1], g[1], b[1], a[1], r[2], g[2], b[2], a[2], ...]
|
| 168 |
+
// Weights are float values computed for each output pixel and rescaled to uint16:
|
| 169 |
+
// weights[i] = [w[i, 0], w[i, 1], ..., w[i, K-1]]
|
| 170 |
+
// We want to compute the output as following:
|
| 171 |
+
// output = [oR[0], oG[0], oB[0], oA[0], oR[1], oG[1], oB[1], oA[1], ...]
|
| 172 |
+
// where
|
| 173 |
+
// oR[yoffset + i] = r[yoffset + xmin[i]] * w[i, 0] + ... + r[yoffset + xmin[i] + K-1] * w[i, K-1]
|
| 174 |
+
// oG[yoffset + i] = g[yoffset + xmin[i]] * w[i, 0] + ... + g[yoffset + xmin[i] + K-1] * w[i, K-1]
|
| 175 |
+
// oB[yoffset + i] = b[yoffset + xmin[i]] * w[i, 0] + ... + b[yoffset + xmin[i] + K-1] * w[i, K-1]
|
| 176 |
+
//
|
| 177 |
+
|
| 178 |
+
// TODO: we may want to merge that into the fallback code (currently called
|
| 179 |
+
// basic_loop_aa_horizontal<uint8_t>)
|
| 180 |
+
// Although this may not be needed if / when we port all this code to use
|
| 181 |
+
// Vec.h since this would potentially give us another fall-back implem
|
| 182 |
+
|
| 183 |
+
const int16_t* kk = (int16_t*)(horiz_indices_weights[3].const_data_ptr<double>());
|
| 184 |
+
|
| 185 |
+
auto xout = unpacked_output.size(2);
|
| 186 |
+
auto yout = unpacked_output.size(1);
|
| 187 |
+
auto xin = unpacked_input.size(2);
|
| 188 |
+
TORCH_INTERNAL_ASSERT(num_channels == unpacked_input.size(0));
|
| 189 |
+
|
| 190 |
+
const int64_t* idx_ptr_xmin = horiz_indices_weights[0].const_data_ptr<int64_t>();
|
| 191 |
+
const int64_t* idx_ptr_size = horiz_indices_weights[1].const_data_ptr<int64_t>();
|
| 192 |
+
|
| 193 |
+
uint8_t* unpacked_output_p = unpacked_output.data_ptr<uint8_t>();
|
| 194 |
+
const uint8_t* unpacked_input_p = unpacked_input.const_data_ptr<uint8_t>();
|
| 195 |
+
|
| 196 |
+
int64_t yy = 0;
|
| 197 |
+
auto xout_stride = xout * num_channels;
|
| 198 |
+
auto xin_stride = xin * num_channels;
|
| 199 |
+
for (; yy < yout - 3; yy += 4) {
|
| 200 |
+
ImagingResampleHorizontalConvolution8u4x(
|
| 201 |
+
unpacked_output_p + yy * xout_stride,
|
| 202 |
+
unpacked_output_p + (yy + 1) * xout_stride,
|
| 203 |
+
unpacked_output_p + (yy + 2) * xout_stride,
|
| 204 |
+
unpacked_output_p + (yy + 3) * xout_stride,
|
| 205 |
+
xout,
|
| 206 |
+
unpacked_input_p + yy * xin_stride,
|
| 207 |
+
unpacked_input_p + (yy + 1) * xin_stride,
|
| 208 |
+
unpacked_input_p + (yy + 2) * xin_stride,
|
| 209 |
+
unpacked_input_p + (yy + 3) * xin_stride,
|
| 210 |
+
xin,
|
| 211 |
+
idx_ptr_xmin,
|
| 212 |
+
idx_ptr_size,
|
| 213 |
+
kk,
|
| 214 |
+
ksize,
|
| 215 |
+
horiz_weights_precision,
|
| 216 |
+
num_channels,
|
| 217 |
+
yy + 3 == yout - 1);
|
| 218 |
+
}
|
| 219 |
+
for (; yy < yout; yy++) {
|
| 220 |
+
ImagingResampleHorizontalConvolution8u(
|
| 221 |
+
unpacked_output_p + yy * xout_stride,
|
| 222 |
+
xout,
|
| 223 |
+
unpacked_input_p + yy * xin_stride,
|
| 224 |
+
xin,
|
| 225 |
+
idx_ptr_xmin,
|
| 226 |
+
idx_ptr_size,
|
| 227 |
+
kk,
|
| 228 |
+
ksize,
|
| 229 |
+
horiz_weights_precision,
|
| 230 |
+
num_channels,
|
| 231 |
+
yy == yout - 1);
|
| 232 |
+
}
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
void ImagingResampleVertical(
|
| 236 |
+
const at::Tensor & unpacked_output,
|
| 237 |
+
const at::Tensor & unpacked_input,
|
| 238 |
+
int ksize,
|
| 239 |
+
const std::vector<at::Tensor>& vert_indices_weights,
|
| 240 |
+
unsigned int vert_weights_precision) {
|
| 241 |
+
|
| 242 |
+
// Interpolation vertical pass: we compute y-axis interpolation outputs.
|
| 243 |
+
// Input data is stored as
|
| 244 |
+
// input = [r[0], g[0], b[0], a[0], r[1], g[1], b[1], a[1], r[2], g[2], b[2], a[2], ...]
|
| 245 |
+
// Weights are float values computed for each output pixel and rescaled to uint16:
|
| 246 |
+
// weights[i] = [w[i, 0], w[i, 1], ..., w[i, K-1]]
|
| 247 |
+
// We want to compute the output as following:
|
| 248 |
+
// output = [oR[0], oG[0], oB[0], oA[0], oR[1], oG[1], oB[1], oA[1], ...]
|
| 249 |
+
// where
|
| 250 |
+
// oR[xoffset + i] = r[xoffset + ymin[i]] * w[i, 0] + ... + r[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1]
|
| 251 |
+
// oG[xoffset + i] = g[xoffset + ymin[i]] * w[i, 0] + ... + g[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1]
|
| 252 |
+
// oB[xoffset + i] = b[xoffset + ymin[i]] * w[i, 0] + ... + b[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1]
|
| 253 |
+
|
| 254 |
+
// TODO: we may want to merge that into the fallback code (currently called
|
| 255 |
+
// basic_loop_aa_vertical<uint8_t>)
|
| 256 |
+
// Although this may not be needed if / when we port all this code to use
|
| 257 |
+
// Vec.h since this would potentially give us another fall-back implem
|
| 258 |
+
const int16_t* kk = (int16_t*)(vert_indices_weights[3].const_data_ptr<double>());
|
| 259 |
+
|
| 260 |
+
const int64_t* idx_ptr_xmin = vert_indices_weights[0].const_data_ptr<int64_t>();
|
| 261 |
+
const int64_t* idx_ptr_size = vert_indices_weights[1].const_data_ptr<int64_t>();
|
| 262 |
+
|
| 263 |
+
uint8_t* unpacked_output_p = unpacked_output.data_ptr<uint8_t>();
|
| 264 |
+
const uint8_t* unpacked_input_p = unpacked_input.const_data_ptr<uint8_t>();
|
| 265 |
+
|
| 266 |
+
auto xout = unpacked_output.size(2);
|
| 267 |
+
auto yout = unpacked_output.size(1);
|
| 268 |
+
const auto num_channels = unpacked_input.size(0);
|
| 269 |
+
TORCH_INTERNAL_ASSERT(num_channels == unpacked_output.size(0));
|
| 270 |
+
|
| 271 |
+
auto xout_stride = xout * num_channels;
|
| 272 |
+
for (const auto yy : c10::irange(yout)) {
|
| 273 |
+
const auto* k = &kk[yy * ksize];
|
| 274 |
+
auto ids_min = idx_ptr_xmin[yy];
|
| 275 |
+
auto ids_size = idx_ptr_size[yy];
|
| 276 |
+
ImagingResampleVerticalConvolution8u(
|
| 277 |
+
unpacked_output_p + yy * xout_stride,
|
| 278 |
+
unpacked_input_p,
|
| 279 |
+
xout,
|
| 280 |
+
ids_min,
|
| 281 |
+
ids_size,
|
| 282 |
+
k,
|
| 283 |
+
vert_weights_precision,
|
| 284 |
+
num_channels);
|
| 285 |
+
}
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
// This is the only public entry point in this file. It supports bilinear or bicubic
|
| 289 |
+
// mode for uint8 dtype when C <= 4, with or without antialias. The
|
| 290 |
+
// implem is based on PIL-SIMD.
|
| 291 |
+
// Its equivalent implementation (fallback) for when AVX isn't supported or when
|
| 292 |
+
// C > 4 is separable_upsample_generic_Nd_kernel_impl() There are a bunch of
|
| 293 |
+
// future improvement that can be done: look for the TODOs in this file.
|
| 294 |
+
// For details on how the weights are computed and how the multiplications are
|
| 295 |
+
// run on int (instead of float weights), see
|
| 296 |
+
// [ Weights computation for uint8_t and multiplication trick ]
|
| 297 |
+
// For details on how the AVX kernels are implemented, see
|
| 298 |
+
// https://gist.github.com/NicolasHug/47c97d731f05eaad5694c173849b86f5
|
| 299 |
+
// See also [ Support for antialias=False as a subcase of antialias=True ] to
|
| 300 |
+
// learn more about how the antialias=False case is computed. The same holds
|
| 301 |
+
// here: all these kernels are general enough to handle an arbitrary number of
|
| 302 |
+
// weights, but when aa=False they could be optimized further.
|
| 303 |
+
template <typename scale_type, class F>
|
| 304 |
+
void upsample_avx_bilinear_bicubic_uint8(
|
| 305 |
+
const at::Tensor& input_,
|
| 306 |
+
const at::Tensor& output,
|
| 307 |
+
bool align_corners,
|
| 308 |
+
const scale_type& scales,
|
| 309 |
+
bool antialias) {
|
| 310 |
+
auto batch_size = input_.size(0);
|
| 311 |
+
auto num_channels = input_.size(1);
|
| 312 |
+
auto xin = input_.size(3);
|
| 313 |
+
auto yin = input_.size(2);
|
| 314 |
+
auto xout = output.size(3);
|
| 315 |
+
auto yout = output.size(2);
|
| 316 |
+
|
| 317 |
+
if (xin == xout && yin == yout) {
|
| 318 |
+
output.copy_(input_);
|
| 319 |
+
return;
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
at::Tensor input = input_;
|
| 323 |
+
if (!(input.is_contiguous() || input.is_contiguous(at::MemoryFormat::ChannelsLast))) {
|
| 324 |
+
// If input is not contiguous with memory format channels first or channels last,
|
| 325 |
+
// we explicitly convert the input to contiguous channels last memory format.
|
| 326 |
+
// This simplifies the rest of the code and let us assume that the format is only contiguous channels first or channels last,
|
| 327 |
+
// Most tensors going through this `if` block won't need to go through unpacking, but those having C < 3 may
|
| 328 |
+
// have to (this means 2 copies are made). We could avoid the extra copy by handling non-contiguous input
|
| 329 |
+
// directly within unpack_rgb() and pack_rgb(), but initial attempts showed that this is fairly complex.
|
| 330 |
+
input = input.contiguous(at::MemoryFormat::ChannelsLast);
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
auto need_horizontal = xout != xin;
|
| 334 |
+
auto need_vertical = yout != yin;
|
| 335 |
+
|
| 336 |
+
int ksize_horiz, ksize_vert;
|
| 337 |
+
std::vector<at::Tensor> horiz_indices_weights, vert_indices_weights;
|
| 338 |
+
unsigned int horiz_weights_precision, vert_weights_precision;
|
| 339 |
+
|
| 340 |
+
bool skip_unpacking = (num_channels == 3 || num_channels == 4) && input.is_contiguous(at::MemoryFormat::ChannelsLast);
|
| 341 |
+
bool skip_packing = (num_channels == 3 || num_channels == 4) && output.is_contiguous(at::MemoryFormat::ChannelsLast);
|
| 342 |
+
|
| 343 |
+
if (need_horizontal) {
|
| 344 |
+
int interp_dim = 3;
|
| 345 |
+
auto stride = (skip_unpacking) ? num_channels : 4;
|
| 346 |
+
std::tie(horiz_indices_weights, ksize_horiz, horiz_weights_precision) =
|
| 347 |
+
F::compute_index_ranges_int16_weights(
|
| 348 |
+
/*input_size=*/xin,
|
| 349 |
+
/*output_size=*/xout,
|
| 350 |
+
/*stride=*/stride,
|
| 351 |
+
/*ndims=*/4,
|
| 352 |
+
/*reshape_dim=*/interp_dim,
|
| 353 |
+
/*align_corners=*/align_corners,
|
| 354 |
+
/*opt_scale=*/scales[interp_dim - 2],
|
| 355 |
+
/*antialias=*/antialias,
|
| 356 |
+
/*align_i32=*/true);
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
if (need_vertical) {
|
| 360 |
+
int interp_dim = 2;
|
| 361 |
+
auto stride = (skip_unpacking) ? num_channels * xout : 4 * xout;
|
| 362 |
+
std::tie(vert_indices_weights, ksize_vert, vert_weights_precision) =
|
| 363 |
+
F::compute_index_ranges_int16_weights(
|
| 364 |
+
/*input_size=*/yin,
|
| 365 |
+
/*output_size=*/yout,
|
| 366 |
+
/*stride=*/stride,
|
| 367 |
+
/*ndims=*/4,
|
| 368 |
+
/*reshape_dim=*/interp_dim,
|
| 369 |
+
/*align_corners=*/align_corners,
|
| 370 |
+
/*opt_scale=*/scales[interp_dim - 2],
|
| 371 |
+
/*antialias=*/antialias,
|
| 372 |
+
/*align_i32=*/true);
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
at::Tensor buffer_horiz, buffer_vert;
|
| 376 |
+
// Minor optimization: we can avoid allocating an extra buffer if we're performing
|
| 377 |
+
// horizontal-only or vertical-only interpolation, and if the tensor doesn't
|
| 378 |
+
// need repacking
|
| 379 |
+
if (need_horizontal && (need_vertical || !skip_packing)) {
|
| 380 |
+
auto c = (skip_unpacking) ? num_channels : 4;
|
| 381 |
+
buffer_horiz = at::empty({c, yin, xout}, input.options());
|
| 382 |
+
}
|
| 383 |
+
if (need_vertical && !skip_packing) {
|
| 384 |
+
auto c = (skip_unpacking) ? num_channels : 4;
|
| 385 |
+
buffer_vert = at::empty({c, yout, xout}, input.options());
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
for (const auto i : c10::irange(batch_size)) {
|
| 389 |
+
|
| 390 |
+
at::Tensor unpacked_input = (skip_unpacking) ? input[i] : unpack_rgb(input[i]);
|
| 391 |
+
at::Tensor unpacked_output;
|
| 392 |
+
|
| 393 |
+
if (need_horizontal) {
|
| 394 |
+
at::Tensor unpacked_output_temp = (need_vertical || !skip_packing) ? buffer_horiz : output[i];
|
| 395 |
+
|
| 396 |
+
if (skip_unpacking && num_channels == 3) {
|
| 397 |
+
ImagingResampleHorizontal<3>(
|
| 398 |
+
unpacked_output_temp,
|
| 399 |
+
unpacked_input,
|
| 400 |
+
ksize_horiz,
|
| 401 |
+
horiz_indices_weights,
|
| 402 |
+
horiz_weights_precision);
|
| 403 |
+
} else {
|
| 404 |
+
ImagingResampleHorizontal<4>(
|
| 405 |
+
unpacked_output_temp,
|
| 406 |
+
unpacked_input,
|
| 407 |
+
ksize_horiz,
|
| 408 |
+
horiz_indices_weights,
|
| 409 |
+
horiz_weights_precision);
|
| 410 |
+
}
|
| 411 |
+
unpacked_output = unpacked_input = unpacked_output_temp;
|
| 412 |
+
}
|
| 413 |
+
if (need_vertical) {
|
| 414 |
+
unpacked_output = (skip_packing) ? output[i] : buffer_vert;
|
| 415 |
+
|
| 416 |
+
ImagingResampleVertical(
|
| 417 |
+
unpacked_output,
|
| 418 |
+
unpacked_input,
|
| 419 |
+
ksize_vert,
|
| 420 |
+
vert_indices_weights,
|
| 421 |
+
vert_weights_precision
|
| 422 |
+
);
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
TORCH_INTERNAL_ASSERT(unpacked_output.defined());
|
| 426 |
+
|
| 427 |
+
if (!skip_packing) {
|
| 428 |
+
pack_rgb(unpacked_output, output[i]);
|
| 429 |
+
}
|
| 430 |
+
}
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
void ImagingResampleHorizontalConvolution8u4x(
|
| 434 |
+
uint8_t* C10_RESTRICT lineOut0,
|
| 435 |
+
uint8_t* C10_RESTRICT lineOut1,
|
| 436 |
+
uint8_t* C10_RESTRICT lineOut2,
|
| 437 |
+
uint8_t* C10_RESTRICT lineOut3,
|
| 438 |
+
int64_t out_xsize,
|
| 439 |
+
const uint8_t* C10_RESTRICT lineIn0,
|
| 440 |
+
const uint8_t* C10_RESTRICT lineIn1,
|
| 441 |
+
const uint8_t* C10_RESTRICT lineIn2,
|
| 442 |
+
const uint8_t* C10_RESTRICT lineIn3,
|
| 443 |
+
int64_t in_xsize,
|
| 444 |
+
const int64_t* idx_ptr_xmin,
|
| 445 |
+
const int64_t* idx_ptr_size,
|
| 446 |
+
const int16_t* kk,
|
| 447 |
+
int kmax,
|
| 448 |
+
unsigned int coefs_precision,
|
| 449 |
+
int64_t num_channels,
|
| 450 |
+
bool is_last_line) {
|
| 451 |
+
|
| 452 |
+
// Interpolation horizontal pass processing together 4 vertical lines.
|
| 453 |
+
// - Input data format is RGBA or RGB with R,G,B,A being uint8. In case of RGBA
|
| 454 |
+
// we can encode 4 values as a single uint32 value.
|
| 455 |
+
// - We split the size of weight vector for a given output index as a sum:
|
| 456 |
+
// ids_size = num_blocks_4 * 4 + num_blocks_2 * 2 + num_blocks_1.
|
| 457 |
+
// - We load and process 4 weights values in a loop ("block 4") then we process 2 weights values
|
| 458 |
+
// in another loop ("block 2") and finally we process 1 weights value in the final loop ("block 1").
|
| 459 |
+
|
| 460 |
+
// Define shuffling masks (low/high) for num_channels 4 and 3
|
| 461 |
+
// Mask low casts lower half of each lane to epi16 and reorder RGBARGBA -> RRGGBBAA:
|
| 462 |
+
// [r1 g1 b1 a1 r2 g2 b2 a2 ... | R1 G1 B1 A1 R2 G2 B2 A2 ... ] ->
|
| 463 |
+
// [r1 0 r2 0 g1 0 g2 0 b1 0 b2 0 a1 0 a2 0 | R1 0 R2 0 G1 0 G2 0 B1 0 B2 0 A1 0 A2 0]
|
| 464 |
+
// Mask high casts upper half of each lane to epi16 and reorder RGBARGBA -> RRGGBBAA::
|
| 465 |
+
// [ ... r3 g3 b3 a3 r4 g4 b4 a4 | ... R3 G3 B3 A3 R4 G4 B4 A4 ] ->
|
| 466 |
+
// [r3 0 r4 0 g3 0 g4 0 b3 0 b4 0 a3 0 a4 0 | R3 0 R4 0 G3 0 G4 0 B3 0 B4 0 A3 0 A4 0]
|
| 467 |
+
|
| 468 |
+
const auto mask_low_c4 = _mm256_set_epi8(
|
| 469 |
+
-1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0,
|
| 470 |
+
-1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);
|
| 471 |
+
const auto mask_high_c4 = _mm256_set_epi8(
|
| 472 |
+
-1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8,
|
| 473 |
+
-1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8);
|
| 474 |
+
const auto mask_low_c3 = _mm256_set_epi8(
|
| 475 |
+
-1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0,
|
| 476 |
+
-1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0);
|
| 477 |
+
const auto mask_high_c3 = _mm256_set_epi8(
|
| 478 |
+
-1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6,
|
| 479 |
+
-1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6);
|
| 480 |
+
|
| 481 |
+
const auto mask_low = (num_channels == 3) ? mask_low_c3 : mask_low_c4;
|
| 482 |
+
const auto mask_high = (num_channels == 3) ? mask_high_c3 : mask_high_c4;
|
| 483 |
+
|
| 484 |
+
const auto stride = num_channels * sizeof(uint8_t);
|
| 485 |
+
|
| 486 |
+
TORCH_INTERNAL_ASSERT(stride == 3 || stride == 4);
|
| 487 |
+
|
| 488 |
+
// out_xsize = output width, out_x = output x index
|
| 489 |
+
// ids_min is the input offset index corresponding to out_x
|
| 490 |
+
// ids_size is the interpolation size for out_x
|
| 491 |
+
|
| 492 |
+
// Let's precompute ids_size limits for block 4 and block 2.
|
| 493 |
+
//
|
| 494 |
+
// In block 4 (4 means we process 4 weight values together), we read input data
|
| 495 |
+
// with _mm_loadu_si128, i.e. 16 bytes, per one line:
|
| 496 |
+
// lineIn0 + stride * (i + ids_min) + 16 <= lineIn0 + stride * (ids_size + ids_min)
|
| 497 |
+
// --> i <= ids_size - 16.0 / stride
|
| 498 |
+
// Strict boundary:
|
| 499 |
+
// --> i < ids_size + 1 - int(ceil(16.0 / stride)) = ids_size - b4_delta
|
| 500 |
+
// Soft boundary for reading inside the buffer except its boundaries:
|
| 501 |
+
// --> i < ids_size + 1 - int(16.0 / stride) = ids_size - b4_delta_soft
|
| 502 |
+
// RGBA: b4_delta = b4_delta_soft = 3
|
| 503 |
+
// RGB : b4_delta = 5
|
| 504 |
+
// RGB : b4_delta_soft = 4
|
| 505 |
+
const auto b4_delta = (stride == 4) ? 3 : ((is_last_line) ? 5 : 4);
|
| 506 |
+
|
| 507 |
+
// In block 2 (2 means we process 2 weights values together), we read input data
|
| 508 |
+
// with _mm_loadl_epi64, i.e. 8 bytes, per one line:
|
| 509 |
+
// lineIn0 + stride * (i + ids_min) + 8 <= lineIn0 + stride * (ids_size + ids_min)
|
| 510 |
+
// --> i <= ids_size - 8.0 / stride
|
| 511 |
+
// Strict boundary:
|
| 512 |
+
// --> i < ids_size + 1 - int(ceil(8.0 / stride)) = ids_size - b2_delta
|
| 513 |
+
// Soft boundary for reading inside the buffer except its boundaries:
|
| 514 |
+
// --> i < ids_size + 1 - int(8.0 / stride) = ids_size - b2_delta_soft
|
| 515 |
+
// RGBA: b2_delta = b2_delta_soft = 1
|
| 516 |
+
// RGB : b2_delta = 2
|
| 517 |
+
// RGB : b2_delta_soft = 1
|
| 518 |
+
const auto b2_delta = (stride == 4) ? 1 : ((is_last_line) ? 2 : 1);
|
| 519 |
+
|
| 520 |
+
const auto max_out_x_strided = out_xsize * stride;
|
| 521 |
+
const auto max_in_x_strided = in_xsize * stride;
|
| 522 |
+
|
| 523 |
+
const auto zero = _mm256_setzero_si256();
|
| 524 |
+
const auto initial = _mm256_set1_epi32(1 << (coefs_precision - 1));
|
| 525 |
+
|
| 526 |
+
for (const auto out_x : c10::irange(out_xsize)) {
|
| 527 |
+
const auto ids_min = idx_ptr_xmin[out_x];
|
| 528 |
+
const auto ids_size = idx_ptr_size[out_x];
|
| 529 |
+
const auto * k = &kk[out_x * kmax];
|
| 530 |
+
int64_t i = 0;
|
| 531 |
+
|
| 532 |
+
auto sss0 = initial;
|
| 533 |
+
auto sss1 = initial;
|
| 534 |
+
|
| 535 |
+
const auto * lineIn0_min = lineIn0 + ids_min;
|
| 536 |
+
const auto * lineIn1_min = lineIn1 + ids_min;
|
| 537 |
+
const auto * lineIn2_min = lineIn2 + ids_min;
|
| 538 |
+
const auto * lineIn3_min = lineIn3 + ids_min;
|
| 539 |
+
|
| 540 |
+
// block 4
|
| 541 |
+
for (; i < ids_size - b4_delta; i += 4) {
|
| 542 |
+
// Load 4 values from weight vector
|
| 543 |
+
// mmk0 = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ...]
|
| 544 |
+
// mmk1 = [wl_2 wh_2 wl_3 wh_3 wl_2 wh_2 wl_3 wh_3 ...]
|
| 545 |
+
const auto mmk0 = _mm256_set1_epi32(*(int32_t*)&k[i]);
|
| 546 |
+
const auto mmk1 = _mm256_set1_epi32(*(int32_t*)&k[i + 2]);
|
| 547 |
+
|
| 548 |
+
// RGBA: Load 8 pixels (4 per line) from input lines 0 and 1:
|
| 549 |
+
// source = [
|
| 550 |
+
// r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3
|
| 551 |
+
// R0 G0 B0 A0 R1 G1 B1 A1 R2 G2 B2 A2 R3 G3 B3 A3
|
| 552 |
+
// ]
|
| 553 |
+
// RGB: Load 10 pixels (5 per line)
|
| 554 |
+
// source = [
|
| 555 |
+
// r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5
|
| 556 |
+
// R0 G0 B0 R1 G1 B1 R2 G2 B2 R3 G3 B3 R4 G4 B4 R5
|
| 557 |
+
// ]
|
| 558 |
+
auto source = _mm256_inserti128_si256(_mm256_castsi128_si256(
|
| 559 |
+
_mm_loadu_si128((__m128i *) (lineIn0_min + stride * i))),
|
| 560 |
+
_mm_loadu_si128((__m128i *) (lineIn1_min + stride * i)), 1);
|
| 561 |
+
|
| 562 |
+
// Apply mask_low:
|
| 563 |
+
// RGBA:
|
| 564 |
+
// [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 A0 0 A1 0]
|
| 565 |
+
// RGB:
|
| 566 |
+
// [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 0 0 0 0]
|
| 567 |
+
auto pix1 = _mm256_shuffle_epi8(source, mask_low);
|
| 568 |
+
// Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision
|
| 569 |
+
sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk0));
|
| 570 |
+
|
| 571 |
+
// Apply mask_high:
|
| 572 |
+
// RGBA:
|
| 573 |
+
// [r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 a2 0 a3 0 | R2 0 R3 0 G2 0 G3 0 B2 0 B3 0 A2 0 A3 0]
|
| 574 |
+
// RGB:
|
| 575 |
+
// [r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 0 0 0 0 | R2 0 R3 0 G2 0 G3 0 B2 0 B3 0 0 0 0 0]
|
| 576 |
+
auto pix2 = _mm256_shuffle_epi8(source, mask_high);
|
| 577 |
+
// Compute output value as C += w2 * C2 + w3 * C3 for each channel in 32-bit precision
|
| 578 |
+
sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix2, mmk1));
|
| 579 |
+
|
| 580 |
+
// Same as above to next lines 2 and 3:
|
| 581 |
+
auto source2 = _mm256_inserti128_si256(_mm256_castsi128_si256(
|
| 582 |
+
_mm_loadu_si128((__m128i *) (lineIn2_min + stride * i))),
|
| 583 |
+
_mm_loadu_si128((__m128i *) (lineIn3_min + stride * i)), 1);
|
| 584 |
+
auto pix3 = _mm256_shuffle_epi8(source2, mask_low);
|
| 585 |
+
sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix3, mmk0));
|
| 586 |
+
auto pix4 = _mm256_shuffle_epi8(source2, mask_high);
|
| 587 |
+
sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix4, mmk1));
|
| 588 |
+
}
|
| 589 |
+
|
| 590 |
+
// block 2
|
| 591 |
+
for (; i < ids_size - b2_delta; i += 2) {
|
| 592 |
+
// Load 2 values from weight vector
|
| 593 |
+
// mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ...]
|
| 594 |
+
const auto mmk = _mm256_set1_epi32(*(int32_t*)&k[i]);
|
| 595 |
+
|
| 596 |
+
// Load 4 pixels (2 per line) from input lines 0 and 1:
|
| 597 |
+
// RGBA: source1 = [
|
| 598 |
+
// r0 g0 b0 a0 r1 g1 b1 a1 0 0 0 0 0 0 0 0
|
| 599 |
+
// R0 G0 B0 A0 R1 G1 B1 A1 0 0 0 0 0 0 0 0
|
| 600 |
+
// ]
|
| 601 |
+
// RGB: source1 = [
|
| 602 |
+
// r0 g0 b0 r1 g1 b1 r2 0 0 0 0 0 0 0 0
|
| 603 |
+
// R0 G0 B0 R1 G1 B1 R2 0 0 0 0 0 0 0 0
|
| 604 |
+
// ]
|
| 605 |
+
auto source1 = _mm256_inserti128_si256(_mm256_castsi128_si256(
|
| 606 |
+
_mm_loadl_epi64((__m128i *) (lineIn0_min + stride * i))),
|
| 607 |
+
_mm_loadl_epi64((__m128i *) (lineIn1_min + stride * i)), 1);
|
| 608 |
+
// Apply mask_low:
|
| 609 |
+
// RGBA:
|
| 610 |
+
// [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 A0 0 A1 0]
|
| 611 |
+
// RGB:
|
| 612 |
+
// [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 0 0 0 0]
|
| 613 |
+
auto pix1 = _mm256_shuffle_epi8(source1, mask_low);
|
| 614 |
+
// Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision
|
| 615 |
+
sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk));
|
| 616 |
+
|
| 617 |
+
// Same as above for lines 2 and 3:
|
| 618 |
+
auto source2 = _mm256_inserti128_si256(_mm256_castsi128_si256(
|
| 619 |
+
_mm_loadl_epi64((__m128i *) (lineIn2_min + stride * i))),
|
| 620 |
+
_mm_loadl_epi64((__m128i *) (lineIn3_min + stride * i)), 1);
|
| 621 |
+
auto pix2 = _mm256_shuffle_epi8(source2, mask_low);
|
| 622 |
+
sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
|
| 623 |
+
}
|
| 624 |
+
|
| 625 |
+
// block 1
|
| 626 |
+
const auto i32_aligned = num_channels == 4;
|
| 627 |
+
for (; i < ids_size - 1; i++) {
|
| 628 |
+
// Load 1 value from weight vector
|
| 629 |
+
// mmk = [wl_0 wh_0 0 0 wl_0 wh_0 0 0 ...]
|
| 630 |
+
const auto mmk = _mm256_set1_epi32(k[i]);
|
| 631 |
+
|
| 632 |
+
// Load 2 pixels (one per line) from input lines 0 and 1:
|
| 633 |
+
// RGBA: pix1 = [
|
| 634 |
+
// r0 0 0 0 g0 0 0 0 b0 0 0 0 a0 0 0 0
|
| 635 |
+
// R0 0 0 0 G0 0 0 0 B0 0 0 0 A0 0 0 0
|
| 636 |
+
// ]
|
| 637 |
+
// RGB: pix1 = [
|
| 638 |
+
// r0 0 0 0 g0 0 0 0 b0 0 0 0 r1 0 0 0
|
| 639 |
+
// R0 0 0 0 G0 0 0 0 B0 0 0 0 R1 0 0 0
|
| 640 |
+
// ]
|
| 641 |
+
auto pix1 = _mm256_inserti128_si256(_mm256_castsi128_si256(
|
| 642 |
+
mm_cvtepu8_epi32(lineIn0_min + stride * i, i32_aligned)),
|
| 643 |
+
mm_cvtepu8_epi32(lineIn1_min + stride * i, i32_aligned), 1);
|
| 644 |
+
// Compute output value as C += w0 * C0 for each channel in 32-bit precision
|
| 645 |
+
sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk));
|
| 646 |
+
|
| 647 |
+
// Same as above for lines 2 and 3
|
| 648 |
+
auto pix2 = _mm256_inserti128_si256(_mm256_castsi128_si256(
|
| 649 |
+
mm_cvtepu8_epi32(lineIn2_min + stride * i, i32_aligned)),
|
| 650 |
+
mm_cvtepu8_epi32(lineIn3_min + stride * i, i32_aligned), 1);
|
| 651 |
+
sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
|
| 652 |
+
}
|
| 653 |
+
|
| 654 |
+
if (i == ids_size - 1) {
|
| 655 |
+
// last element
|
| 656 |
+
auto mmk = _mm256_set1_epi32(k[i]);
|
| 657 |
+
// For num_channels == 3 (3 bytes = one pixel) we tolerate to read 4 bytes
|
| 658 |
+
// lines 0, 1 and 2 wont go out of allocated memory bounds
|
| 659 |
+
auto pix = _mm256_inserti128_si256(_mm256_castsi128_si256(
|
| 660 |
+
mm_cvtepu8_epi32(lineIn0_min + stride * i, i32_aligned)),
|
| 661 |
+
mm_cvtepu8_epi32(lineIn1_min + stride * i, i32_aligned), 1);
|
| 662 |
+
sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix, mmk));
|
| 663 |
+
|
| 664 |
+
auto p0 = mm_cvtepu8_epi32(lineIn2_min + stride * i, i32_aligned);
|
| 665 |
+
__m128i p1;
|
| 666 |
+
if (num_channels == 3 && C10_UNLIKELY(is_last_line && ids_min + stride * i + 4 >= max_in_x_strided)) {
|
| 667 |
+
uint8_t input[4];
|
| 668 |
+
std::memcpy(input, lineIn3_min + stride * i, 3);
|
| 669 |
+
p1 = mm_cvtepu8_epi32(input, true);
|
| 670 |
+
} else {
|
| 671 |
+
p1 = mm_cvtepu8_epi32(lineIn3_min + stride * i, i32_aligned);
|
| 672 |
+
}
|
| 673 |
+
auto pix2 = _mm256_inserti128_si256(_mm256_castsi128_si256(p0), p1, 1);
|
| 674 |
+
sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
|
| 675 |
+
}
|
| 676 |
+
|
| 677 |
+
// Convert fixed point values back to integers (truncating)
|
| 678 |
+
sss0 = _mm256_srai_epi32(sss0, coefs_precision);
|
| 679 |
+
sss1 = _mm256_srai_epi32(sss1, coefs_precision);
|
| 680 |
+
// Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
|
| 681 |
+
// (a a a a b b b b c c c c d d d d) -> (a a b b c c d d 0 0 0 0 0 0 0 0)
|
| 682 |
+
sss0 = _mm256_packs_epi32(sss0, zero);
|
| 683 |
+
sss1 = _mm256_packs_epi32(sss1, zero);
|
| 684 |
+
// Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
|
| 685 |
+
// (a a b b c c d d) -> (a b c d 0 0 0 0)
|
| 686 |
+
sss0 = _mm256_packus_epi16(sss0, zero);
|
| 687 |
+
sss1 = _mm256_packus_epi16(sss1, zero);
|
| 688 |
+
|
| 689 |
+
// Write the output into single uint32
|
| 690 |
+
// (a b c d) -> x_uint32
|
| 691 |
+
auto o0 = _mm_cvtsi128_si32(_mm256_castsi256_si128(sss0));
|
| 692 |
+
auto o1 = _mm_cvtsi128_si32(_mm256_extracti128_si256(sss0, 1));
|
| 693 |
+
auto o2 = _mm_cvtsi128_si32(_mm256_castsi256_si128(sss1));
|
| 694 |
+
auto o3 = _mm_cvtsi128_si32(_mm256_extracti128_si256(sss1, 1));
|
| 695 |
+
|
| 696 |
+
const auto out_x_strided = stride * out_x;
|
| 697 |
+
|
| 698 |
+
if (num_channels == 3 && C10_UNLIKELY(out_x_strided + 4 >= max_out_x_strided)) {
|
| 699 |
+
// Memcpy 4-bytes is faster than 3-bytes and this is a boundary case when we want to write
|
| 700 |
+
// 4 bytes (R G B | X) to the output buffer (X1 X2 X3 | R1).
|
| 701 |
+
// The 4th byte in the register (X) has a garbage value and 4th byte in the output buffer (R1) has a correct
|
| 702 |
+
// value which was previously computed by another line. In other words, it means that we can not overwrite
|
| 703 |
+
// it by simply writing 4 bytes from the register to the output. We'll do the following:
|
| 704 |
+
// v----------|
|
| 705 |
+
// Output = [... X1 X2 X3 | R1 G1 B1 R2 ...]
|
| 706 |
+
// First, we write R1 value to the 4th byte of (R G B | X) -> (R G B | R1)
|
| 707 |
+
// Second, we write 4 bytes from the register to the output: (X1 X2 X3 | R1) -> (R G B | R1)
|
| 708 |
+
// Output = [... R G B | R1 G1 B1 R2 ...]
|
| 709 |
+
|
| 710 |
+
_write_endline_rgb_as_uint32(lineOut0 + out_x_strided, o0);
|
| 711 |
+
_write_endline_rgb_as_uint32(lineOut1 + out_x_strided, o1);
|
| 712 |
+
_write_endline_rgb_as_uint32(lineOut2 + out_x_strided, o2);
|
| 713 |
+
|
| 714 |
+
if (C10_UNLIKELY(is_last_line)) {
|
| 715 |
+
// When we handle the last line, we can not access the next 4 bytes
|
| 716 |
+
// as they are out of memory bounds.
|
| 717 |
+
std::memcpy(lineOut3 + out_x_strided, (uint8_t *) &o3, num_channels);
|
| 718 |
+
} else {
|
| 719 |
+
_write_endline_rgb_as_uint32(lineOut3 + out_x_strided, o3);
|
| 720 |
+
}
|
| 721 |
+
} else if (num_channels == 3) {
|
| 722 |
+
// Memcpy 4-bytes is faster than 3-bytes and here
|
| 723 |
+
// we simply write 4 bytes (... R G B X 0 0 0 0 0 ...) where X is a garbage value
|
| 724 |
+
// that we will overwrite on the next iteration: (... R G B R G B X 0 0 ...)
|
| 725 |
+
std::memcpy(lineOut0 + out_x_strided, (uint8_t *) &o0, 4);
|
| 726 |
+
std::memcpy(lineOut1 + out_x_strided, (uint8_t *) &o1, 4);
|
| 727 |
+
std::memcpy(lineOut2 + out_x_strided, (uint8_t *) &o2, 4);
|
| 728 |
+
std::memcpy(lineOut3 + out_x_strided, (uint8_t *) &o3, 4);
|
| 729 |
+
} else {
|
| 730 |
+
// num_channels = 4 -> lineOutX + out_x_strided should be uint32 aligned
|
| 731 |
+
*(uint32_t *)(lineOut0 + out_x_strided) = o0;
|
| 732 |
+
*(uint32_t *)(lineOut1 + out_x_strided) = o1;
|
| 733 |
+
*(uint32_t *)(lineOut2 + out_x_strided) = o2;
|
| 734 |
+
*(uint32_t *)(lineOut3 + out_x_strided) = o3;
|
| 735 |
+
}
|
| 736 |
+
}
|
| 737 |
+
}
|
| 738 |
+
|
| 739 |
+
void ImagingResampleHorizontalConvolution8u(
|
| 740 |
+
uint8_t* C10_RESTRICT lineOut,
|
| 741 |
+
int64_t out_xsize,
|
| 742 |
+
const uint8_t* C10_RESTRICT lineIn,
|
| 743 |
+
int64_t in_xsize,
|
| 744 |
+
const int64_t* idx_ptr_xmin,
|
| 745 |
+
const int64_t* idx_ptr_size,
|
| 746 |
+
const int16_t* kk,
|
| 747 |
+
int kmax,
|
| 748 |
+
unsigned int coefs_precision,
|
| 749 |
+
int64_t num_channels,
|
| 750 |
+
bool is_last_line) {
|
| 751 |
+
|
| 752 |
+
// Interpolation horizontal pass processing only one vertical line.
|
| 753 |
+
// - Input data format is RGBA or RGB with R,G,B,A being uint8. In case of RGBA
|
| 754 |
+
// we can encode 4 values as a single uint32 value.
|
| 755 |
+
// - We split the size of weight vector for a given output index as a sum:
|
| 756 |
+
// ids_size = num_blocks_8 * 8 + num_blocks_4 * 4 + num_blocks_2 * 2 + num_blocks_1
|
| 757 |
+
// - We load and process 8 weights values in a loop ("block 8") then 4 weights and 2 weights values in
|
| 758 |
+
// in another loops ("block 4" and "block 2") and finally we process 1 weight value in the final loop ("block 1").
|
| 759 |
+
|
| 760 |
+
// Define various shuffling masks
|
| 761 |
+
const auto kmask_low = _mm256_set_epi8(
|
| 762 |
+
11, 10, 9, 8, 11, 10, 9, 8, 11, 10, 9, 8, 11, 10, 9, 8,
|
| 763 |
+
3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0);
|
| 764 |
+
const auto kmask_high = _mm256_set_epi8(
|
| 765 |
+
15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12,
|
| 766 |
+
7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4);
|
| 767 |
+
const auto kmask_hl = _mm256_set_epi8(
|
| 768 |
+
7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4,
|
| 769 |
+
3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0);
|
| 770 |
+
|
| 771 |
+
const auto mask_low_c4 = _mm256_set_epi8(
|
| 772 |
+
-1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0,
|
| 773 |
+
-1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);
|
| 774 |
+
const auto mask_high_c4 = _mm256_set_epi8(
|
| 775 |
+
-1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8,
|
| 776 |
+
-1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8);
|
| 777 |
+
const auto mask_low_c3 = _mm256_set_epi8(
|
| 778 |
+
-1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0,
|
| 779 |
+
-1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0);
|
| 780 |
+
const auto mask_high_c3 = _mm256_set_epi8(
|
| 781 |
+
-1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6,
|
| 782 |
+
-1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6);
|
| 783 |
+
const auto mask_hl_c3 = _mm256_set_epi8(
|
| 784 |
+
-1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6,
|
| 785 |
+
-1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0);
|
| 786 |
+
const auto mask_hl_c4 = _mm256_set_epi8(
|
| 787 |
+
-1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8,
|
| 788 |
+
-1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);
|
| 789 |
+
|
| 790 |
+
const auto mask_low128_c3 = _mm_set_epi8(
|
| 791 |
+
-1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0);
|
| 792 |
+
const auto mask_low128_c4 = _mm_set_epi8(
|
| 793 |
+
-1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);
|
| 794 |
+
|
| 795 |
+
const auto mask_low = (num_channels == 3) ? mask_low_c3 : mask_low_c4;
|
| 796 |
+
const auto mask_high = (num_channels == 3) ? mask_high_c3 : mask_high_c4;
|
| 797 |
+
const auto mask_hl = (num_channels == 3) ? mask_hl_c3 : mask_hl_c4;
|
| 798 |
+
const auto mask_low128 = (num_channels == 3) ? mask_low128_c3 : mask_low128_c4;
|
| 799 |
+
|
| 800 |
+
// out_xsize = output width, out_x = output x index
|
| 801 |
+
// ids_min is the input offset index corresponding to out_x
|
| 802 |
+
// ids_size is the interpolation size for out_x
|
| 803 |
+
|
| 804 |
+
const auto stride = num_channels * sizeof(uint8_t);
|
| 805 |
+
const auto zero = _mm_setzero_si128();
|
| 806 |
+
|
| 807 |
+
TORCH_INTERNAL_ASSERT(stride == 3 || stride == 4);
|
| 808 |
+
|
| 809 |
+
// Let's precompute ids_size limits for block 8, block 4 and block 2
|
| 810 |
+
//
|
| 811 |
+
// In block 8 (8 means we process 8 weight values together), we read at
|
| 812 |
+
// most 32 bytes input data (16 + 16 bytes for RGBA and 12 + 16 bytes for RGB)
|
| 813 |
+
// lineIn + stride * (i + ids_min) + 32 <= lineIn + stride * (ids_size + ids_min)
|
| 814 |
+
// --> i <= ids_size - 32.0 / stride
|
| 815 |
+
// Strict boundary:
|
| 816 |
+
// --> i < ids_size + 1 - int(ceil(32.0 / stride)) = ids_size - b8_delta
|
| 817 |
+
// Soft boundary for reading inside the buffer except its boundaries:
|
| 818 |
+
// --> i < ids_size + 1 - int(32.0 / stride) = ids_size - b8_delta_soft
|
| 819 |
+
// RGBA: b8_delta = b8_delta_soft = 7
|
| 820 |
+
// RGB : b8_delta = 10
|
| 821 |
+
// RGB : b8_delta_soft = 9
|
| 822 |
+
const auto b8_delta = (stride == 4) ? 7 : ((is_last_line) ? 10 : 9);
|
| 823 |
+
|
| 824 |
+
// In block 4 (4 means we process 4 weight values together), we read
|
| 825 |
+
// 16 bytes of input data.
|
| 826 |
+
// lineIn + stride * (i + ids_min) + 16 <= lineIn0 + stride * (ids_size + ids_min)
|
| 827 |
+
// --> i <= ids_size - 16.0 / stride
|
| 828 |
+
// Strict boundary:
|
| 829 |
+
// --> i < ids_size + 1 - int(ceil(16.0 / stride)) = ids_size - b4_delta
|
| 830 |
+
// Soft boundary for reading inside the buffer except its boundaries:
|
| 831 |
+
// --> i < ids_size + 1 - int(16.0 / stride) = ids_size - b4_delta_soft
|
| 832 |
+
// RGBA: b4_delta = b4_delta_soft = 3
|
| 833 |
+
// RGB : b4_delta = 5
|
| 834 |
+
// RGB : b4_delta_soft = 4
|
| 835 |
+
const auto b4_delta = (stride == 4) ? 3 : ((is_last_line) ? 5 : 4);
|
| 836 |
+
|
| 837 |
+
// In block 2 (2 means we process 2 weight values together), we read
|
| 838 |
+
// 8 bytes of input data.
|
| 839 |
+
// lineIn0 + stride * (i + ids_min) + 8 <= lineIn0 + stride * (ids_size + ids_min)
|
| 840 |
+
// --> i <= ids_size - 8.0 / stride
|
| 841 |
+
// Strict boundary:
|
| 842 |
+
// --> i < ids_size + 1 - int(ceil(8.0 / stride)) = ids_size - b2_delta
|
| 843 |
+
// Soft boundary for reading inside the buffer except its boundaries:
|
| 844 |
+
// --> i < ids_size + 1 - int(8.0 / stride) = ids_size - b2_delta_soft
|
| 845 |
+
// RGBA: b2_delta = b2_delta_soft = 1
|
| 846 |
+
// RGB : b2_delta = 2
|
| 847 |
+
// RGB : b2_delta_soft = 1
|
| 848 |
+
const auto b2_delta = (stride == 4) ? 1 : ((is_last_line) ? 2 : 1);
|
| 849 |
+
|
| 850 |
+
const auto max_out_x_strided = out_xsize * stride;
|
| 851 |
+
const auto max_in_x_strided = in_xsize * stride;
|
| 852 |
+
|
| 853 |
+
for (const auto out_x : c10::irange(out_xsize)) {
|
| 854 |
+
__m128i sss;
|
| 855 |
+
const auto ids_min = idx_ptr_xmin[out_x];
|
| 856 |
+
const auto ids_size = idx_ptr_size[out_x];
|
| 857 |
+
const auto * k = &kk[out_x * kmax];
|
| 858 |
+
int64_t i = 0;
|
| 859 |
+
|
| 860 |
+
const auto * lineIn_min = lineIn + ids_min;
|
| 861 |
+
|
| 862 |
+
if (ids_size < 8) {
|
| 863 |
+
sss = _mm_set1_epi32(1 << (coefs_precision - 1));
|
| 864 |
+
} else {
|
| 865 |
+
// Lower part will be added to higher, use only half of the error
|
| 866 |
+
auto sss256 = _mm256_set1_epi32(1 << (coefs_precision - 2));
|
| 867 |
+
|
| 868 |
+
// block 8
|
| 869 |
+
for (; i < ids_size - b8_delta; i += 8) {
|
| 870 |
+
// Load 8 values from weight vector
|
| 871 |
+
auto tmp = _mm_loadu_si128((__m128i*)&k[i]);
|
| 872 |
+
// ksource = [
|
| 873 |
+
// wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 wl_4 wh_4 wl_5 wh_5 wl_6 wh_6 wl_7 wh_7
|
| 874 |
+
// wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 wl_4 wh_4 wl_5 wh_5 wl_6 wh_6 wl_7 wh_7
|
| 875 |
+
// ]
|
| 876 |
+
auto ksource = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1);
|
| 877 |
+
|
| 878 |
+
// RGBA: Load 8 pixels from input:
|
| 879 |
+
// source = [
|
| 880 |
+
// r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3
|
| 881 |
+
// r4 g4 b4 a4 r5 g5 b5 a5 r6 g6 b6 a6 r7 g7 b7 a7
|
| 882 |
+
// ]
|
| 883 |
+
// RGB: Load 10 pixels from input (however we can process only 8 pixels):
|
| 884 |
+
// source = [
|
| 885 |
+
// r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5
|
| 886 |
+
// r4 g4 b4 r5 g5 b5 r6 g6 b6 r7 g7 b7 r8 g8 b8 r9
|
| 887 |
+
// ]
|
| 888 |
+
auto source = _mm256_inserti128_si256(_mm256_castsi128_si256(
|
| 889 |
+
_mm_loadu_si128((__m128i *) (lineIn_min + stride * i))),
|
| 890 |
+
_mm_loadu_si128((__m128i *) (lineIn_min + stride * (i + 4))), 1);
|
| 891 |
+
|
| 892 |
+
// Extract lower part of each lane, cast to epi16 and reoder RGBARGBA -> RRGGBBAA
|
| 893 |
+
// RGBA: pix1 = [
|
| 894 |
+
// r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0
|
| 895 |
+
// r4 0 r5 0 g4 0 g5 0 b4 0 b5 0 a4 0 a5 0
|
| 896 |
+
// ]
|
| 897 |
+
// RGB: pix1 = [
|
| 898 |
+
// r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0
|
| 899 |
+
// r4 0 r5 0 g4 0 g5 0 b4 0 b5 0 0 0 0 0
|
| 900 |
+
// ]
|
| 901 |
+
auto pix1 = _mm256_shuffle_epi8(source, mask_low);
|
| 902 |
+
// mmk1 = [
|
| 903 |
+
// wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ...
|
| 904 |
+
// wl_4 wh_4 wl_5 wh_5 wl_4 wh_4 wl_5 wh_5 ... ...
|
| 905 |
+
// ]
|
| 906 |
+
auto mmk1 = _mm256_shuffle_epi8(ksource, kmask_low);
|
| 907 |
+
// Compute output value as
|
| 908 |
+
// C += w0 * C0 + w1 * C1
|
| 909 |
+
// C += w4 * C4 + w5 * C5 for each channel in 32-bit precision
|
| 910 |
+
sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix1, mmk1));
|
| 911 |
+
|
| 912 |
+
// Same as above for higher part of each lane
|
| 913 |
+
auto pix2 = _mm256_shuffle_epi8(source, mask_high);
|
| 914 |
+
auto mmk2 = _mm256_shuffle_epi8(ksource, kmask_high);
|
| 915 |
+
// Compute output value as
|
| 916 |
+
// C += w2 * C2 + w3 * C3
|
| 917 |
+
// C += w6 * C6 + w7 * C7 for each channel in 32-bit precision
|
| 918 |
+
sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix2, mmk2));
|
| 919 |
+
}
|
| 920 |
+
|
| 921 |
+
// block 4
|
| 922 |
+
for (; i < ids_size - b4_delta; i += 4) {
|
| 923 |
+
// Load 4 values from weight vector
|
| 924 |
+
auto tmp = _mm_loadl_epi64((__m128i *) &k[i]);
|
| 925 |
+
// ksource = [
|
| 926 |
+
// wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 0 0 0 0 0 0 0 0
|
| 927 |
+
// wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 0 0 0 0 0 0 0 0
|
| 928 |
+
// ]
|
| 929 |
+
auto ksource = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1);
|
| 930 |
+
|
| 931 |
+
// Load pixels from input line
|
| 932 |
+
tmp = _mm_loadu_si128((__m128i *) (lineIn_min + stride * i));
|
| 933 |
+
// RGBA: source = [
|
| 934 |
+
// r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3
|
| 935 |
+
// r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3
|
| 936 |
+
// ]
|
| 937 |
+
// RGB: source = [
|
| 938 |
+
// r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5
|
| 939 |
+
// r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5
|
| 940 |
+
// ]
|
| 941 |
+
auto source = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1);
|
| 942 |
+
|
| 943 |
+
// Cast source to epi16 and reorder RGBARGBA -> RRGGBBAA
|
| 944 |
+
// RGBA: pix = [
|
| 945 |
+
// r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0
|
| 946 |
+
// r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 a2 0 a3 0
|
| 947 |
+
// ]
|
| 948 |
+
// RGB: pix = [
|
| 949 |
+
// r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0
|
| 950 |
+
// r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 0 0 0 0
|
| 951 |
+
// ]
|
| 952 |
+
auto pix = _mm256_shuffle_epi8(source, mask_hl);
|
| 953 |
+
// mmk = [
|
| 954 |
+
// wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ...
|
| 955 |
+
// wl_2 wh_2 wl_3 wh_3 wl_2 wh_2 wl_3 wh_3 ... ...
|
| 956 |
+
// ]
|
| 957 |
+
auto mmk = _mm256_shuffle_epi8(ksource, kmask_hl);
|
| 958 |
+
// Compute output value as
|
| 959 |
+
// C += w0 * C0 + w1 * C1
|
| 960 |
+
// C += w2 * C2 + w3 * C3 for each channel in 32-bit precision
|
| 961 |
+
sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix, mmk));
|
| 962 |
+
}
|
| 963 |
+
|
| 964 |
+
// Sum results between the lanes
|
| 965 |
+
sss = _mm_add_epi32(
|
| 966 |
+
_mm256_extracti128_si256(sss256, 0),
|
| 967 |
+
_mm256_extracti128_si256(sss256, 1));
|
| 968 |
+
}
|
| 969 |
+
|
| 970 |
+
// block 2
|
| 971 |
+
for (; i < ids_size - b2_delta; i += 2) {
|
| 972 |
+
// Load 2 values from weight vector
|
| 973 |
+
// mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ...]
|
| 974 |
+
auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]);
|
| 975 |
+
// Load pixels from input line
|
| 976 |
+
// RGBA: source = [
|
| 977 |
+
// r0 g0 b0 a0 r1 g1 b1 a1 0 0 0 0 0 0 0 0
|
| 978 |
+
// ]
|
| 979 |
+
// RGB: source = [
|
| 980 |
+
// r0 g0 b0 r1 g1 b1 r2 g2 0 0 0 0 0 0 0 0
|
| 981 |
+
// ]
|
| 982 |
+
auto source = _mm_loadl_epi64((__m128i *) (lineIn_min + stride * i));
|
| 983 |
+
// Cast source to epi16 and reorder RGBARGBA -> RRGGBBAA
|
| 984 |
+
auto pix = _mm_shuffle_epi8(source, mask_low128);
|
| 985 |
+
// Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision
|
| 986 |
+
sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
|
| 987 |
+
}
|
| 988 |
+
|
| 989 |
+
// block 1
|
| 990 |
+
const auto i32_aligned = num_channels == 4;
|
| 991 |
+
for (; i < ids_size - 1; i++) {
|
| 992 |
+
// Load 1 value from weight vector
|
| 993 |
+
// mmk = [wl_0 wh_0 0 0 wl_0 wh_0 0 0 ...]
|
| 994 |
+
auto mmk = _mm_set1_epi32(k[i]);
|
| 995 |
+
// Load one pixel from input line
|
| 996 |
+
// RGBA: pix = [
|
| 997 |
+
// r0 0 0 0 g0 0 0 0 b0 0 0 0 a0 0 0 0
|
| 998 |
+
// ]
|
| 999 |
+
// RGB: pix = [
|
| 1000 |
+
// r0 0 0 0 g0 0 0 0 b0 0 0 0 r1 0 0 0
|
| 1001 |
+
// ]
|
| 1002 |
+
auto pix = mm_cvtepu8_epi32(lineIn_min + stride * i, i32_aligned);
|
| 1003 |
+
// Compute output value as C += w0 * C0 for each channel in 32-bit precision
|
| 1004 |
+
sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
|
| 1005 |
+
}
|
| 1006 |
+
|
| 1007 |
+
if (i == ids_size - 1) {
|
| 1008 |
+
// last element
|
| 1009 |
+
auto mmk = _mm_set1_epi32(k[i]);
|
| 1010 |
+
__m128i pix;
|
| 1011 |
+
auto p = lineIn_min + stride * i;
|
| 1012 |
+
if (num_channels == 3 && C10_UNLIKELY(is_last_line && ids_min + stride * i + 4 >= max_in_x_strided)) {
|
| 1013 |
+
uint8_t input[4];
|
| 1014 |
+
std::memcpy(input, p, 3);
|
| 1015 |
+
pix = mm_cvtepu8_epi32(input, true);
|
| 1016 |
+
} else {
|
| 1017 |
+
pix = mm_cvtepu8_epi32(p, i32_aligned);
|
| 1018 |
+
}
|
| 1019 |
+
sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
|
| 1020 |
+
}
|
| 1021 |
+
|
| 1022 |
+
// Convert fixed point values back to integers (truncating)
|
| 1023 |
+
sss = _mm_srai_epi32(sss, coefs_precision);
|
| 1024 |
+
// Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
|
| 1025 |
+
// (a a a a b b b b c c c c d d d d) -> (a a b b c c d d 0 0 0 0 0 0 0 0)
|
| 1026 |
+
sss = _mm_packs_epi32(sss, zero);
|
| 1027 |
+
// Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
|
| 1028 |
+
// (a a b b c c d d) -> (a b c d 0 0 0 0)
|
| 1029 |
+
sss = _mm_packus_epi16(sss, zero);
|
| 1030 |
+
// Write the output into single uint32
|
| 1031 |
+
// (a b c d) -> x_uint32
|
| 1032 |
+
auto o = _mm_cvtsi128_si32(sss);
|
| 1033 |
+
const auto out_x_strided = stride * out_x;
|
| 1034 |
+
if (num_channels == 3 && C10_UNLIKELY(out_x_strided + 4 >= max_out_x_strided)) {
|
| 1035 |
+
if (C10_UNLIKELY(is_last_line)) {
|
| 1036 |
+
// When we handle the last line, we can not access the next 4 bytes
|
| 1037 |
+
// as they are out of memory bounds.
|
| 1038 |
+
std::memcpy(lineOut + out_x_strided, (uint8_t *) &o, 3);
|
| 1039 |
+
} else {
|
| 1040 |
+
// Memcpy 4-bytes is faster than 3-bytes and this is a boundary case when we want to write
|
| 1041 |
+
// 4 bytes (R G B | X) to the output buffer (X1 X2 X3 | R1).
|
| 1042 |
+
// The 4th byte in the register (X) has a garbage value and 4th byte in the output buffer (R1) has a correct
|
| 1043 |
+
// value which was previously computed by another line. In other words, it means that we can not overwrite
|
| 1044 |
+
// it by simply writing 4 bytes from the register to the output. We'll do the following:
|
| 1045 |
+
// v----------|
|
| 1046 |
+
// Output = [... X1 X2 X3 | R1 G1 B1 R2 ...]
|
| 1047 |
+
// First, we write R1 value to the 4th byte of (R G B | X) -> (R G B | R1)
|
| 1048 |
+
// Second, we write 4 bytes from the register to the output: (X1 X2 X3 | R1) -> (R G B | R1)
|
| 1049 |
+
// Output = [... R G B | R1 G1 B1 R2 ...]
|
| 1050 |
+
_write_endline_rgb_as_uint32(lineOut + out_x_strided, o);
|
| 1051 |
+
}
|
| 1052 |
+
} else if (num_channels == 3) {
|
| 1053 |
+
// Memcpy 4-bytes is faster than 3-bytes and here
|
| 1054 |
+
// we simply write 4 bytes (... R G B X 0 0 0 0 0 ...) where X is a garbage value
|
| 1055 |
+
// that we will overwrite on the next iteration: (... R G B R G B X 0 0 ...)
|
| 1056 |
+
std::memcpy(lineOut + out_x_strided, (uint8_t *) &o, 4);
|
| 1057 |
+
} else {
|
| 1058 |
+
// num_channels = 4 -> lineOut + out_x_strided should be uint32 aligned
|
| 1059 |
+
*(uint32_t *)(lineOut + out_x_strided) = o;
|
| 1060 |
+
}
|
| 1061 |
+
}
|
| 1062 |
+
}
|
| 1063 |
+
|
| 1064 |
+
void ImagingResampleVerticalConvolution8u(
|
| 1065 |
+
uint8_t* C10_RESTRICT lineOut,
|
| 1066 |
+
const uint8_t* C10_RESTRICT lineIn,
|
| 1067 |
+
int64_t xsize,
|
| 1068 |
+
int64_t ids_min,
|
| 1069 |
+
int64_t ids_size,
|
| 1070 |
+
const int16_t* k,
|
| 1071 |
+
unsigned int coefs_precision,
|
| 1072 |
+
int64_t num_channels) {
|
| 1073 |
+
|
| 1074 |
+
// Interpolation vertical pass processing one line.
|
| 1075 |
+
// - We process x-axis data with blocks of 8, 2 and 1
|
| 1076 |
+
// - We split the size of weight vector for a given output index as a sum: K = n * 2 + m.
|
| 1077 |
+
|
| 1078 |
+
// xsize = output width, also equals to input width
|
| 1079 |
+
// ids_size = interpolation size
|
| 1080 |
+
// ids_min = input y start index
|
| 1081 |
+
const auto stride = num_channels * sizeof(uint8_t);
|
| 1082 |
+
|
| 1083 |
+
TORCH_INTERNAL_ASSERT(stride == 3 || stride == 4);
|
| 1084 |
+
|
| 1085 |
+
const int64_t data_size = xsize * stride;
|
| 1086 |
+
const int64_t data_stride = stride;
|
| 1087 |
+
constexpr auto vec_size = 256 / 8;
|
| 1088 |
+
|
| 1089 |
+
const auto initial = _mm_set1_epi32(1 << (coefs_precision - 1));
|
| 1090 |
+
const auto initial_256 = _mm256_set1_epi32(1 << (coefs_precision - 1));
|
| 1091 |
+
const auto zero = _mm_setzero_si128();
|
| 1092 |
+
const auto zero_256 = _mm256_setzero_si256();
|
| 1093 |
+
|
| 1094 |
+
int64_t j = 0;
|
| 1095 |
+
// block 8
|
| 1096 |
+
const auto b8_usable_vec_stride = (vec_size / data_stride) * data_stride;
|
| 1097 |
+
for (; j < data_size - vec_size; j += b8_usable_vec_stride) {
|
| 1098 |
+
auto sss0 = initial_256;
|
| 1099 |
+
auto sss1 = initial_256;
|
| 1100 |
+
auto sss2 = initial_256;
|
| 1101 |
+
auto sss3 = initial_256;
|
| 1102 |
+
int64_t i = 0;
|
| 1103 |
+
const auto * lineIn_min = lineIn + j + ids_min;
|
| 1104 |
+
|
| 1105 |
+
for (; i < ids_size - 1; i += 2) {
|
| 1106 |
+
// Load 2 values from weight vector
|
| 1107 |
+
auto mmk = _mm256_set1_epi32(*(int32_t*)&k[i]);
|
| 1108 |
+
|
| 1109 |
+
// RGBA: Load 8 pixels per line
|
| 1110 |
+
// source1 = [
|
| 1111 |
+
// r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3
|
| 1112 |
+
// r4 g4 b4 a4 r5 g5 b5 a5 r6 g6 b6 a6 r7 g7 b7 a7
|
| 1113 |
+
// ]
|
| 1114 |
+
// RGB: Load 10 pixels per line (however we can process only 8 pixels):
|
| 1115 |
+
// source1 = [
|
| 1116 |
+
// r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5
|
| 1117 |
+
// r4 g4 b4 r5 g5 b5 r6 g6 b6 r7 g7 b7 r8 g8 b8 r9
|
| 1118 |
+
// ]
|
| 1119 |
+
auto source1 =
|
| 1120 |
+
_mm256_loadu_si256((__m256i*)(lineIn_min + data_size * i));
|
| 1121 |
+
auto source2 =
|
| 1122 |
+
_mm256_loadu_si256((__m256i*)(lineIn_min + data_size * (i + 1)));
|
| 1123 |
+
|
| 1124 |
+
// Interleave source1 and source2 from the low half of each 128-bit lane
|
| 1125 |
+
// and cast the result to epi16
|
| 1126 |
+
// RGBA: pix1 = [
|
| 1127 |
+
// r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 a0 0 A0 0
|
| 1128 |
+
// r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 a1 0 A1 0
|
| 1129 |
+
// ]
|
| 1130 |
+
// RGB: pix1 = [
|
| 1131 |
+
// r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 0 0 0 0
|
| 1132 |
+
// r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 0 0 0 0
|
| 1133 |
+
// ]
|
| 1134 |
+
auto source_lo = _mm256_unpacklo_epi8(source1, source2);
|
| 1135 |
+
auto pix1 = _mm256_unpacklo_epi8(source_lo, zero_256);
|
| 1136 |
+
// Compute output value as
|
| 1137 |
+
// C += w0 * c0 + w1 * C0
|
| 1138 |
+
// C += w0 * c1 + w1 * C1 for each channel in 32-bit precision
|
| 1139 |
+
sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk));
|
| 1140 |
+
|
| 1141 |
+
// RGBA: pix2 = [
|
| 1142 |
+
// r2 0 R2 0 g2 0 G2 0 b2 0 B2 0 a2 0 A2 0
|
| 1143 |
+
// r3 0 R3 0 g3 0 G3 0 b3 0 B3 0 a3 0 A3 0
|
| 1144 |
+
// ]
|
| 1145 |
+
// RGB: pix2 = [
|
| 1146 |
+
// r2 0 R2 0 g2 0 G2 0 b2 0 B2 0 0 0 0 0
|
| 1147 |
+
// r3 0 R3 0 g3 0 G3 0 b3 0 B3 0 0 0 0 0
|
| 1148 |
+
// ]
|
| 1149 |
+
auto pix2 = _mm256_unpackhi_epi8(source_lo, zero_256);
|
| 1150 |
+
// Compute output value as
|
| 1151 |
+
// C += w0 * c2 + w1 * C2
|
| 1152 |
+
// C += w0 * c3 + w1 * C3 for each channel in 32-bit precision
|
| 1153 |
+
sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
|
| 1154 |
+
|
| 1155 |
+
// Same as above for the high half of each 128-bit lane
|
| 1156 |
+
auto source_hi = _mm256_unpackhi_epi8(source1, source2);
|
| 1157 |
+
auto pix3 = _mm256_unpacklo_epi8(source_hi, zero_256);
|
| 1158 |
+
sss2 = _mm256_add_epi32(sss2, _mm256_madd_epi16(pix3, mmk));
|
| 1159 |
+
auto pix4 = _mm256_unpackhi_epi8(source_hi, zero_256);
|
| 1160 |
+
sss3 = _mm256_add_epi32(sss3, _mm256_madd_epi16(pix4, mmk));
|
| 1161 |
+
}
|
| 1162 |
+
// Same processing as above but with a single weight value
|
| 1163 |
+
for (; i < ids_size; i += 1) {
|
| 1164 |
+
auto mmk = _mm256_set1_epi32(k[i]);
|
| 1165 |
+
|
| 1166 |
+
auto source1 = _mm256_loadu_si256((__m256i*)(lineIn_min + i * data_size));
|
| 1167 |
+
|
| 1168 |
+
auto source_lo = _mm256_unpacklo_epi8(source1, zero_256);
|
| 1169 |
+
auto pix1 = _mm256_unpacklo_epi8(source_lo, zero_256);
|
| 1170 |
+
sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk));
|
| 1171 |
+
auto pix2 = _mm256_unpackhi_epi8(source_lo, zero_256);
|
| 1172 |
+
sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
|
| 1173 |
+
|
| 1174 |
+
auto source_hi = _mm256_unpackhi_epi8(source1, zero_256);
|
| 1175 |
+
auto pix3 = _mm256_unpacklo_epi8(source_hi, _mm256_setzero_si256());
|
| 1176 |
+
sss2 = _mm256_add_epi32(sss2, _mm256_madd_epi16(pix3, mmk));
|
| 1177 |
+
auto pix4 = _mm256_unpackhi_epi8(source_hi, _mm256_setzero_si256());
|
| 1178 |
+
sss3 = _mm256_add_epi32(sss3, _mm256_madd_epi16(pix4, mmk));
|
| 1179 |
+
}
|
| 1180 |
+
// Convert fixed point values back to integers (truncating)
|
| 1181 |
+
sss0 = _mm256_srai_epi32(sss0, coefs_precision);
|
| 1182 |
+
sss1 = _mm256_srai_epi32(sss1, coefs_precision);
|
| 1183 |
+
sss2 = _mm256_srai_epi32(sss2, coefs_precision);
|
| 1184 |
+
sss3 = _mm256_srai_epi32(sss3, coefs_precision);
|
| 1185 |
+
// Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
|
| 1186 |
+
// (a a a a b b b b c c c c d d d d) -> (a a b b c c d d)
|
| 1187 |
+
sss0 = _mm256_packs_epi32(sss0, sss1);
|
| 1188 |
+
sss2 = _mm256_packs_epi32(sss2, sss3);
|
| 1189 |
+
// Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
|
| 1190 |
+
// (a a b b c c d d) -> (a b c d)
|
| 1191 |
+
sss0 = _mm256_packus_epi16(sss0, sss2);
|
| 1192 |
+
|
| 1193 |
+
// Stores 32 bytes
|
| 1194 |
+
_mm256_storeu_si256((__m256i*)(lineOut + j), sss0);
|
| 1195 |
+
}
|
| 1196 |
+
|
| 1197 |
+
// TODO: Do we also need block 4 ???
|
| 1198 |
+
// block 2
|
| 1199 |
+
const auto b2_usable_vec_stride = (8 / data_stride) * data_stride;
|
| 1200 |
+
for (; j < data_size - vec_size / 4; j += b2_usable_vec_stride) {
|
| 1201 |
+
auto sss0 = initial;
|
| 1202 |
+
auto sss1 = initial;
|
| 1203 |
+
int64_t i = 0;
|
| 1204 |
+
const auto * lineIn_min = lineIn + j + ids_min;
|
| 1205 |
+
|
| 1206 |
+
for (; i < ids_size - 1; i += 2) {
|
| 1207 |
+
// Load 2 values from weight vector
|
| 1208 |
+
// mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ]
|
| 1209 |
+
auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]);
|
| 1210 |
+
|
| 1211 |
+
// Load 2 pixels per line
|
| 1212 |
+
// RGBA: source1 = [
|
| 1213 |
+
// r0 g0 b0 a0 r1 g1 b1 a1 0 0 0 0 0 0 0 0
|
| 1214 |
+
// ]
|
| 1215 |
+
// RGB: source1 = [
|
| 1216 |
+
// r0 g0 b0 r1 g1 b1 r2 g2 0 0 0 0 0 0 0 0
|
| 1217 |
+
// ]
|
| 1218 |
+
auto source1 = _mm_loadl_epi64((__m128i *) (lineIn_min + i * data_size));
|
| 1219 |
+
auto source2 = _mm_loadl_epi64((__m128i *) (lineIn_min + (i + 1) * data_size));
|
| 1220 |
+
// Interleave source1 and source2 and cast the result to epi16
|
| 1221 |
+
// RGBA: pix = [
|
| 1222 |
+
// r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 a0 0 A0 0
|
| 1223 |
+
// ]
|
| 1224 |
+
// RGB: pix = [
|
| 1225 |
+
// r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 0 0 0 0
|
| 1226 |
+
// ]
|
| 1227 |
+
auto source = _mm_unpacklo_epi8(source1, source2);
|
| 1228 |
+
auto pix = _mm_unpacklo_epi8(source, zero);
|
| 1229 |
+
// Compute output value as C += w0 * c0 + w1 * C0 for each channel in 32-bit precision
|
| 1230 |
+
sss0 = _mm_add_epi32(sss0, _mm_madd_epi16(pix, mmk));
|
| 1231 |
+
// RGBA: pix = [
|
| 1232 |
+
// r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 a1 0 A1 0
|
| 1233 |
+
// ]
|
| 1234 |
+
// RGB: pix = [
|
| 1235 |
+
// r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 0 0 0 0
|
| 1236 |
+
// ]
|
| 1237 |
+
pix = _mm_unpackhi_epi8(source, zero);
|
| 1238 |
+
// Compute output value as C += w0 * c1 + w1 * C1 for each channel in 32-bit precision
|
| 1239 |
+
sss1 = _mm_add_epi32(sss1, _mm_madd_epi16(pix, mmk));
|
| 1240 |
+
}
|
| 1241 |
+
// Same processing as above but with a single weight value
|
| 1242 |
+
for (; i < ids_size; i += 1) {
|
| 1243 |
+
auto mmk = _mm_set1_epi32(k[i]);
|
| 1244 |
+
|
| 1245 |
+
auto source1 = _mm_loadl_epi64((__m128i*) (lineIn_min + i * data_size));
|
| 1246 |
+
|
| 1247 |
+
auto source = _mm_unpacklo_epi8(source1, zero);
|
| 1248 |
+
auto pix1 = _mm_unpacklo_epi8(source, zero);
|
| 1249 |
+
sss0 = _mm_add_epi32(sss0, _mm_madd_epi16(pix1, mmk));
|
| 1250 |
+
auto pix2 = _mm_unpackhi_epi8(source, zero);
|
| 1251 |
+
sss1 = _mm_add_epi32(sss1, _mm_madd_epi16(pix2, mmk));
|
| 1252 |
+
}
|
| 1253 |
+
// Convert fixed point values back to integers (truncating)
|
| 1254 |
+
sss0 = _mm_srai_epi32(sss0, coefs_precision);
|
| 1255 |
+
sss1 = _mm_srai_epi32(sss1, coefs_precision);
|
| 1256 |
+
// Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
|
| 1257 |
+
// (a a a a b b b b c c c c d d d d) -> (a a b b c c d d)
|
| 1258 |
+
sss0 = _mm_packs_epi32(sss0, sss1);
|
| 1259 |
+
// Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
|
| 1260 |
+
// (a a b b c c d d) -> (a b c d)
|
| 1261 |
+
sss0 = _mm_packus_epi16(sss0, sss0);
|
| 1262 |
+
// Store 2 pixels to the output
|
| 1263 |
+
_mm_storel_epi64((__m128i*)(lineOut + j), sss0);
|
| 1264 |
+
}
|
| 1265 |
+
|
| 1266 |
+
// block 1
|
| 1267 |
+
const auto b1_usable_vec_stride = (4 / data_stride) * data_stride;
|
| 1268 |
+
const auto i32_aligned = num_channels == 4;
|
| 1269 |
+
for (; j < data_size - 4; j += b1_usable_vec_stride) {
|
| 1270 |
+
auto sss = initial;
|
| 1271 |
+
int64_t i = 0;
|
| 1272 |
+
const auto * lineIn_min = lineIn + j + ids_min;
|
| 1273 |
+
|
| 1274 |
+
for (; i < ids_size - 1; i += 2) {
|
| 1275 |
+
// Load 2 values from weight vector
|
| 1276 |
+
// mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ]
|
| 1277 |
+
auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]);
|
| 1278 |
+
|
| 1279 |
+
// Load one pixel per line
|
| 1280 |
+
// RGBA: source1 = [
|
| 1281 |
+
// r0 g0 b0 a0 0 0 0 0 0 0 0 0 0 0 0 0
|
| 1282 |
+
// ]
|
| 1283 |
+
// RGB: source1 = [
|
| 1284 |
+
// r0 g0 b0 r1 0 0 0 0 0 0 0 0 0 0 0 0
|
| 1285 |
+
// ]
|
| 1286 |
+
auto source1 = mm_cvtsi32_si128(lineIn_min + i * data_size, i32_aligned);
|
| 1287 |
+
auto source2 = mm_cvtsi32_si128(lineIn_min + (i + 1) * data_size, i32_aligned);
|
| 1288 |
+
|
| 1289 |
+
// Interleave source1 and source2 and cast the result to epi16
|
| 1290 |
+
// RGBA: pix = [
|
| 1291 |
+
// r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 a0 0 A0 0
|
| 1292 |
+
// ]
|
| 1293 |
+
// RGB: pix = [
|
| 1294 |
+
// r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 0 0 0 0
|
| 1295 |
+
// ]
|
| 1296 |
+
auto source = _mm_unpacklo_epi8(source1, source2);
|
| 1297 |
+
auto pix = _mm_unpacklo_epi8(source, zero);
|
| 1298 |
+
// Compute output value as C += w0 * c0 + w1 * C0 for each channel in 32-bit precision
|
| 1299 |
+
sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
|
| 1300 |
+
}
|
| 1301 |
+
|
| 1302 |
+
for (; i < ids_size; i++) {
|
| 1303 |
+
auto mmk = _mm_set1_epi32(k[i]);
|
| 1304 |
+
auto pix = mm_cvtepu8_epi32(lineIn_min + i * data_size, i32_aligned);
|
| 1305 |
+
sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
|
| 1306 |
+
}
|
| 1307 |
+
sss = _mm_srai_epi32(sss, coefs_precision);
|
| 1308 |
+
sss = _mm_packs_epi32(sss, zero);
|
| 1309 |
+
sss = _mm_packus_epi16(sss, zero);
|
| 1310 |
+
|
| 1311 |
+
auto o = _mm_cvtsi128_si32(sss);
|
| 1312 |
+
|
| 1313 |
+
// Here we write 4 bytes to the output even if num_channels < 4, e.g o = {r,g,b,X} for num_channels=3
|
| 1314 |
+
// It is OK to write 4th byte (e.g. X) as on the next step we will overwrite it with new data.
|
| 1315 |
+
// We also wont go out of bounds of lineOut memory allocation
|
| 1316 |
+
std::memcpy(lineOut + j, (uint8_t *) &o, 4);
|
| 1317 |
+
}
|
| 1318 |
+
|
| 1319 |
+
for (; j < data_size; j += data_stride) {
|
| 1320 |
+
auto sss = initial;
|
| 1321 |
+
int64_t i = 0;
|
| 1322 |
+
const auto * lineIn_min = lineIn + j + ids_min;
|
| 1323 |
+
// For RGBA we can use (ids_size - 1) as tighter limit but for RGB we can read outside memory boundary
|
| 1324 |
+
// for the last remaining line
|
| 1325 |
+
for (; i < ids_size - 2; i += 2) {
|
| 1326 |
+
// Load two coefficients at once
|
| 1327 |
+
auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]);
|
| 1328 |
+
|
| 1329 |
+
// Load 2 lines
|
| 1330 |
+
auto source1 = mm_cvtsi32_si128(lineIn_min + i * data_size, i32_aligned);
|
| 1331 |
+
auto source2 = mm_cvtsi32_si128(lineIn_min + (i + 1) * data_size, i32_aligned);
|
| 1332 |
+
|
| 1333 |
+
auto source = _mm_unpacklo_epi8(source1, source2);
|
| 1334 |
+
auto pix = _mm_unpacklo_epi8(source, zero);
|
| 1335 |
+
sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
|
| 1336 |
+
}
|
| 1337 |
+
|
| 1338 |
+
// Same processing as above but with a single weight value
|
| 1339 |
+
for (; i < ids_size; i++) {
|
| 1340 |
+
auto mmk = _mm_set1_epi32(k[i]);
|
| 1341 |
+
|
| 1342 |
+
const uint8_t * p = lineIn_min + i * data_size;
|
| 1343 |
+
__m128i pix;
|
| 1344 |
+
// There is no much perf gain using more detailed condition like
|
| 1345 |
+
// num_channels == 3 && ids_min + j + data_size * i + 4 >= in_max_size
|
| 1346 |
+
// const int64_t in_max_size = data_size * in_ysize;
|
| 1347 |
+
if (num_channels == 3) {
|
| 1348 |
+
uint8_t input[4];
|
| 1349 |
+
std::memcpy(input, p, 3);
|
| 1350 |
+
pix = mm_cvtepu8_epi32(input, true);
|
| 1351 |
+
} else {
|
| 1352 |
+
pix = mm_cvtepu8_epi32(p, true);
|
| 1353 |
+
}
|
| 1354 |
+
sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
|
| 1355 |
+
}
|
| 1356 |
+
|
| 1357 |
+
// Convert fixed point values back to integers (truncating)
|
| 1358 |
+
sss = _mm_srai_epi32(sss, coefs_precision);
|
| 1359 |
+
// Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
|
| 1360 |
+
// (a a a a b b b b c c c c d d d d) -> (a a b b c c d d)
|
| 1361 |
+
sss = _mm_packs_epi32(sss, zero);
|
| 1362 |
+
// Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
|
| 1363 |
+
// (a a b b c c d d) -> (a b c d)
|
| 1364 |
+
sss = _mm_packus_epi16(sss, zero);
|
| 1365 |
+
// Store one pixel to the output
|
| 1366 |
+
auto o = _mm_cvtsi128_si32(sss);
|
| 1367 |
+
if (num_channels == 3 && C10_UNLIKELY(j + 4 >= data_size)) {
|
| 1368 |
+
std::memcpy(lineOut + j, (uint8_t *) &o, 3);
|
| 1369 |
+
} else {
|
| 1370 |
+
std::memcpy(lineOut + j, (uint8_t *) &o, 4);
|
| 1371 |
+
}
|
| 1372 |
+
}
|
| 1373 |
+
}
|
| 1374 |
+
|
| 1375 |
+
} // anonymous namespace
|
| 1376 |
+
#endif // CPU_CAPABILITY_AVX2
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/WeightNormKernel.h
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/native/DispatchStub.h>
|
| 3 |
+
#include <cstdint>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
class TensorBase;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
namespace at::native {
|
| 10 |
+
|
| 11 |
+
using weight_norm_fn = void(*)(
|
| 12 |
+
TensorBase&, TensorBase&, const TensorBase&, const TensorBase&, int64_t);
|
| 13 |
+
using weight_norm_backward_fn = void(*)(
|
| 14 |
+
TensorBase&, TensorBase&, const TensorBase&, const TensorBase&,
|
| 15 |
+
const TensorBase&, const TensorBase&, int64_t);
|
| 16 |
+
|
| 17 |
+
DECLARE_DISPATCH(weight_norm_fn, weight_norm_stub);
|
| 18 |
+
DECLARE_DISPATCH(weight_norm_backward_fn, weight_norm_backward_stub);
|
| 19 |
+
|
| 20 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/avx_mathfun.h
ADDED
|
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
/*
|
| 3 |
+
AVX implementation of sin, cos, sincos, exp and log
|
| 4 |
+
|
| 5 |
+
Based on "sse_mathfun.h", by Julien Pommier
|
| 6 |
+
http://gruntthepeon.free.fr/ssemath/
|
| 7 |
+
|
| 8 |
+
Copyright (C) 2012 Giovanni Garberoglio
|
| 9 |
+
Interdisciplinary Laboratory for Computational Science (LISC)
|
| 10 |
+
Fondazione Bruno Kessler and University of Trento
|
| 11 |
+
via Sommarive, 18
|
| 12 |
+
I-38123 Trento (Italy)
|
| 13 |
+
|
| 14 |
+
This software is provided 'as-is', without any express or implied
|
| 15 |
+
warranty. In no event will the authors be held liable for any damages
|
| 16 |
+
arising from the use of this software.
|
| 17 |
+
|
| 18 |
+
Permission is granted to anyone to use this software for any purpose,
|
| 19 |
+
including commercial applications, and to alter it and redistribute it
|
| 20 |
+
freely, subject to the following restrictions:
|
| 21 |
+
|
| 22 |
+
1. The origin of this software must not be misrepresented; you must not
|
| 23 |
+
claim that you wrote the original software. If you use this software
|
| 24 |
+
in a product, an acknowledgment in the product documentation would be
|
| 25 |
+
appreciated but is not required.
|
| 26 |
+
2. Altered source versions must be plainly marked as such, and must not be
|
| 27 |
+
misrepresented as being the original software.
|
| 28 |
+
3. This notice may not be removed or altered from any source distribution.
|
| 29 |
+
|
| 30 |
+
(this is the zlib license)
|
| 31 |
+
*/
|
| 32 |
+
|
| 33 |
+
#include <ATen/native/cpu/Intrinsics.h>
|
| 34 |
+
|
| 35 |
+
/* The original source of this file has been modified. */
|
| 36 |
+
#if defined(CPU_CAPABILITY_AVX2)
|
| 37 |
+
|
| 38 |
+
#if defined(__GNUC__)
|
| 39 |
+
# define ALIGN32_BEG __attribute__((aligned(32)))
|
| 40 |
+
#elif defined(_WIN32)
|
| 41 |
+
# define ALIGN32_BEG __declspec(align(32))
|
| 42 |
+
#endif
|
| 43 |
+
|
| 44 |
+
typedef __m256 v8sf; // vector of 8 float (avx2)
|
| 45 |
+
typedef __m256i v8si; // vector of 8 int (avx2)
|
| 46 |
+
|
| 47 |
+
/* declare some AVX constants -- why can't I figure a better way to do that? */
|
| 48 |
+
#define _PS256_CONST(Name, Val) \
|
| 49 |
+
static const ALIGN32_BEG float _ps256_##Name[8] = { Val, Val, Val, Val, Val, Val, Val, Val }
|
| 50 |
+
#define _PI32_CONST256(Name, Val) \
|
| 51 |
+
static const ALIGN32_BEG int _pi32_256_##Name[8] = { Val, Val, Val, Val, Val, Val, Val, Val }
|
| 52 |
+
#define _PS256_CONST_TYPE(Name, Type, Val) \
|
| 53 |
+
static const ALIGN32_BEG Type _ps256_##Name[8] = { Val, Val, Val, Val, Val, Val, Val, Val }
|
| 54 |
+
|
| 55 |
+
_PS256_CONST(1 , 1.0f);
|
| 56 |
+
_PS256_CONST(0p5, 0.5f);
|
| 57 |
+
/* the smallest non denormalized float number */
|
| 58 |
+
_PS256_CONST_TYPE(min_norm_pos, int, 0x00800000);
|
| 59 |
+
_PS256_CONST_TYPE(mant_mask, int, 0x7f800000);
|
| 60 |
+
_PS256_CONST_TYPE(inv_mant_mask, int, ~0x7f800000);
|
| 61 |
+
|
| 62 |
+
_PS256_CONST_TYPE(sign_mask, int, (int)0x80000000);
|
| 63 |
+
_PS256_CONST_TYPE(inv_sign_mask, int, ~0x80000000);
|
| 64 |
+
|
| 65 |
+
_PI32_CONST256(0, 0);
|
| 66 |
+
_PI32_CONST256(1, 1);
|
| 67 |
+
_PI32_CONST256(inv1, ~1);
|
| 68 |
+
_PI32_CONST256(2, 2);
|
| 69 |
+
_PI32_CONST256(4, 4);
|
| 70 |
+
_PI32_CONST256(0x7f, 0x7f);
|
| 71 |
+
|
| 72 |
+
_PS256_CONST(cephes_SQRTHF, 0.707106781186547524);
|
| 73 |
+
_PS256_CONST(cephes_log_p0, 7.0376836292E-2);
|
| 74 |
+
_PS256_CONST(cephes_log_p1, - 1.1514610310E-1);
|
| 75 |
+
_PS256_CONST(cephes_log_p2, 1.1676998740E-1);
|
| 76 |
+
_PS256_CONST(cephes_log_p3, - 1.2420140846E-1);
|
| 77 |
+
_PS256_CONST(cephes_log_p4, + 1.4249322787E-1);
|
| 78 |
+
_PS256_CONST(cephes_log_p5, - 1.6668057665E-1);
|
| 79 |
+
_PS256_CONST(cephes_log_p6, + 2.0000714765E-1);
|
| 80 |
+
_PS256_CONST(cephes_log_p7, - 2.4999993993E-1);
|
| 81 |
+
_PS256_CONST(cephes_log_p8, + 3.3333331174E-1);
|
| 82 |
+
_PS256_CONST(cephes_log_q1, -2.12194440e-4);
|
| 83 |
+
_PS256_CONST(cephes_log_q2, 0.693359375);
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
/* natural logarithm computed for 8 simultaneous float
|
| 87 |
+
return NaN for x <= 0
|
| 88 |
+
*/
|
| 89 |
+
inline v8sf log256_ps(v8sf x) {
|
| 90 |
+
v8si imm0;
|
| 91 |
+
v8sf one = *(v8sf*)_ps256_1;
|
| 92 |
+
|
| 93 |
+
//v8sf invalid_mask = _mm256_cmple_ps(x, _mm256_setzero_ps());
|
| 94 |
+
v8sf invalid_mask = _mm256_cmp_ps(x, _mm256_setzero_ps(), _CMP_LE_OS);
|
| 95 |
+
|
| 96 |
+
x = _mm256_max_ps(x, *(v8sf*)_ps256_min_norm_pos); /* cut off denormalized stuff */
|
| 97 |
+
|
| 98 |
+
// can be done with AVX2
|
| 99 |
+
imm0 = _mm256_srli_epi32(_mm256_castps_si256(x), 23);
|
| 100 |
+
|
| 101 |
+
/* keep only the fractional part */
|
| 102 |
+
x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_mant_mask);
|
| 103 |
+
x = _mm256_or_ps(x, *(v8sf*)_ps256_0p5);
|
| 104 |
+
|
| 105 |
+
// this is again another AVX2 instruction
|
| 106 |
+
imm0 = _mm256_sub_epi32(imm0, *(v8si*)_pi32_256_0x7f);
|
| 107 |
+
v8sf e = _mm256_cvtepi32_ps(imm0);
|
| 108 |
+
|
| 109 |
+
e = _mm256_add_ps(e, one);
|
| 110 |
+
|
| 111 |
+
/* part2:
|
| 112 |
+
if( x < SQRTHF ) {
|
| 113 |
+
e -= 1;
|
| 114 |
+
x = x + x - 1.0;
|
| 115 |
+
} else { x = x - 1.0; }
|
| 116 |
+
*/
|
| 117 |
+
//v8sf mask = _mm256_cmplt_ps(x, *(v8sf*)_ps256_cephes_SQRTHF);
|
| 118 |
+
v8sf mask = _mm256_cmp_ps(x, *(v8sf*)_ps256_cephes_SQRTHF, _CMP_LT_OS);
|
| 119 |
+
v8sf tmp = _mm256_and_ps(x, mask);
|
| 120 |
+
x = _mm256_sub_ps(x, one);
|
| 121 |
+
e = _mm256_sub_ps(e, _mm256_and_ps(one, mask));
|
| 122 |
+
x = _mm256_add_ps(x, tmp);
|
| 123 |
+
|
| 124 |
+
v8sf z = _mm256_mul_ps(x,x);
|
| 125 |
+
|
| 126 |
+
v8sf y = *(v8sf*)_ps256_cephes_log_p0;
|
| 127 |
+
y = _mm256_mul_ps(y, x);
|
| 128 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p1);
|
| 129 |
+
y = _mm256_mul_ps(y, x);
|
| 130 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p2);
|
| 131 |
+
y = _mm256_mul_ps(y, x);
|
| 132 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p3);
|
| 133 |
+
y = _mm256_mul_ps(y, x);
|
| 134 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p4);
|
| 135 |
+
y = _mm256_mul_ps(y, x);
|
| 136 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p5);
|
| 137 |
+
y = _mm256_mul_ps(y, x);
|
| 138 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p6);
|
| 139 |
+
y = _mm256_mul_ps(y, x);
|
| 140 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p7);
|
| 141 |
+
y = _mm256_mul_ps(y, x);
|
| 142 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p8);
|
| 143 |
+
y = _mm256_mul_ps(y, x);
|
| 144 |
+
|
| 145 |
+
y = _mm256_mul_ps(y, z);
|
| 146 |
+
|
| 147 |
+
tmp = _mm256_mul_ps(e, *(v8sf*)_ps256_cephes_log_q1);
|
| 148 |
+
y = _mm256_add_ps(y, tmp);
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5);
|
| 152 |
+
y = _mm256_sub_ps(y, tmp);
|
| 153 |
+
|
| 154 |
+
tmp = _mm256_mul_ps(e, *(v8sf*)_ps256_cephes_log_q2);
|
| 155 |
+
x = _mm256_add_ps(x, y);
|
| 156 |
+
x = _mm256_add_ps(x, tmp);
|
| 157 |
+
x = _mm256_or_ps(x, invalid_mask); // negative arg will be NAN
|
| 158 |
+
return x;
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
_PS256_CONST(exp_hi, 88.3762626647949f);
|
| 162 |
+
_PS256_CONST(exp_lo, -88.3762626647949f);
|
| 163 |
+
|
| 164 |
+
_PS256_CONST(cephes_LOG2EF, 1.44269504088896341);
|
| 165 |
+
_PS256_CONST(cephes_exp_C1, 0.693359375);
|
| 166 |
+
_PS256_CONST(cephes_exp_C2, -2.12194440e-4);
|
| 167 |
+
|
| 168 |
+
_PS256_CONST(cephes_exp_p0, 1.9875691500E-4);
|
| 169 |
+
_PS256_CONST(cephes_exp_p1, 1.3981999507E-3);
|
| 170 |
+
_PS256_CONST(cephes_exp_p2, 8.3334519073E-3);
|
| 171 |
+
_PS256_CONST(cephes_exp_p3, 4.1665795894E-2);
|
| 172 |
+
_PS256_CONST(cephes_exp_p4, 1.6666665459E-1);
|
| 173 |
+
_PS256_CONST(cephes_exp_p5, 5.0000001201E-1);
|
| 174 |
+
|
| 175 |
+
inline v8sf exp256_ps(v8sf x) {
|
| 176 |
+
v8sf tmp = _mm256_setzero_ps(), fx;
|
| 177 |
+
v8si imm0;
|
| 178 |
+
v8sf one = *(v8sf*)_ps256_1;
|
| 179 |
+
|
| 180 |
+
x = _mm256_min_ps(x, *(v8sf*)_ps256_exp_hi);
|
| 181 |
+
x = _mm256_max_ps(x, *(v8sf*)_ps256_exp_lo);
|
| 182 |
+
|
| 183 |
+
/* express exp(x) as exp(g + n*log(2)) */
|
| 184 |
+
fx = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_LOG2EF);
|
| 185 |
+
fx = _mm256_add_ps(fx, *(v8sf*)_ps256_0p5);
|
| 186 |
+
|
| 187 |
+
/* how to perform a floorf with SSE: just below */
|
| 188 |
+
//imm0 = _mm256_cvttps_epi32(fx);
|
| 189 |
+
//tmp = _mm256_cvtepi32_ps(imm0);
|
| 190 |
+
|
| 191 |
+
tmp = _mm256_floor_ps(fx);
|
| 192 |
+
|
| 193 |
+
/* if greater, subtract 1 */
|
| 194 |
+
//v8sf mask = _mm256_cmpgt_ps(tmp, fx);
|
| 195 |
+
v8sf mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS);
|
| 196 |
+
mask = _mm256_and_ps(mask, one);
|
| 197 |
+
fx = _mm256_sub_ps(tmp, mask);
|
| 198 |
+
|
| 199 |
+
tmp = _mm256_mul_ps(fx, *(v8sf*)_ps256_cephes_exp_C1);
|
| 200 |
+
v8sf z = _mm256_mul_ps(fx, *(v8sf*)_ps256_cephes_exp_C2);
|
| 201 |
+
x = _mm256_sub_ps(x, tmp);
|
| 202 |
+
x = _mm256_sub_ps(x, z);
|
| 203 |
+
|
| 204 |
+
z = _mm256_mul_ps(x,x);
|
| 205 |
+
|
| 206 |
+
v8sf y = *(v8sf*)_ps256_cephes_exp_p0;
|
| 207 |
+
y = _mm256_mul_ps(y, x);
|
| 208 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p1);
|
| 209 |
+
y = _mm256_mul_ps(y, x);
|
| 210 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p2);
|
| 211 |
+
y = _mm256_mul_ps(y, x);
|
| 212 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p3);
|
| 213 |
+
y = _mm256_mul_ps(y, x);
|
| 214 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p4);
|
| 215 |
+
y = _mm256_mul_ps(y, x);
|
| 216 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p5);
|
| 217 |
+
y = _mm256_mul_ps(y, z);
|
| 218 |
+
y = _mm256_add_ps(y, x);
|
| 219 |
+
y = _mm256_add_ps(y, one);
|
| 220 |
+
|
| 221 |
+
/* build 2^n */
|
| 222 |
+
imm0 = _mm256_cvttps_epi32(fx);
|
| 223 |
+
// another two AVX2 instructions
|
| 224 |
+
imm0 = _mm256_add_epi32(imm0, *(v8si*)_pi32_256_0x7f);
|
| 225 |
+
imm0 = _mm256_slli_epi32(imm0, 23);
|
| 226 |
+
v8sf pow2n = _mm256_castsi256_ps(imm0);
|
| 227 |
+
y = _mm256_mul_ps(y, pow2n);
|
| 228 |
+
return y;
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
_PS256_CONST(minus_cephes_DP1, -0.78515625);
|
| 232 |
+
_PS256_CONST(minus_cephes_DP2, -2.4187564849853515625e-4);
|
| 233 |
+
_PS256_CONST(minus_cephes_DP3, -3.77489497744594108e-8);
|
| 234 |
+
_PS256_CONST(sincof_p0, -1.9515295891E-4);
|
| 235 |
+
_PS256_CONST(sincof_p1, 8.3321608736E-3);
|
| 236 |
+
_PS256_CONST(sincof_p2, -1.6666654611E-1);
|
| 237 |
+
_PS256_CONST(coscof_p0, 2.443315711809948E-005);
|
| 238 |
+
_PS256_CONST(coscof_p1, -1.388731625493765E-003);
|
| 239 |
+
_PS256_CONST(coscof_p2, 4.166664568298827E-002);
|
| 240 |
+
_PS256_CONST(cephes_FOPI, 1.27323954473516); // 4 / M_PI
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
/* evaluation of 8 sines at onces using AVX intrinsics
|
| 244 |
+
|
| 245 |
+
The code is the exact rewriting of the cephes sinf function.
|
| 246 |
+
Precision is excellent as long as x < 8192 (I did not bother to
|
| 247 |
+
take into account the special handling they have for greater values
|
| 248 |
+
-- it does not return garbage for arguments over 8192, though, but
|
| 249 |
+
the extra precision is missing).
|
| 250 |
+
|
| 251 |
+
Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the
|
| 252 |
+
surprising but correct result.
|
| 253 |
+
|
| 254 |
+
*/
|
| 255 |
+
inline v8sf sin256_ps(v8sf x) { // any x
|
| 256 |
+
v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, sign_bit, y;
|
| 257 |
+
v8si imm0, imm2;
|
| 258 |
+
|
| 259 |
+
sign_bit = x;
|
| 260 |
+
/* take the absolute value */
|
| 261 |
+
x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);
|
| 262 |
+
/* extract the sign bit (upper one) */
|
| 263 |
+
sign_bit = _mm256_and_ps(sign_bit, *(v8sf*)_ps256_sign_mask);
|
| 264 |
+
|
| 265 |
+
/* scale by 4/Pi */
|
| 266 |
+
y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI);
|
| 267 |
+
|
| 268 |
+
/*
|
| 269 |
+
Here we start a series of integer operations, which are in the
|
| 270 |
+
realm of AVX2.
|
| 271 |
+
If we don't have AVX, let's perform them using SSE2 directives
|
| 272 |
+
*/
|
| 273 |
+
|
| 274 |
+
/* store the integer part of y in mm0 */
|
| 275 |
+
imm2 = _mm256_cvttps_epi32(y);
|
| 276 |
+
/* j=(j+1) & (~1) (see the cephes sources) */
|
| 277 |
+
// another two AVX2 instruction
|
| 278 |
+
imm2 = _mm256_add_epi32(imm2, *(v8si*)_pi32_256_1);
|
| 279 |
+
imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_inv1);
|
| 280 |
+
y = _mm256_cvtepi32_ps(imm2);
|
| 281 |
+
|
| 282 |
+
/* get the swap sign flag */
|
| 283 |
+
imm0 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_4);
|
| 284 |
+
imm0 = _mm256_slli_epi32(imm0, 29);
|
| 285 |
+
/* get the polynom selection mask
|
| 286 |
+
there is one polynom for 0 <= x <= Pi/4
|
| 287 |
+
and another one for Pi/4<x<=Pi/2
|
| 288 |
+
|
| 289 |
+
Both branches will be computed.
|
| 290 |
+
*/
|
| 291 |
+
imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2);
|
| 292 |
+
imm2 = _mm256_cmpeq_epi32(imm2,*(v8si*)_pi32_256_0);
|
| 293 |
+
|
| 294 |
+
v8sf swap_sign_bit = _mm256_castsi256_ps(imm0);
|
| 295 |
+
v8sf poly_mask = _mm256_castsi256_ps(imm2);
|
| 296 |
+
sign_bit = _mm256_xor_ps(sign_bit, swap_sign_bit);
|
| 297 |
+
|
| 298 |
+
/* The magic pass: "Extended precision modular arithmetic"
|
| 299 |
+
x = ((x - y * DP1) - y * DP2) - y * DP3; */
|
| 300 |
+
xmm1 = *(v8sf*)_ps256_minus_cephes_DP1;
|
| 301 |
+
xmm2 = *(v8sf*)_ps256_minus_cephes_DP2;
|
| 302 |
+
xmm3 = *(v8sf*)_ps256_minus_cephes_DP3;
|
| 303 |
+
xmm1 = _mm256_mul_ps(y, xmm1);
|
| 304 |
+
xmm2 = _mm256_mul_ps(y, xmm2);
|
| 305 |
+
xmm3 = _mm256_mul_ps(y, xmm3);
|
| 306 |
+
x = _mm256_add_ps(x, xmm1);
|
| 307 |
+
x = _mm256_add_ps(x, xmm2);
|
| 308 |
+
x = _mm256_add_ps(x, xmm3);
|
| 309 |
+
|
| 310 |
+
/* Evaluate the first polynom (0 <= x <= Pi/4) */
|
| 311 |
+
y = *(v8sf*)_ps256_coscof_p0;
|
| 312 |
+
v8sf z = _mm256_mul_ps(x,x);
|
| 313 |
+
|
| 314 |
+
y = _mm256_mul_ps(y, z);
|
| 315 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p1);
|
| 316 |
+
y = _mm256_mul_ps(y, z);
|
| 317 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p2);
|
| 318 |
+
y = _mm256_mul_ps(y, z);
|
| 319 |
+
y = _mm256_mul_ps(y, z);
|
| 320 |
+
v8sf tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5);
|
| 321 |
+
y = _mm256_sub_ps(y, tmp);
|
| 322 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_1);
|
| 323 |
+
|
| 324 |
+
/* Evaluate the second polynom (Pi/4 <= x <= 0) */
|
| 325 |
+
|
| 326 |
+
v8sf y2 = *(v8sf*)_ps256_sincof_p0;
|
| 327 |
+
y2 = _mm256_mul_ps(y2, z);
|
| 328 |
+
y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p1);
|
| 329 |
+
y2 = _mm256_mul_ps(y2, z);
|
| 330 |
+
y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p2);
|
| 331 |
+
y2 = _mm256_mul_ps(y2, z);
|
| 332 |
+
y2 = _mm256_mul_ps(y2, x);
|
| 333 |
+
y2 = _mm256_add_ps(y2, x);
|
| 334 |
+
|
| 335 |
+
/* select the correct result from the two polynoms */
|
| 336 |
+
xmm3 = poly_mask;
|
| 337 |
+
y2 = _mm256_and_ps(xmm3, y2); //, xmm3);
|
| 338 |
+
y = _mm256_andnot_ps(xmm3, y);
|
| 339 |
+
y = _mm256_add_ps(y,y2);
|
| 340 |
+
/* update the sign */
|
| 341 |
+
y = _mm256_xor_ps(y, sign_bit);
|
| 342 |
+
|
| 343 |
+
return y;
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
/* almost the same as sin_ps */
|
| 347 |
+
inline v8sf cos256_ps(v8sf x) { // any x
|
| 348 |
+
v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, y;
|
| 349 |
+
v8si imm0, imm2;
|
| 350 |
+
|
| 351 |
+
/* take the absolute value */
|
| 352 |
+
x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);
|
| 353 |
+
|
| 354 |
+
/* scale by 4/Pi */
|
| 355 |
+
y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI);
|
| 356 |
+
|
| 357 |
+
/* store the integer part of y in mm0 */
|
| 358 |
+
imm2 = _mm256_cvttps_epi32(y);
|
| 359 |
+
/* j=(j+1) & (~1) (see the cephes sources) */
|
| 360 |
+
imm2 = _mm256_add_epi32(imm2, *(v8si*)_pi32_256_1);
|
| 361 |
+
imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_inv1);
|
| 362 |
+
y = _mm256_cvtepi32_ps(imm2);
|
| 363 |
+
imm2 = _mm256_sub_epi32(imm2, *(v8si*)_pi32_256_2);
|
| 364 |
+
|
| 365 |
+
/* get the swap sign flag */
|
| 366 |
+
imm0 = _mm256_andnot_si256(imm2, *(v8si*)_pi32_256_4);
|
| 367 |
+
imm0 = _mm256_slli_epi32(imm0, 29);
|
| 368 |
+
/* get the polynom selection mask */
|
| 369 |
+
imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2);
|
| 370 |
+
imm2 = _mm256_cmpeq_epi32(imm2, *(v8si*)_pi32_256_0);
|
| 371 |
+
|
| 372 |
+
v8sf sign_bit = _mm256_castsi256_ps(imm0);
|
| 373 |
+
v8sf poly_mask = _mm256_castsi256_ps(imm2);
|
| 374 |
+
|
| 375 |
+
/* The magic pass: "Extended precision modular arithmetic"
|
| 376 |
+
x = ((x - y * DP1) - y * DP2) - y * DP3; */
|
| 377 |
+
xmm1 = *(v8sf*)_ps256_minus_cephes_DP1;
|
| 378 |
+
xmm2 = *(v8sf*)_ps256_minus_cephes_DP2;
|
| 379 |
+
xmm3 = *(v8sf*)_ps256_minus_cephes_DP3;
|
| 380 |
+
xmm1 = _mm256_mul_ps(y, xmm1);
|
| 381 |
+
xmm2 = _mm256_mul_ps(y, xmm2);
|
| 382 |
+
xmm3 = _mm256_mul_ps(y, xmm3);
|
| 383 |
+
x = _mm256_add_ps(x, xmm1);
|
| 384 |
+
x = _mm256_add_ps(x, xmm2);
|
| 385 |
+
x = _mm256_add_ps(x, xmm3);
|
| 386 |
+
|
| 387 |
+
/* Evaluate the first polynom (0 <= x <= Pi/4) */
|
| 388 |
+
y = *(v8sf*)_ps256_coscof_p0;
|
| 389 |
+
v8sf z = _mm256_mul_ps(x,x);
|
| 390 |
+
|
| 391 |
+
y = _mm256_mul_ps(y, z);
|
| 392 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p1);
|
| 393 |
+
y = _mm256_mul_ps(y, z);
|
| 394 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p2);
|
| 395 |
+
y = _mm256_mul_ps(y, z);
|
| 396 |
+
y = _mm256_mul_ps(y, z);
|
| 397 |
+
v8sf tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5);
|
| 398 |
+
y = _mm256_sub_ps(y, tmp);
|
| 399 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_1);
|
| 400 |
+
|
| 401 |
+
/* Evaluate the second polynom (Pi/4 <= x <= 0) */
|
| 402 |
+
|
| 403 |
+
v8sf y2 = *(v8sf*)_ps256_sincof_p0;
|
| 404 |
+
y2 = _mm256_mul_ps(y2, z);
|
| 405 |
+
y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p1);
|
| 406 |
+
y2 = _mm256_mul_ps(y2, z);
|
| 407 |
+
y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p2);
|
| 408 |
+
y2 = _mm256_mul_ps(y2, z);
|
| 409 |
+
y2 = _mm256_mul_ps(y2, x);
|
| 410 |
+
y2 = _mm256_add_ps(y2, x);
|
| 411 |
+
|
| 412 |
+
/* select the correct result from the two polynoms */
|
| 413 |
+
xmm3 = poly_mask;
|
| 414 |
+
y2 = _mm256_and_ps(xmm3, y2); //, xmm3);
|
| 415 |
+
y = _mm256_andnot_ps(xmm3, y);
|
| 416 |
+
y = _mm256_add_ps(y,y2);
|
| 417 |
+
/* update the sign */
|
| 418 |
+
y = _mm256_xor_ps(y, sign_bit);
|
| 419 |
+
|
| 420 |
+
return y;
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
/* since sin256_ps and cos256_ps are almost identical, sincos256_ps could replace both of them..
|
| 424 |
+
it is almost as fast, and gives you a free cosine with your sine */
|
| 425 |
+
inline void sincos256_ps(v8sf x, v8sf *s, v8sf *c) {
|
| 426 |
+
|
| 427 |
+
v8sf xmm1, xmm2, xmm3 = _mm256_setzero_ps(), sign_bit_sin, y;
|
| 428 |
+
v8si imm0, imm2, imm4;
|
| 429 |
+
|
| 430 |
+
sign_bit_sin = x;
|
| 431 |
+
/* take the absolute value */
|
| 432 |
+
x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);
|
| 433 |
+
/* extract the sign bit (upper one) */
|
| 434 |
+
sign_bit_sin = _mm256_and_ps(sign_bit_sin, *(v8sf*)_ps256_sign_mask);
|
| 435 |
+
|
| 436 |
+
/* scale by 4/Pi */
|
| 437 |
+
y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI);
|
| 438 |
+
|
| 439 |
+
/* store the integer part of y in imm2 */
|
| 440 |
+
imm2 = _mm256_cvttps_epi32(y);
|
| 441 |
+
|
| 442 |
+
/* j=(j+1) & (~1) (see the cephes sources) */
|
| 443 |
+
imm2 = _mm256_add_epi32(imm2, *(v8si*)_pi32_256_1);
|
| 444 |
+
imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_inv1);
|
| 445 |
+
|
| 446 |
+
y = _mm256_cvtepi32_ps(imm2);
|
| 447 |
+
imm4 = imm2;
|
| 448 |
+
|
| 449 |
+
/* get the swap sign flag for the sine */
|
| 450 |
+
imm0 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_4);
|
| 451 |
+
imm0 = _mm256_slli_epi32(imm0, 29);
|
| 452 |
+
//v8sf swap_sign_bit_sin = _mm256_castsi256_ps(imm0);
|
| 453 |
+
|
| 454 |
+
/* get the polynom selection mask for the sine*/
|
| 455 |
+
imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2);
|
| 456 |
+
imm2 = _mm256_cmpeq_epi32(imm2, *(v8si*)_pi32_256_0);
|
| 457 |
+
//v8sf poly_mask = _mm256_castsi256_ps(imm2);
|
| 458 |
+
|
| 459 |
+
v8sf swap_sign_bit_sin = _mm256_castsi256_ps(imm0);
|
| 460 |
+
v8sf poly_mask = _mm256_castsi256_ps(imm2);
|
| 461 |
+
|
| 462 |
+
/* The magic pass: "Extended precision modular arithmetic"
|
| 463 |
+
x = ((x - y * DP1) - y * DP2) - y * DP3; */
|
| 464 |
+
xmm1 = *(v8sf*)_ps256_minus_cephes_DP1;
|
| 465 |
+
xmm2 = *(v8sf*)_ps256_minus_cephes_DP2;
|
| 466 |
+
xmm3 = *(v8sf*)_ps256_minus_cephes_DP3;
|
| 467 |
+
xmm1 = _mm256_mul_ps(y, xmm1);
|
| 468 |
+
xmm2 = _mm256_mul_ps(y, xmm2);
|
| 469 |
+
xmm3 = _mm256_mul_ps(y, xmm3);
|
| 470 |
+
x = _mm256_add_ps(x, xmm1);
|
| 471 |
+
x = _mm256_add_ps(x, xmm2);
|
| 472 |
+
x = _mm256_add_ps(x, xmm3);
|
| 473 |
+
|
| 474 |
+
imm4 = _mm256_sub_epi32(imm4, *(v8si*)_pi32_256_2);
|
| 475 |
+
imm4 = _mm256_andnot_si256(imm4, *(v8si*)_pi32_256_4);
|
| 476 |
+
imm4 = _mm256_slli_epi32(imm4, 29);
|
| 477 |
+
|
| 478 |
+
v8sf sign_bit_cos = _mm256_castsi256_ps(imm4);
|
| 479 |
+
|
| 480 |
+
sign_bit_sin = _mm256_xor_ps(sign_bit_sin, swap_sign_bit_sin);
|
| 481 |
+
|
| 482 |
+
/* Evaluate the first polynom (0 <= x <= Pi/4) */
|
| 483 |
+
v8sf z = _mm256_mul_ps(x,x);
|
| 484 |
+
y = *(v8sf*)_ps256_coscof_p0;
|
| 485 |
+
|
| 486 |
+
y = _mm256_mul_ps(y, z);
|
| 487 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p1);
|
| 488 |
+
y = _mm256_mul_ps(y, z);
|
| 489 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p2);
|
| 490 |
+
y = _mm256_mul_ps(y, z);
|
| 491 |
+
y = _mm256_mul_ps(y, z);
|
| 492 |
+
v8sf tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5);
|
| 493 |
+
y = _mm256_sub_ps(y, tmp);
|
| 494 |
+
y = _mm256_add_ps(y, *(v8sf*)_ps256_1);
|
| 495 |
+
|
| 496 |
+
/* Evaluate the second polynom (Pi/4 <= x <= 0) */
|
| 497 |
+
|
| 498 |
+
v8sf y2 = *(v8sf*)_ps256_sincof_p0;
|
| 499 |
+
y2 = _mm256_mul_ps(y2, z);
|
| 500 |
+
y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p1);
|
| 501 |
+
y2 = _mm256_mul_ps(y2, z);
|
| 502 |
+
y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p2);
|
| 503 |
+
y2 = _mm256_mul_ps(y2, z);
|
| 504 |
+
y2 = _mm256_mul_ps(y2, x);
|
| 505 |
+
y2 = _mm256_add_ps(y2, x);
|
| 506 |
+
|
| 507 |
+
/* select the correct result from the two polynoms */
|
| 508 |
+
xmm3 = poly_mask;
|
| 509 |
+
v8sf ysin2 = _mm256_and_ps(xmm3, y2);
|
| 510 |
+
v8sf ysin1 = _mm256_andnot_ps(xmm3, y);
|
| 511 |
+
y2 = _mm256_sub_ps(y2,ysin2);
|
| 512 |
+
y = _mm256_sub_ps(y, ysin1);
|
| 513 |
+
|
| 514 |
+
xmm1 = _mm256_add_ps(ysin1,ysin2);
|
| 515 |
+
xmm2 = _mm256_add_ps(y,y2);
|
| 516 |
+
|
| 517 |
+
/* update the sign */
|
| 518 |
+
*s = _mm256_xor_ps(xmm1, sign_bit_sin);
|
| 519 |
+
*c = _mm256_xor_ps(xmm2, sign_bit_cos);
|
| 520 |
+
}
|
| 521 |
+
|
| 522 |
+
#endif // CPU_CAPABILITY_AVX2
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/int_mm_kernel.h
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/native/DispatchStub.h>
|
| 5 |
+
|
| 6 |
+
namespace at::native {
|
| 7 |
+
|
| 8 |
+
using weight_to_int4pack_fn = void(*)(const Tensor&, const Tensor&, int, int);
|
| 9 |
+
using int4pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, int, const Tensor&, int, int);
|
| 10 |
+
using int8pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&);
|
| 11 |
+
|
| 12 |
+
DECLARE_DISPATCH(weight_to_int4pack_fn, weight_to_int4pack_stub);
|
| 13 |
+
DECLARE_DISPATCH(int4pack_mm_fn, int4pack_mm_stub);
|
| 14 |
+
DECLARE_DISPATCH(int8pack_mm_fn, int8pack_mm_stub);
|
| 15 |
+
|
| 16 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/mixed_data_type.h
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
|
| 5 |
+
namespace at::native {
|
| 6 |
+
|
| 7 |
+
inline ScalarType first_type() {
|
| 8 |
+
return ScalarType::Undefined;
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
template <typename... Args>
|
| 12 |
+
inline ScalarType first_type(const Tensor& arg, const Args&... parameters) {
|
| 13 |
+
return arg.defined() ? arg.scalar_type() : first_type(parameters...);
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
template <typename... Args>
|
| 17 |
+
inline bool is_mixed_type(const Tensor& input, const Args&... parameters) {
|
| 18 |
+
const auto parameter_type = first_type(parameters...);
|
| 19 |
+
return ((parameter_type != ScalarType::Undefined) &&
|
| 20 |
+
(parameter_type != input.scalar_type()));
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
// currently on CPU, mixed data type is only supported
|
| 24 |
+
// when input is 'BFloat16' or 'Half' and parameters are 'Float'
|
| 25 |
+
inline void check_mixed_data_type(const Tensor& input) {
|
| 26 |
+
TORCH_CHECK(at::isReducedFloatingType(input.scalar_type()),
|
| 27 |
+
"mixed dtype (CPU): all inputs must share same datatype.");
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
template <typename... Args>
|
| 31 |
+
inline void check_mixed_data_type(const Tensor& input, const Tensor& parameter, const Args&... parameters) {
|
| 32 |
+
TORCH_CHECK(!parameter.defined() || parameter.scalar_type() == ScalarType::Float,
|
| 33 |
+
"mixed dtype (CPU): expect parameter to have scalar type of Float");
|
| 34 |
+
check_mixed_data_type(input, parameters...);
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
inline ScalarType param_scalar_type(const Tensor& t, bool is_mixed_type) {
|
| 38 |
+
return is_mixed_type ? ScalarType::Float : t.scalar_type();
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/moments_utils.h
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <array>
|
| 4 |
+
#include <cstring>
|
| 5 |
+
#include <utility>
|
| 6 |
+
|
| 7 |
+
#include <ATen/Parallel.h>
|
| 8 |
+
#include <ATen/OpMathType.h>
|
| 9 |
+
#include <ATen/cpu/vec/vec.h>
|
| 10 |
+
#include <ATen/native/cpu/utils.h>
|
| 11 |
+
#include <c10/util/SmallVector.h>
|
| 12 |
+
#include <c10/util/irange.h>
|
| 13 |
+
|
| 14 |
+
namespace at::native {
|
| 15 |
+
inline namespace CPU_CAPABILITY {
|
| 16 |
+
|
| 17 |
+
template<typename T> using opmath_t = at::opmath_type<T>;
|
| 18 |
+
|
| 19 |
+
constexpr int64_t kChunkSize = 16;
|
| 20 |
+
|
| 21 |
+
template <typename T>
|
| 22 |
+
void AddMoments(
|
| 23 |
+
int64_t m0_add,
|
| 24 |
+
const T& m1_add,
|
| 25 |
+
const T& m2_add,
|
| 26 |
+
int64_t& m0,
|
| 27 |
+
T& m1,
|
| 28 |
+
T& m2) {
|
| 29 |
+
const int64_t n = m0 + m0_add;
|
| 30 |
+
const T c = n == 0 ? static_cast<T>(0) : static_cast<T>(m0_add) / static_cast<T>(n);
|
| 31 |
+
const T delta = m1_add - m1;
|
| 32 |
+
m1 += c * delta;
|
| 33 |
+
m2 += m2_add + delta * delta * c * static_cast<T>(m0);
|
| 34 |
+
m0 = n;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
template <typename T>
|
| 38 |
+
C10_ALWAYS_INLINE void AddMomentsVec(
|
| 39 |
+
int64_t m0_add,
|
| 40 |
+
const vec::Vectorized<T>& m1_add,
|
| 41 |
+
const vec::Vectorized<T>& m2_add,
|
| 42 |
+
int64_t& m0,
|
| 43 |
+
vec::Vectorized<T>& m1,
|
| 44 |
+
vec::Vectorized<T>& m2) {
|
| 45 |
+
using Vec = vec::Vectorized<T>;
|
| 46 |
+
const int64_t n = m0 + m0_add;
|
| 47 |
+
const T c = n == 0 ? static_cast<T>(0) : static_cast<T>(m0_add) / static_cast<T>(n);
|
| 48 |
+
const Vec c_vec(c);
|
| 49 |
+
const Vec delta = m1_add - m1;
|
| 50 |
+
m1 += c_vec * delta;
|
| 51 |
+
m2 += m2_add + delta * delta * c_vec * Vec(static_cast<T>(m0));
|
| 52 |
+
m0 = n;
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
template <typename T>
|
| 56 |
+
inline std::enable_if_t<std::is_same_v<T, opmath_t<T>>, void>
|
| 57 |
+
UpdateMomentsVec(
|
| 58 |
+
int64_t m0,
|
| 59 |
+
const T* X_ptr,
|
| 60 |
+
const std::array<vec::Vectorized<opmath_t<T>>, kChunkSize>& c_vecs,
|
| 61 |
+
int64_t& m0_stk0,
|
| 62 |
+
vec::Vectorized<opmath_t<T>>& m1_stk0,
|
| 63 |
+
vec::Vectorized<opmath_t<T>>& m2_stk0) {
|
| 64 |
+
using Vec = vec::Vectorized<opmath_t<T>>;
|
| 65 |
+
Vec m1_vec(0);
|
| 66 |
+
Vec m2_vec(0);
|
| 67 |
+
for (const auto j : c10::irange(m0)) {
|
| 68 |
+
const Vec x_vec = Vec::loadu(X_ptr + j * Vec::size());
|
| 69 |
+
const Vec delta_vec = x_vec - m1_vec;
|
| 70 |
+
m1_vec += delta_vec * c_vecs[j];
|
| 71 |
+
m2_vec += delta_vec * (x_vec - m1_vec);
|
| 72 |
+
}
|
| 73 |
+
AddMomentsVec(m0, m1_vec, m2_vec, m0_stk0, m1_stk0, m2_stk0);
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
// each bfloat16/half vector will be converted to two float vectors,
|
| 77 |
+
// and accumulated successively on m1_stk0/m2_stk0.
|
| 78 |
+
template <typename T>
|
| 79 |
+
inline std::enable_if_t<!std::is_same_v<T, at::opmath_type<T>>, void>
|
| 80 |
+
UpdateMomentsVec(
|
| 81 |
+
int64_t m0,
|
| 82 |
+
const T* X_ptr,
|
| 83 |
+
const std::array<vec::Vectorized<at::opmath_type<T>>, kChunkSize>& c_vecs,
|
| 84 |
+
int64_t& m0_stk0,
|
| 85 |
+
vec::Vectorized<at::opmath_type<T>>& m1_stk0,
|
| 86 |
+
vec::Vectorized<at::opmath_type<T>>& m2_stk0) {
|
| 87 |
+
using Vec = vec::Vectorized<T>;
|
| 88 |
+
using fVec = vec::Vectorized<at::opmath_type<T>>;
|
| 89 |
+
fVec m1_fvec0(0), m1_fvec1(0);
|
| 90 |
+
fVec m2_fvec0(0), m2_fvec1(0);
|
| 91 |
+
for (const auto j : c10::irange(m0)) {
|
| 92 |
+
const Vec x_bvec = Vec::loadu(X_ptr + j * Vec::size());
|
| 93 |
+
auto [x_fvec0, x_fvec1] = convert_to_float<T>(x_bvec);
|
| 94 |
+
const fVec delta_fvec0 = x_fvec0 - m1_fvec0;
|
| 95 |
+
const fVec delta_fvec1 = x_fvec1 - m1_fvec1;
|
| 96 |
+
m1_fvec0 += delta_fvec0 * c_vecs[j];
|
| 97 |
+
m1_fvec1 += delta_fvec1 * c_vecs[j];
|
| 98 |
+
m2_fvec0 += delta_fvec0 * (x_fvec0 - m1_fvec0);
|
| 99 |
+
m2_fvec1 += delta_fvec1 * (x_fvec1 - m1_fvec1);
|
| 100 |
+
}
|
| 101 |
+
AddMomentsVec(m0, m1_fvec0, m2_fvec0, m0_stk0, m1_stk0, m2_stk0);
|
| 102 |
+
AddMomentsVec(m0, m1_fvec1, m2_fvec1, m0_stk0, m1_stk0, m2_stk0);
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
// Compute rowwise moments by Welford algorithm and cascade sum to improve
|
| 106 |
+
// numerical stability.
|
| 107 |
+
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
|
| 108 |
+
// https://en.wikipedia.org/wiki/Pairwise_summation
|
| 109 |
+
template <typename T, int64_t kMaxDepth>
|
| 110 |
+
std::pair<opmath_t<T>, opmath_t<T>> RowwiseMomentsImpl(const T* X, int64_t N, int64_t ddof = 0) {
|
| 111 |
+
using math_t = opmath_t<T>;
|
| 112 |
+
|
| 113 |
+
constexpr int64_t kVecSize = vec::Vectorized<T>::size();
|
| 114 |
+
constexpr int64_t kAccVecSize = vec::Vectorized<math_t>::size();
|
| 115 |
+
const int64_t n = N / kVecSize;
|
| 116 |
+
const int64_t m = divup(n, kChunkSize);
|
| 117 |
+
const int64_t depth = utils::CeilLog2(m);
|
| 118 |
+
|
| 119 |
+
using Vec = vec::Vectorized<math_t>;
|
| 120 |
+
const Vec kZeroVec(math_t(0));
|
| 121 |
+
c10::SmallVector<int64_t, kMaxDepth> m0_stk(depth, 0);
|
| 122 |
+
c10::SmallVector<Vec, kMaxDepth> m1_stk(depth, kZeroVec);
|
| 123 |
+
c10::SmallVector<Vec, kMaxDepth> m2_stk(depth, kZeroVec);
|
| 124 |
+
|
| 125 |
+
for (const auto i : c10::irange(m)) {
|
| 126 |
+
const T* X_ptr = X + i * kChunkSize * kVecSize;
|
| 127 |
+
const int64_t m0 = std::min(kChunkSize, n - i * kChunkSize);
|
| 128 |
+
static std::array<Vec, kChunkSize> c_vecs = ([]() {
|
| 129 |
+
std::array<Vec, kChunkSize> result;
|
| 130 |
+
for (const auto i : c10::irange(kChunkSize)) {
|
| 131 |
+
result[i] = Vec(math_t(1) / static_cast<math_t>(i + 1));
|
| 132 |
+
}
|
| 133 |
+
return result;
|
| 134 |
+
})();
|
| 135 |
+
UpdateMomentsVec(m0, X_ptr, c_vecs, m0_stk[0], m1_stk[0], m2_stk[0]);
|
| 136 |
+
|
| 137 |
+
int64_t mask = i + 1;
|
| 138 |
+
for (int64_t j = 1; j < depth && (mask & 1) == 0; ++j) {
|
| 139 |
+
AddMomentsVec(
|
| 140 |
+
m0_stk[j - 1],
|
| 141 |
+
m1_stk[j - 1],
|
| 142 |
+
m2_stk[j - 1],
|
| 143 |
+
m0_stk[j],
|
| 144 |
+
m1_stk[j],
|
| 145 |
+
m2_stk[j]);
|
| 146 |
+
m0_stk[j - 1] = 0;
|
| 147 |
+
m1_stk[j - 1] = kZeroVec;
|
| 148 |
+
m2_stk[j - 1] = kZeroVec;
|
| 149 |
+
mask >>= 1;
|
| 150 |
+
}
|
| 151 |
+
}
|
| 152 |
+
for (const auto i : c10::irange(1, depth)) {
|
| 153 |
+
AddMomentsVec(
|
| 154 |
+
m0_stk[i], m1_stk[i], m2_stk[i], m0_stk[0], m1_stk[0], m2_stk[0]);
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
std::array<math_t, kAccVecSize> m1_arr{};
|
| 158 |
+
std::array<math_t, kAccVecSize> m2_arr{};
|
| 159 |
+
m1_stk[0].store(m1_arr.data());
|
| 160 |
+
m2_stk[0].store(m2_arr.data());
|
| 161 |
+
|
| 162 |
+
int64_t m0 = 0;
|
| 163 |
+
math_t m1 = 0;
|
| 164 |
+
math_t m2 = 0;
|
| 165 |
+
for (int64_t i = n * kVecSize; i < N; ++i) {
|
| 166 |
+
math_t x = static_cast<math_t>(X[i]);
|
| 167 |
+
const math_t delta = x - m1;
|
| 168 |
+
++m0;
|
| 169 |
+
m1 += delta / static_cast<math_t>(m0);
|
| 170 |
+
m2 += delta * (x - m1);
|
| 171 |
+
}
|
| 172 |
+
// for BFloat16, each vector in m1_arr/m2_arr holds 2*n accumulated result
|
| 173 |
+
int64_t m0_add = n * kVecSize / kAccVecSize;
|
| 174 |
+
for (const auto i : c10::irange(kAccVecSize)) {
|
| 175 |
+
AddMoments(m0_add, m1_arr[i], m2_arr[i], m0, m1, m2);
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
return std::make_pair(m1, m2 / static_cast<math_t>(N - ddof));
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
template <typename T>
|
| 182 |
+
std::pair<opmath_t<T>, opmath_t<T>> RowwiseMoments(const T* X, int64_t N, int64_t ddof = 0) {
|
| 183 |
+
using Vec = vec::Vectorized<T>;
|
| 184 |
+
constexpr int64_t kVecSize = Vec::size();
|
| 185 |
+
const int64_t n = N / kVecSize;
|
| 186 |
+
const int64_t m = divup(n, kChunkSize);
|
| 187 |
+
const int64_t depth = utils::CeilLog2(m);
|
| 188 |
+
if (depth <= 4) {
|
| 189 |
+
return RowwiseMomentsImpl<T, 4>(X, N, ddof);
|
| 190 |
+
} else if (depth <= 8) {
|
| 191 |
+
return RowwiseMomentsImpl<T, 8>(X, N, ddof);
|
| 192 |
+
} else if (depth <= 16) {
|
| 193 |
+
return RowwiseMomentsImpl<T, 16>(X, N, ddof);
|
| 194 |
+
} else if (depth <= 32) {
|
| 195 |
+
return RowwiseMomentsImpl<T, 32>(X, N, ddof);
|
| 196 |
+
} else {
|
| 197 |
+
return RowwiseMomentsImpl<T, 64>(X, N, ddof);
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
} // namespace CPU_CAPABILITY
|
| 202 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/utils.h
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/Parallel.h>
|
| 4 |
+
#include <ATen/core/TensorAccessor.h>
|
| 5 |
+
#include <ATen/cpu/vec/vec.h>
|
| 6 |
+
#include <c10/util/llvmMathExtras.h>
|
| 7 |
+
|
| 8 |
+
#ifdef USE_FBGEMM
|
| 9 |
+
#include <fbgemm/Fbgemm.h>
|
| 10 |
+
#endif
|
| 11 |
+
|
| 12 |
+
namespace at::native {
|
| 13 |
+
|
| 14 |
+
template <typename T>
|
| 15 |
+
inline void _store(T* dst, at::vec::Vectorized<T> src) {
|
| 16 |
+
src.store(dst);
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
inline void _store(at::BFloat16* dst, at::vec::Vectorized<float> src) {
|
| 20 |
+
auto res = at::vec::convert_float_bfloat16(src, src);
|
| 21 |
+
res.store(dst, at::vec::Vectorized<float>::size());
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
inline void _store(at::Half* dst, at::vec::Vectorized<float> src) {
|
| 25 |
+
auto res = at::vec::convert_float_half(src, src);
|
| 26 |
+
res.store(dst, at::vec::Vectorized<float>::size());
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
inline namespace CPU_CAPABILITY {
|
| 30 |
+
|
| 31 |
+
template <typename T>
|
| 32 |
+
inline T data_index_init(T offset) {
|
| 33 |
+
return offset;
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
template <typename T, typename... Args>
|
| 37 |
+
inline T data_index_init(T offset, T& x, const T& X, Args&&... args) {
|
| 38 |
+
offset = data_index_init(offset, std::forward<Args>(args)...);
|
| 39 |
+
x = offset % X;
|
| 40 |
+
return offset / X;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
inline bool data_index_step() {
|
| 44 |
+
return true;
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
template <typename T, typename... Args>
|
| 48 |
+
inline bool data_index_step(T& x, const T& X, Args&&... args) {
|
| 49 |
+
if (data_index_step(std::forward<Args>(args)...)) {
|
| 50 |
+
x = ((x + 1) == X) ? 0 : (x + 1);
|
| 51 |
+
return x == 0;
|
| 52 |
+
}
|
| 53 |
+
return false;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
// Helper struct for bfloat16/float16 vectorization
|
| 57 |
+
// Useful when you need float as immediate dtype or accumulate dtype
|
| 58 |
+
using namespace vec;
|
| 59 |
+
struct Vec2 {
|
| 60 |
+
Vectorized<float> val0, val1;
|
| 61 |
+
Vec2(Vectorized<float> v0, Vectorized<float> v1) : val0(v0), val1(v1) {}
|
| 62 |
+
Vec2(float v) : val0(v), val1(v) {}
|
| 63 |
+
static Vec2 loadu(const BFloat16* ptr) {
|
| 64 |
+
auto [v0, v1] = convert_bfloat16_float(Vectorized<BFloat16>::loadu(ptr));
|
| 65 |
+
return {v0, v1};
|
| 66 |
+
}
|
| 67 |
+
static Vec2 loadu(const Half* ptr) {
|
| 68 |
+
auto [v0, v1] = convert_half_float(Vectorized<Half>::loadu(ptr));
|
| 69 |
+
return {v0, v1};
|
| 70 |
+
}
|
| 71 |
+
static Vec2 loadu(const float* ptr) {
|
| 72 |
+
return {Vectorized<float>::loadu(ptr), Vectorized<float>::loadu(ptr + Vectorized<float>::size())};
|
| 73 |
+
}
|
| 74 |
+
void store(BFloat16* ptr) const {
|
| 75 |
+
Vectorized<BFloat16> val = convert_float_bfloat16(val0, val1);
|
| 76 |
+
val.store(ptr);
|
| 77 |
+
}
|
| 78 |
+
void store(Half* ptr) const {
|
| 79 |
+
Vectorized<Half> val = convert_float_half(val0, val1);
|
| 80 |
+
val.store(ptr);
|
| 81 |
+
}
|
| 82 |
+
void store(float* ptr) const {
|
| 83 |
+
val0.store(ptr);
|
| 84 |
+
val1.store(ptr + Vectorized<float>::size());
|
| 85 |
+
}
|
| 86 |
+
};
|
| 87 |
+
inline Vec2 operator+(const Vec2& a, const Vec2& b) { return {a.val0 + b.val0, a.val1 + b.val1}; }
|
| 88 |
+
inline Vec2 operator*(const Vec2& a, const Vec2& b) { return {a.val0 * b.val0, a.val1 * b.val1}; }
|
| 89 |
+
inline Vec2 operator-(const Vec2& a, const Vec2& b) { return {a.val0 - b.val0, a.val1 - b.val1}; }
|
| 90 |
+
inline Vec2 operator/(const Vec2& a, const Vec2& b) { return {a.val0 / b.val0, a.val1 / b.val1}; }
|
| 91 |
+
inline Vec2 maximum(const Vec2& a, const Vec2& b) { return {vec::maximum(a.val0, b.val0), vec::maximum(a.val1, b.val1)}; }
|
| 92 |
+
inline Vec2 minimum(const Vec2& a, const Vec2& b) { return {vec::minimum(a.val0, b.val0), vec::minimum(a.val1, b.val1)}; }
|
| 93 |
+
|
| 94 |
+
template <typename scalar_t> struct VectorizedType { using type = Vectorized<scalar_t>; };
|
| 95 |
+
template <> struct VectorizedType<BFloat16> { using type = Vec2; };
|
| 96 |
+
template <> struct VectorizedType<Half> { using type = Vec2; };
|
| 97 |
+
template <typename scalar_t> using VecType = typename VectorizedType<scalar_t>::type;
|
| 98 |
+
|
| 99 |
+
// Helper for mixed data type parameter Vec::load
|
| 100 |
+
inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const BFloat16* ptr) {
|
| 101 |
+
return convert_bfloat16_float(Vectorized<BFloat16>::loadu(ptr));
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const Half* ptr) {
|
| 105 |
+
return convert_half_float(Vectorized<Half>::loadu(ptr));
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const float* ptr) {
|
| 109 |
+
using Vec = Vectorized<float>;
|
| 110 |
+
return std::make_tuple(Vec::loadu(ptr), Vec::loadu(ptr + Vec::size()));
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const BFloat16* ptr, int64_t count) {
|
| 114 |
+
return convert_bfloat16_float(Vectorized<BFloat16>::loadu(ptr, count));
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const Half* ptr, int64_t count) {
|
| 118 |
+
return convert_half_float(Vectorized<Half>::loadu(ptr, count));
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const float* ptr, int64_t count) {
|
| 122 |
+
using Vec = Vectorized<float>;
|
| 123 |
+
if (count > Vec::size()) {
|
| 124 |
+
return std::make_tuple(Vec::loadu(ptr), Vec::loadu(ptr + Vec::size(), count - Vec::size()));
|
| 125 |
+
} else {
|
| 126 |
+
return std::make_tuple(Vec::loadu(ptr, count), Vec(0));
|
| 127 |
+
}
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
} // namespace
|
| 131 |
+
|
| 132 |
+
namespace utils {
|
| 133 |
+
|
| 134 |
+
template <typename T>
|
| 135 |
+
T CeilLog2(const T& x) {
|
| 136 |
+
if (x <= 2) {
|
| 137 |
+
return 1;
|
| 138 |
+
}
|
| 139 |
+
// Last set bit is floor(log2(x)), floor + 1 is ceil
|
| 140 |
+
// except when x is an exact powers of 2, so subtract 1 first
|
| 141 |
+
return static_cast<T>(llvm::findLastSet(static_cast<uint64_t>(x) - 1)) + 1;
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
// matrix transpose:
|
| 145 |
+
// src has shape of M by N, with leading dimension of ld_src
|
| 146 |
+
// dst has shape of N by M, with leading dimension of ld_dst
|
| 147 |
+
template <typename T>
|
| 148 |
+
inline void transpose(int64_t M, int64_t N, const T* src, int64_t ld_src, T* dst, int64_t ld_dst) {
|
| 149 |
+
for (int64_t j = 0; j < N; j++) {
|
| 150 |
+
for (int64_t i = 0; i < M; i++) {
|
| 151 |
+
dst[j * ld_dst + i] = src[i * ld_src + j];
|
| 152 |
+
}
|
| 153 |
+
}
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
#ifdef USE_FBGEMM
|
| 157 |
+
template <>
|
| 158 |
+
inline void transpose<float>(int64_t M, int64_t N, const float* src, int64_t ld_src, float* dst, int64_t ld_dst) {
|
| 159 |
+
TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
|
| 160 |
+
fbgemm::transpose_simd<float>(M, N, src, ld_src, dst, ld_dst);
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
template <>
|
| 164 |
+
inline void transpose<uint16_t>(int64_t M, int64_t N, const uint16_t* src, int64_t ld_src, uint16_t* dst, int64_t ld_dst) {
|
| 165 |
+
TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
|
| 166 |
+
fbgemm::transpose_simd<uint16_t>(M, N, src, ld_src, dst, ld_dst);
|
| 167 |
+
}
|
| 168 |
+
#endif
|
| 169 |
+
|
| 170 |
+
template <typename index_t, typename F>
|
| 171 |
+
inline void parallel_sparse_csr(
|
| 172 |
+
const TensorAccessor<index_t, 1>& crow_acc,
|
| 173 |
+
const int64_t M,
|
| 174 |
+
const int64_t nnz,
|
| 175 |
+
const F& f) {
|
| 176 |
+
TORCH_CHECK(crow_acc.size(0) == M + 1);
|
| 177 |
+
|
| 178 |
+
// directly parallel on `M` may lead to load imbalance,
|
| 179 |
+
// statically determine thread partition here to average payload
|
| 180 |
+
// for each thread.
|
| 181 |
+
int num_threads = at::get_num_threads();
|
| 182 |
+
std::vector<int64_t> thread_splits(num_threads + 1, M);
|
| 183 |
+
|
| 184 |
+
int64_t thread_averge_payload = std::max((int64_t)1, divup(nnz, num_threads));
|
| 185 |
+
|
| 186 |
+
thread_splits[0] = 0;
|
| 187 |
+
int64_t sum = 0;
|
| 188 |
+
int64_t t = 1;
|
| 189 |
+
for (const auto m : c10::irange(M)) {
|
| 190 |
+
int64_t row_start = crow_acc[m];
|
| 191 |
+
int64_t row_end = crow_acc[m + 1];
|
| 192 |
+
sum += row_end - row_start;
|
| 193 |
+
if (sum > t * thread_averge_payload) {
|
| 194 |
+
thread_splits[t] = m;
|
| 195 |
+
t++;
|
| 196 |
+
}
|
| 197 |
+
}
|
| 198 |
+
// need to restore the last index,
|
| 199 |
+
// due to rounding error when calculating `thread_averge_payload`.
|
| 200 |
+
thread_splits[num_threads] = M;
|
| 201 |
+
|
| 202 |
+
at::parallel_for(0, num_threads, 1, [&](int64_t cbegin, int64_t cend) {
|
| 203 |
+
int tid = at::get_thread_num();
|
| 204 |
+
int64_t begin = thread_splits[tid];
|
| 205 |
+
int64_t end = thread_splits[tid + 1];
|
| 206 |
+
f(begin, end);
|
| 207 |
+
});
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
} // namespace utils
|
| 211 |
+
|
| 212 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/zmath.h
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// Complex number math operations that act as no-ops for other dtypes.
|
| 4 |
+
#include <c10/util/complex.h>
|
| 5 |
+
#include <c10/util/MathConstants.h>
|
| 6 |
+
#include<ATen/NumericUtils.h>
|
| 7 |
+
|
| 8 |
+
namespace at::native {
|
| 9 |
+
inline namespace CPU_CAPABILITY {
|
| 10 |
+
|
| 11 |
+
template <typename SCALAR_TYPE, typename VALUE_TYPE=SCALAR_TYPE>
|
| 12 |
+
inline VALUE_TYPE zabs (SCALAR_TYPE z) {
|
| 13 |
+
return z;
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
template<>
|
| 17 |
+
inline c10::complex<float> zabs <c10::complex<float>> (c10::complex<float> z) {
|
| 18 |
+
return c10::complex<float>(std::abs(z));
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
template<>
|
| 22 |
+
inline float zabs <c10::complex<float>, float> (c10::complex<float> z) {
|
| 23 |
+
return std::abs(z);
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
template<>
|
| 27 |
+
inline c10::complex<double> zabs <c10::complex<double>> (c10::complex<double> z) {
|
| 28 |
+
return c10::complex<double>(std::abs(z));
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
template<>
|
| 32 |
+
inline double zabs <c10::complex<double>, double> (c10::complex<double> z) {
|
| 33 |
+
return std::abs(z);
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
// This overload corresponds to non-complex dtypes.
|
| 37 |
+
// The function is consistent with its NumPy equivalent
|
| 38 |
+
// for non-complex dtypes where `pi` is returned for
|
| 39 |
+
// negative real numbers and `0` is returned for 0 or positive
|
| 40 |
+
// real numbers.
|
| 41 |
+
// Note: `nan` is propagated.
|
| 42 |
+
template <typename SCALAR_TYPE, typename VALUE_TYPE=SCALAR_TYPE>
|
| 43 |
+
inline VALUE_TYPE angle_impl (SCALAR_TYPE z) {
|
| 44 |
+
if (at::_isnan(z)) {
|
| 45 |
+
return z;
|
| 46 |
+
}
|
| 47 |
+
return z < 0 ? c10::pi<double> : 0;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
template<>
|
| 51 |
+
inline c10::complex<float> angle_impl <c10::complex<float>> (c10::complex<float> z) {
|
| 52 |
+
return c10::complex<float>(std::arg(z), 0.0);
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
template<>
|
| 56 |
+
inline float angle_impl <c10::complex<float>, float> (c10::complex<float> z) {
|
| 57 |
+
return std::arg(z);
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
template<>
|
| 61 |
+
inline c10::complex<double> angle_impl <c10::complex<double>> (c10::complex<double> z) {
|
| 62 |
+
return c10::complex<double>(std::arg(z), 0.0);
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
template<>
|
| 66 |
+
inline double angle_impl <c10::complex<double>, double> (c10::complex<double> z) {
|
| 67 |
+
return std::arg(z);
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
template <typename SCALAR_TYPE, typename VALUE_TYPE=SCALAR_TYPE>
|
| 71 |
+
constexpr VALUE_TYPE real_impl (SCALAR_TYPE z) {
|
| 72 |
+
return z; //No-Op
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
template<>
|
| 76 |
+
constexpr c10::complex<float> real_impl <c10::complex<float>> (c10::complex<float> z) {
|
| 77 |
+
return c10::complex<float>(z.real(), 0.0);
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
template<>
|
| 81 |
+
constexpr float real_impl <c10::complex<float>, float> (c10::complex<float> z) {
|
| 82 |
+
return z.real();
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
template<>
|
| 86 |
+
constexpr c10::complex<double> real_impl <c10::complex<double>> (c10::complex<double> z) {
|
| 87 |
+
return c10::complex<double>(z.real(), 0.0);
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
template<>
|
| 91 |
+
constexpr double real_impl <c10::complex<double>, double> (c10::complex<double> z) {
|
| 92 |
+
return z.real();
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
template <typename SCALAR_TYPE, typename VALUE_TYPE=SCALAR_TYPE>
|
| 96 |
+
constexpr VALUE_TYPE imag_impl (SCALAR_TYPE /*z*/) {
|
| 97 |
+
return 0;
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
template<>
|
| 101 |
+
constexpr c10::complex<float> imag_impl <c10::complex<float>> (c10::complex<float> z) {
|
| 102 |
+
return c10::complex<float>(z.imag(), 0.0);
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
template<>
|
| 106 |
+
constexpr float imag_impl <c10::complex<float>, float> (c10::complex<float> z) {
|
| 107 |
+
return z.imag();
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
template<>
|
| 111 |
+
constexpr c10::complex<double> imag_impl <c10::complex<double>> (c10::complex<double> z) {
|
| 112 |
+
return c10::complex<double>(z.imag(), 0.0);
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
template<>
|
| 116 |
+
constexpr double imag_impl <c10::complex<double>, double> (c10::complex<double> z) {
|
| 117 |
+
return z.imag();
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
template <typename TYPE>
|
| 121 |
+
inline TYPE conj_impl (TYPE z) {
|
| 122 |
+
return z; //No-Op
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
template<>
|
| 126 |
+
inline c10::complex<at::Half> conj_impl <c10::complex<at::Half>> (c10::complex<at::Half> z) {
|
| 127 |
+
return c10::complex<at::Half>{z.real(), -z.imag()};
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
template<>
|
| 131 |
+
inline c10::complex<float> conj_impl <c10::complex<float>> (c10::complex<float> z) {
|
| 132 |
+
return c10::complex<float>(z.real(), -z.imag());
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
template<>
|
| 136 |
+
inline c10::complex<double> conj_impl <c10::complex<double>> (c10::complex<double> z) {
|
| 137 |
+
return c10::complex<double>(z.real(), -z.imag());
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
template <typename TYPE>
|
| 141 |
+
inline TYPE ceil_impl (TYPE z) {
|
| 142 |
+
return std::ceil(z);
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
template <>
|
| 146 |
+
inline c10::complex<float> ceil_impl (c10::complex<float> z) {
|
| 147 |
+
return c10::complex<float>(std::ceil(z.real()), std::ceil(z.imag()));
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
template <>
|
| 151 |
+
inline c10::complex<double> ceil_impl (c10::complex<double> z) {
|
| 152 |
+
return c10::complex<double>(std::ceil(z.real()), std::ceil(z.imag()));
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
template<typename T>
|
| 156 |
+
inline c10::complex<T> sgn_impl (c10::complex<T> z) {
|
| 157 |
+
if (z == c10::complex<T>(0, 0)) {
|
| 158 |
+
return c10::complex<T>(0, 0);
|
| 159 |
+
} else {
|
| 160 |
+
return z / zabs(z);
|
| 161 |
+
}
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
template <typename TYPE>
|
| 165 |
+
inline TYPE floor_impl (TYPE z) {
|
| 166 |
+
return std::floor(z);
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
template <>
|
| 170 |
+
inline c10::complex<float> floor_impl (c10::complex<float> z) {
|
| 171 |
+
return c10::complex<float>(std::floor(z.real()), std::floor(z.imag()));
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
template <>
|
| 175 |
+
inline c10::complex<double> floor_impl (c10::complex<double> z) {
|
| 176 |
+
return c10::complex<double>(std::floor(z.real()), std::floor(z.imag()));
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
template <typename TYPE>
|
| 180 |
+
inline TYPE round_impl (TYPE z) {
|
| 181 |
+
return std::nearbyint(z);
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
template <>
|
| 185 |
+
inline c10::complex<float> round_impl (c10::complex<float> z) {
|
| 186 |
+
return c10::complex<float>(std::nearbyint(z.real()), std::nearbyint(z.imag()));
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
template <>
|
| 190 |
+
inline c10::complex<double> round_impl (c10::complex<double> z) {
|
| 191 |
+
return c10::complex<double>(std::nearbyint(z.real()), std::nearbyint(z.imag()));
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
template <typename TYPE>
|
| 195 |
+
inline TYPE trunc_impl (TYPE z) {
|
| 196 |
+
return std::trunc(z);
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
template <>
|
| 200 |
+
inline c10::complex<float> trunc_impl (c10::complex<float> z) {
|
| 201 |
+
return c10::complex<float>(std::trunc(z.real()), std::trunc(z.imag()));
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
template <>
|
| 205 |
+
inline c10::complex<double> trunc_impl (c10::complex<double> z) {
|
| 206 |
+
return c10::complex<double>(std::trunc(z.real()), std::trunc(z.imag()));
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
template <typename TYPE, std::enable_if_t<!c10::is_complex<TYPE>::value, int> = 0>
|
| 210 |
+
inline TYPE max_impl (TYPE a, TYPE b) {
|
| 211 |
+
if (_isnan<TYPE>(a) || _isnan<TYPE>(b)) {
|
| 212 |
+
return std::numeric_limits<TYPE>::quiet_NaN();
|
| 213 |
+
} else {
|
| 214 |
+
return std::max(a, b);
|
| 215 |
+
}
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
template <typename TYPE, std::enable_if_t<c10::is_complex<TYPE>::value, int> = 0>
|
| 219 |
+
inline TYPE max_impl (TYPE a, TYPE b) {
|
| 220 |
+
if (_isnan<TYPE>(a)) {
|
| 221 |
+
return a;
|
| 222 |
+
} else if (_isnan<TYPE>(b)) {
|
| 223 |
+
return b;
|
| 224 |
+
} else {
|
| 225 |
+
return std::abs(a) > std::abs(b) ? a : b;
|
| 226 |
+
}
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
template <typename TYPE, std::enable_if_t<!c10::is_complex<TYPE>::value, int> = 0>
|
| 230 |
+
inline TYPE min_impl (TYPE a, TYPE b) {
|
| 231 |
+
if (_isnan<TYPE>(a) || _isnan<TYPE>(b)) {
|
| 232 |
+
return std::numeric_limits<TYPE>::quiet_NaN();
|
| 233 |
+
} else {
|
| 234 |
+
return std::min(a, b);
|
| 235 |
+
}
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
template <typename TYPE, std::enable_if_t<c10::is_complex<TYPE>::value, int> = 0>
|
| 239 |
+
inline TYPE min_impl (TYPE a, TYPE b) {
|
| 240 |
+
if (_isnan<TYPE>(a)) {
|
| 241 |
+
return a;
|
| 242 |
+
} else if (_isnan<TYPE>(b)) {
|
| 243 |
+
return b;
|
| 244 |
+
} else {
|
| 245 |
+
return std::abs(a) < std::abs(b) ? a : b;
|
| 246 |
+
}
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
} // end namespace
|
| 250 |
+
} //end at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Activation.h
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/native/Activation.h>
|
| 3 |
+
#include <cstdint>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
struct TensorIteratorBase;
|
| 7 |
+
class TensorBase;
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
namespace at { namespace native {
|
| 11 |
+
|
| 12 |
+
void launch_glu_backward_kernel(const TensorIteratorBase& iter,
|
| 13 |
+
int64_t gI_stride, int64_t I_stride);
|
| 14 |
+
|
| 15 |
+
void launch_log_sigmoid_forward_kernel(TensorIteratorBase& iter);
|
| 16 |
+
|
| 17 |
+
void GeluCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate);
|
| 18 |
+
void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate);
|
| 19 |
+
|
| 20 |
+
}} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/BinaryInternal.h
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// DON'T include this except from Binary*.cu files. It should not leak into
|
| 2 |
+
// headers.
|
| 3 |
+
#pragma once
|
| 4 |
+
#define TORCH_ASSERT_NO_OPERATORS
|
| 5 |
+
#include <ATen/AccumulateType.h>
|
| 6 |
+
#include <ATen/Dispatch.h>
|
| 7 |
+
#include <ATen/native/BinaryOps.h>
|
| 8 |
+
#include <ATen/native/DispatchStub.h>
|
| 9 |
+
#include <ATen/native/TensorIterator.h>
|
| 10 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 11 |
+
#include <c10/cuda/CUDAMathCompat.h>
|
| 12 |
+
#include <c10/util/TypeSafeSignMath.h>
|
| 13 |
+
#include <ATen/native/cuda/JitLoops.cuh>
|
| 14 |
+
#include <ATen/native/cuda/Loops.cuh>
|
| 15 |
+
|
| 16 |
+
#include <type_traits>
|
| 17 |
+
|
| 18 |
+
namespace at {
|
| 19 |
+
namespace native {
|
| 20 |
+
namespace binary_internal {
|
| 21 |
+
|
| 22 |
+
template <typename scalar_t>
|
| 23 |
+
struct DivFunctor {
|
| 24 |
+
__device__ scalar_t operator()(scalar_t a, scalar_t b) const {
|
| 25 |
+
return a / b;
|
| 26 |
+
}
|
| 27 |
+
};
|
| 28 |
+
|
| 29 |
+
template <typename T>
|
| 30 |
+
struct MulFunctor {
|
| 31 |
+
__device__ T operator()(T a, T b) const {
|
| 32 |
+
return a * b;
|
| 33 |
+
}
|
| 34 |
+
};
|
| 35 |
+
|
| 36 |
+
// Workaround for the error: '*' in boolean context, suggest '&&' instead
|
| 37 |
+
// [-Werror=int-in-bool-context]
|
| 38 |
+
template <>
|
| 39 |
+
struct MulFunctor<bool> {
|
| 40 |
+
__device__ bool operator()(bool a, bool b) const {
|
| 41 |
+
return a && b;
|
| 42 |
+
}
|
| 43 |
+
};
|
| 44 |
+
void div_true_kernel_cuda(TensorIteratorBase& iter);
|
| 45 |
+
void div_trunc_kernel_cuda(TensorIteratorBase& iter);
|
| 46 |
+
} // namespace binary_internal
|
| 47 |
+
} // namespace native
|
| 48 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CUDAJitLoops.cuh
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/jit_macros.h>
|
| 3 |
+
|
| 4 |
+
// Jiterator functions are guarded behind this macro
|
| 5 |
+
#if AT_USE_JITERATOR()
|
| 6 |
+
|
| 7 |
+
#include <ATen/OpMathType.h>
|
| 8 |
+
#include <ATen/TensorIterator.h>
|
| 9 |
+
#include <ATen/core/Array.h>
|
| 10 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 11 |
+
#include <ATen/cuda/detail/OffsetCalculator.cuh>
|
| 12 |
+
#include <ATen/native/cuda/jit_utils.h>
|
| 13 |
+
#include <ATen/native/cuda/MemoryAccess.cuh>
|
| 14 |
+
#include <ATen/native/cuda/thread_constants.h>
|
| 15 |
+
|
| 16 |
+
#include <ATen/native/cuda/Loops.cuh>
|
| 17 |
+
|
| 18 |
+
#include <c10/macros/Macros.h>
|
| 19 |
+
#include <c10/core/ScalarType.h>
|
| 20 |
+
#include <c10/util/SmallBuffer.h>
|
| 21 |
+
|
| 22 |
+
#include <initializer_list>
|
| 23 |
+
#include <type_traits>
|
| 24 |
+
#include <tuple>
|
| 25 |
+
#include <mutex>
|
| 26 |
+
|
| 27 |
+
namespace at {
|
| 28 |
+
namespace native {
|
| 29 |
+
|
| 30 |
+
template <typename Tuple, std::size_t... I>
|
| 31 |
+
constexpr auto tuple_to_array_helper(Tuple& t, std::index_sequence<I...> seq) {
|
| 32 |
+
constexpr auto size = seq.size();
|
| 33 |
+
(void)t; // warning : unused parameter when tuple is empty.
|
| 34 |
+
return std::array<void*, size>{static_cast<void*>(&std::get<I>(t))...};
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
// Helper function convert tuple to std::array<void*, N>
|
| 38 |
+
// for passing the arguments to CUDA Kernel
|
| 39 |
+
// NOTE: We capture tuple by reference,
|
| 40 |
+
// so the pointers in returned array are only valid
|
| 41 |
+
// till tuple is alive.
|
| 42 |
+
template <typename ...Args>
|
| 43 |
+
constexpr auto tuple_to_array(std::tuple<Args...>& extra_args) {
|
| 44 |
+
constexpr auto tuple_size = sizeof...(Args);
|
| 45 |
+
return tuple_to_array_helper(extra_args, std::make_index_sequence<tuple_size>{});
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
struct JittedVecKernelCache {
|
| 49 |
+
// Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
|
| 50 |
+
at::cuda::jit::NvrtcFunction vec1;
|
| 51 |
+
at::cuda::jit::NvrtcFunction vec2;
|
| 52 |
+
at::cuda::jit::NvrtcFunction vec4;
|
| 53 |
+
};
|
| 54 |
+
|
| 55 |
+
struct JittedKernelVariantCache {
|
| 56 |
+
JittedVecKernelCache vec;
|
| 57 |
+
at::cuda::jit::NvrtcFunction noncontiguous;
|
| 58 |
+
at::cuda::jit::NvrtcFunction dynamic_contiguous;
|
| 59 |
+
at::cuda::jit::NvrtcFunction dynamic_noncontiguous;
|
| 60 |
+
};
|
| 61 |
+
|
| 62 |
+
inline c10::SmallBuffer<void*, 64> pack_kernel_args(
|
| 63 |
+
std::initializer_list<void*> args,
|
| 64 |
+
c10::ArrayRef<void*> extra_args) {
|
| 65 |
+
c10::SmallBuffer<void*, 64> ret(args.size() + extra_args.size());
|
| 66 |
+
std::copy(args.begin(), args.end(), ret.data());
|
| 67 |
+
std::copy(extra_args.begin(), extra_args.end(), ret.data() + args.size());
|
| 68 |
+
return ret;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
template<typename array_t,
|
| 72 |
+
typename inp_calc_t,
|
| 73 |
+
typename out_calc_t,
|
| 74 |
+
typename loader_t,
|
| 75 |
+
typename storer_t>
|
| 76 |
+
void launch_jitted_unrolled_kernel(
|
| 77 |
+
std::mutex &jiterator_mutex,
|
| 78 |
+
at::cuda::jit::NvrtcFunction &fn_cache,
|
| 79 |
+
const at::cuda::jit::KernelDescriptor &desc,
|
| 80 |
+
int64_t N,
|
| 81 |
+
array_t data,
|
| 82 |
+
inp_calc_t ic,
|
| 83 |
+
out_calc_t oc,
|
| 84 |
+
loader_t l,
|
| 85 |
+
storer_t s,
|
| 86 |
+
bool contiguous,
|
| 87 |
+
at::cuda::jit::BinaryFuncVariant scalar_pos,
|
| 88 |
+
void* scalar_val,
|
| 89 |
+
c10::ArrayRef<void*> extra_args) {
|
| 90 |
+
|
| 91 |
+
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
|
| 92 |
+
//casting result to int is always safe, intermediate is int64 and won't overflow
|
| 93 |
+
const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
|
| 94 |
+
|
| 95 |
+
if (!fn_cache.function) {
|
| 96 |
+
const std::lock_guard<std::mutex> lock{jiterator_mutex};
|
| 97 |
+
if (!fn_cache.function) {
|
| 98 |
+
constexpr bool dynamic_casting = !std::is_same<decltype(l), memory::LoadWithoutCast>() ||
|
| 99 |
+
!std::is_same<decltype(s), memory::StoreWithoutCast>();
|
| 100 |
+
auto code = at::cuda::jit::generate_code(
|
| 101 |
+
desc, contiguous, dynamic_casting, scalar_pos);
|
| 102 |
+
fn_cache = at::cuda::jit::jit_pwise_function(code, desc.name);
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
auto args = pack_kernel_args({&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args);
|
| 107 |
+
at::cuda::jit::launch_jitted_pwise_function(fn_cache, args.data(), {grid, 1u, 1u},
|
| 108 |
+
{num_threads(), 1u, 1u});
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
template<int arity, typename array_t>
|
| 112 |
+
void launch_jitted_vectorized_kernel(
|
| 113 |
+
std::mutex &jiterator_mutex, JittedVecKernelCache &fn_cache,
|
| 114 |
+
const at::cuda::jit::KernelDescriptor &desc, int64_t N, array_t data,
|
| 115 |
+
at::cuda::jit::BinaryFuncVariant scalar_pos,
|
| 116 |
+
void *scalar_val, c10::ArrayRef<void*> extra_args) {
|
| 117 |
+
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
|
| 118 |
+
// N is still int64_t for the computation, but it's always safe to cast result to int
|
| 119 |
+
const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
|
| 120 |
+
const int vec_size = at::cuda::jit::can_vectorize_up_to(
|
| 121 |
+
desc, c10::ArrayRef<char*>(data.data, data.size()));
|
| 122 |
+
|
| 123 |
+
// Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
|
| 124 |
+
// fn_ptr is set to the appropriate function based on the vec size and GPU used
|
| 125 |
+
at::cuda::jit::NvrtcFunction* fn_ptr;
|
| 126 |
+
if (vec_size == 4) {
|
| 127 |
+
fn_ptr = &fn_cache.vec4;
|
| 128 |
+
} else if (vec_size == 2) {
|
| 129 |
+
fn_ptr = &fn_cache.vec2;
|
| 130 |
+
} else if (vec_size ==1) {
|
| 131 |
+
fn_ptr = &fn_cache.vec1;
|
| 132 |
+
} else {
|
| 133 |
+
TORCH_INTERNAL_ASSERT(false, "unexpected vec_size for jitter vectorized kernel");
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
bool vectorized = vec_size > 1;
|
| 137 |
+
|
| 138 |
+
if (!fn_ptr->function) {
|
| 139 |
+
const std::lock_guard<std::mutex> lock{jiterator_mutex};
|
| 140 |
+
if (!fn_ptr->function) { // cache miss!
|
| 141 |
+
|
| 142 |
+
// Generates program
|
| 143 |
+
auto code = at::cuda::jit::generate_code(
|
| 144 |
+
desc, /*contiguous=*/true, /*dynamic_casting=*/false,
|
| 145 |
+
scalar_pos, vectorized, vec_size);
|
| 146 |
+
std::string kernel_name = vectorized ? desc.name + "_vectorized" + std::to_string(vec_size) : desc.name;
|
| 147 |
+
|
| 148 |
+
// Acquires the program
|
| 149 |
+
*fn_ptr = at::cuda::jit::jit_pwise_function(code, kernel_name);
|
| 150 |
+
}
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
if (vectorized) {
|
| 154 |
+
auto args = pack_kernel_args({&N, &data, scalar_val}, extra_args);
|
| 155 |
+
at::cuda::jit::launch_jitted_pwise_function(
|
| 156 |
+
*fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
|
| 157 |
+
} else {
|
| 158 |
+
// NVCC complains about unused variables l and s.
|
| 159 |
+
// It should be false positive in most cases, so we suppress the warnings.
|
| 160 |
+
#pragma nv_diagnostic push
|
| 161 |
+
#pragma nv_diag_suppress 177
|
| 162 |
+
auto ic = TrivialOffsetCalculator<arity>();
|
| 163 |
+
auto oc = TrivialOffsetCalculator<1>();
|
| 164 |
+
auto l = memory::LoadWithoutCast();
|
| 165 |
+
auto s = memory::StoreWithoutCast();
|
| 166 |
+
|
| 167 |
+
auto args = pack_kernel_args(
|
| 168 |
+
{&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args);
|
| 169 |
+
at::cuda::jit::launch_jitted_pwise_function(
|
| 170 |
+
*fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
|
| 171 |
+
#pragma nv_diagnostic pop
|
| 172 |
+
}
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
template <int arity>
|
| 176 |
+
void jitted_gpu_kernel_generic(
|
| 177 |
+
std::mutex &jiterator_mutex,
|
| 178 |
+
JittedKernelVariantCache &cache,
|
| 179 |
+
const at::cuda::jit::KernelDescriptor &desc,
|
| 180 |
+
at::cuda::jit::BinaryFuncVariant scalar_pos,
|
| 181 |
+
c10::ArrayRef<void*> extra_args,
|
| 182 |
+
TensorIteratorBase& iter,
|
| 183 |
+
const bool dynamic_casting,
|
| 184 |
+
void *scalar_val) {
|
| 185 |
+
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
|
| 186 |
+
TORCH_INTERNAL_ASSERT(iter.ninputs() == arity);
|
| 187 |
+
TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
|
| 188 |
+
|
| 189 |
+
constexpr int ntensors = arity + 1;
|
| 190 |
+
at::detail::Array<char*, ntensors> data;
|
| 191 |
+
for (auto i : c10::irange(ntensors)) {
|
| 192 |
+
data[i] = (char*)iter.data_ptr(i);
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
int64_t numel = iter.numel();
|
| 196 |
+
bool contiguous = iter.is_contiguous();
|
| 197 |
+
|
| 198 |
+
// Decides which of 4 kernel types to launch
|
| 199 |
+
// Variations are:
|
| 200 |
+
// - Case 1: no dynamic casting and contiguous
|
| 201 |
+
// - Case 2: no dynamic casting and noncontiguous
|
| 202 |
+
// - Case 3: dynamic casting and contiguous
|
| 203 |
+
// - Case 4: dynamic casting and noncontiguous
|
| 204 |
+
// These cases align with the non-jitted CUDALoops.cuh cases in gpu_kernel_impl
|
| 205 |
+
|
| 206 |
+
if (!dynamic_casting) {
|
| 207 |
+
if (contiguous) {
|
| 208 |
+
// Case 1: no dynamic casting and contiguous
|
| 209 |
+
launch_jitted_vectorized_kernel<arity>(
|
| 210 |
+
jiterator_mutex, cache.vec, desc,
|
| 211 |
+
numel, data, scalar_pos, scalar_val, extra_args);
|
| 212 |
+
return;
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
// Case 2: no dynamic casting and noncontiguous
|
| 216 |
+
auto input_offset_calculator = make_input_offset_calculator<arity>(iter);
|
| 217 |
+
auto output_offset_calculator = make_output_offset_calculator(iter);
|
| 218 |
+
auto loader = memory::LoadWithoutCast();
|
| 219 |
+
auto storer = memory::StoreWithoutCast();
|
| 220 |
+
launch_jitted_unrolled_kernel(
|
| 221 |
+
jiterator_mutex, cache.noncontiguous, desc, numel, data,
|
| 222 |
+
input_offset_calculator, output_offset_calculator, loader,
|
| 223 |
+
storer, contiguous, scalar_pos, scalar_val, extra_args);
|
| 224 |
+
return;
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
// Cases 3 and 4 are handled below
|
| 228 |
+
// Both require construction of a storer (this asserts 1 output) and one or more loaders
|
| 229 |
+
|
| 230 |
+
// Creates store cast to output (the zeroth tensor in TensorIterator)
|
| 231 |
+
auto storer = memory::StoreWithCast<1>(iter);
|
| 232 |
+
|
| 233 |
+
// Creates load casts from inputs (note offset indexing into the iterators 1...n tensors)
|
| 234 |
+
auto loader = memory::LoadWithCast<arity>(iter);
|
| 235 |
+
|
| 236 |
+
if (contiguous) {
|
| 237 |
+
// Case 3: dynamic casting and contiguous
|
| 238 |
+
auto input_offset_calculator = TrivialOffsetCalculator<arity>();
|
| 239 |
+
auto output_offset_calculator = TrivialOffsetCalculator<1>();
|
| 240 |
+
launch_jitted_unrolled_kernel(
|
| 241 |
+
jiterator_mutex, cache.dynamic_contiguous, desc, numel, data, input_offset_calculator,
|
| 242 |
+
output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args);
|
| 243 |
+
return;
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
// Case 4: dynamic casting and noncontiguous
|
| 247 |
+
auto input_offset_calculator = make_input_offset_calculator<arity>(iter);
|
| 248 |
+
auto output_offset_calculator = make_output_offset_calculator(iter);
|
| 249 |
+
launch_jitted_unrolled_kernel(
|
| 250 |
+
jiterator_mutex, cache.dynamic_noncontiguous, desc, numel, data, input_offset_calculator,
|
| 251 |
+
output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args);
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
// NOTE: static to reduce chances of name collision.
|
| 255 |
+
template <
|
| 256 |
+
char const* name,
|
| 257 |
+
typename result_type,
|
| 258 |
+
typename f_inputs_type,
|
| 259 |
+
int arity,
|
| 260 |
+
at::cuda::jit::BinaryFuncVariant scalar_pos =
|
| 261 |
+
at::cuda::jit::BinaryFuncVariant::NoScalar,
|
| 262 |
+
typename... ExtraArgs>
|
| 263 |
+
static void jitted_gpu_kernel_impl(
|
| 264 |
+
TensorIteratorBase& iter,
|
| 265 |
+
const std::string &f,
|
| 266 |
+
const bool dynamic_casting,
|
| 267 |
+
at::opmath_type<f_inputs_type> scalar_val,
|
| 268 |
+
std::tuple<ExtraArgs...> extra_args) {
|
| 269 |
+
|
| 270 |
+
// TODO: Memory use can probably be optimized by re-using kernels across GPUs with
|
| 271 |
+
// the same compute capability
|
| 272 |
+
static std::mutex jiterator_mutex;
|
| 273 |
+
static std::vector<JittedKernelVariantCache> device_caches(c10::cuda::device_count());
|
| 274 |
+
|
| 275 |
+
constexpr int nInputs = arity;
|
| 276 |
+
constexpr int nOutputs = 1; // TODO: Support more than 1 output
|
| 277 |
+
static const auto desc = at::cuda::jit::make_kernel_descriptor<
|
| 278 |
+
result_type, f_inputs_type, ExtraArgs...>(name, f, nInputs, nOutputs);
|
| 279 |
+
|
| 280 |
+
auto &cache = device_caches[iter.device().index()];
|
| 281 |
+
auto extra_args_array = tuple_to_array(extra_args);
|
| 282 |
+
return jitted_gpu_kernel_generic<arity>(
|
| 283 |
+
jiterator_mutex,
|
| 284 |
+
cache,
|
| 285 |
+
desc,
|
| 286 |
+
scalar_pos,
|
| 287 |
+
extra_args_array,
|
| 288 |
+
iter,
|
| 289 |
+
dynamic_casting,
|
| 290 |
+
&scalar_val
|
| 291 |
+
);
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
}} // at::native
|
| 295 |
+
|
| 296 |
+
#endif // AT_USE_JITERATOR()
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CUDALoops.cuh
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// This file provides two functions to help write GPU elementwise kernels:
|
| 4 |
+
//
|
| 5 |
+
// gpu_kernel(TensorIterator iter, <lambda>)
|
| 6 |
+
// gpu_kernel_with_scalars(TensorIterator iter, <lambda>)
|
| 7 |
+
//
|
| 8 |
+
// The gpu_kernel_with_scalars generates specializations that support a
|
| 9 |
+
// single scalar CPU argument, such as from `cuda_tensor + 5`. The CPU scalar
|
| 10 |
+
// is lifted to a kernel parameter instead of copying to device memory.
|
| 11 |
+
// This should be used in conjunction with TensorIterator::allow_cpu_scalars_,
|
| 12 |
+
// which is the default for TensorIterator::binary_op. Otherwise, all inputs
|
| 13 |
+
// and the output must be on the GPU.
|
| 14 |
+
//
|
| 15 |
+
// For example, to write a reciprocal kernel for GPU float Tensors:
|
| 16 |
+
//
|
| 17 |
+
// gpu_kernel(iter, []GPU_LAMBDA(float a) {
|
| 18 |
+
// return 1.0f / a;
|
| 19 |
+
// });
|
| 20 |
+
//
|
| 21 |
+
// To write a multiplication kernel for GPU float Tensors where one argument
|
| 22 |
+
// may be a CPU scalar:
|
| 23 |
+
//
|
| 24 |
+
// gpu_kernel_with_scalars(iter, []GPU_LAMBDA(float a, float b) {
|
| 25 |
+
// return a * b;
|
| 26 |
+
// });
|
| 27 |
+
//
|
| 28 |
+
// See BinaryOpsKernel.cu for the complete implementation
|
| 29 |
+
//
|
| 30 |
+
|
| 31 |
+
#include <iostream>
|
| 32 |
+
#include <tuple>
|
| 33 |
+
#include <type_traits>
|
| 34 |
+
|
| 35 |
+
#include <ATen/core/Array.h>
|
| 36 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 37 |
+
#include <ATen/detail/FunctionTraits.h>
|
| 38 |
+
#include <ATen/native/TensorIterator.h>
|
| 39 |
+
#include <c10/core/DynamicCast.h>
|
| 40 |
+
#include <c10/core/ScalarType.h>
|
| 41 |
+
#include <c10/macros/Macros.h>
|
| 42 |
+
#include <c10/util/TypeCast.h>
|
| 43 |
+
|
| 44 |
+
#ifdef __NVCC__
|
| 45 |
+
#define ASSERT_HOST_DEVICE_LAMBDA(type) \
|
| 46 |
+
static_assert( \
|
| 47 |
+
__nv_is_extended_host_device_lambda_closure_type(type), \
|
| 48 |
+
#type " must be a __host__ __device__ lambda")
|
| 49 |
+
#else
|
| 50 |
+
#define ASSERT_HOST_DEVICE_LAMBDA(type)
|
| 51 |
+
#endif
|
| 52 |
+
|
| 53 |
+
namespace at {
|
| 54 |
+
namespace native {
|
| 55 |
+
|
| 56 |
+
template <int vec_size, typename func_t, typename array_t>
|
| 57 |
+
C10_LAUNCH_BOUNDS_1(num_threads())
|
| 58 |
+
__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
|
| 59 |
+
using traits = function_traits<func_t>;
|
| 60 |
+
int remaining = N - block_work_size() * blockIdx.x;
|
| 61 |
+
|
| 62 |
+
if (remaining < block_work_size()) { // if this block handles the reminder,
|
| 63 |
+
// just do a naive unrolled loop
|
| 64 |
+
auto input_calc = TrivialOffsetCalculator<traits::arity>();
|
| 65 |
+
auto output_calc = TrivialOffsetCalculator<1>();
|
| 66 |
+
auto loader = memory::LoadWithoutCast();
|
| 67 |
+
auto storer = memory::StoreWithoutCast();
|
| 68 |
+
auto policy = memory::policies::unroll<
|
| 69 |
+
array_t,
|
| 70 |
+
decltype(input_calc),
|
| 71 |
+
decltype(output_calc),
|
| 72 |
+
memory::LoadWithoutCast,
|
| 73 |
+
memory::StoreWithoutCast>(
|
| 74 |
+
data, remaining, input_calc, output_calc, loader, storer);
|
| 75 |
+
elementwise_kernel_helper(f, policy);
|
| 76 |
+
} else { // if this block has a full `block_work_size` data to handle, use
|
| 77 |
+
// vectorized memory access
|
| 78 |
+
elementwise_kernel_helper(
|
| 79 |
+
f, memory::policies::vectorized<vec_size, array_t>(data));
|
| 80 |
+
}
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
template <
|
| 84 |
+
typename func_t,
|
| 85 |
+
typename array_t,
|
| 86 |
+
typename inp_calc_t,
|
| 87 |
+
typename out_calc_t,
|
| 88 |
+
typename loader_t,
|
| 89 |
+
typename storer_t>
|
| 90 |
+
C10_LAUNCH_BOUNDS_1(num_threads())
|
| 91 |
+
__global__ void unrolled_elementwise_kernel(
|
| 92 |
+
int N,
|
| 93 |
+
func_t f,
|
| 94 |
+
array_t data,
|
| 95 |
+
inp_calc_t ic,
|
| 96 |
+
out_calc_t oc,
|
| 97 |
+
loader_t l,
|
| 98 |
+
storer_t s) {
|
| 99 |
+
int remaining = N - block_work_size() * blockIdx.x;
|
| 100 |
+
auto policy = memory::policies::
|
| 101 |
+
unroll<array_t, inp_calc_t, out_calc_t, loader_t, storer_t>(
|
| 102 |
+
data, remaining, ic, oc, l, s);
|
| 103 |
+
elementwise_kernel_helper(f, policy);
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
// this function assume trivial 1d and no dynamic casting
|
| 107 |
+
template <typename func_t, typename array_t>
|
| 108 |
+
static inline void launch_vectorized_kernel(
|
| 109 |
+
int64_t N,
|
| 110 |
+
const func_t& f,
|
| 111 |
+
array_t data) {
|
| 112 |
+
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
|
| 113 |
+
using traits = function_traits<func_t>;
|
| 114 |
+
int64_t grid = (N + block_work_size() - 1) / block_work_size();
|
| 115 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
| 116 |
+
int vec_size = memory::can_vectorize_up_to<func_t>(data);
|
| 117 |
+
|
| 118 |
+
switch (vec_size) {
|
| 119 |
+
case 4:
|
| 120 |
+
vectorized_elementwise_kernel<4, func_t, array_t>
|
| 121 |
+
<<<grid, num_threads(), 0, stream>>>(N, f, data);
|
| 122 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 123 |
+
break;
|
| 124 |
+
case 2:
|
| 125 |
+
vectorized_elementwise_kernel<2, func_t, array_t>
|
| 126 |
+
<<<grid, num_threads(), 0, stream>>>(N, f, data);
|
| 127 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 128 |
+
break;
|
| 129 |
+
case 1: {
|
| 130 |
+
auto input_calc = TrivialOffsetCalculator<traits::arity>();
|
| 131 |
+
auto output_calc = TrivialOffsetCalculator<1>();
|
| 132 |
+
auto loader = memory::LoadWithoutCast();
|
| 133 |
+
auto storer = memory::StoreWithoutCast();
|
| 134 |
+
unrolled_elementwise_kernel<func_t, array_t>
|
| 135 |
+
<<<grid, num_threads(), 0, stream>>>(
|
| 136 |
+
N, f, data, input_calc, output_calc, loader, storer);
|
| 137 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 138 |
+
break;
|
| 139 |
+
}
|
| 140 |
+
default:
|
| 141 |
+
TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size");
|
| 142 |
+
}
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
template <
|
| 146 |
+
typename func_t,
|
| 147 |
+
typename array_t,
|
| 148 |
+
typename inp_calc_t,
|
| 149 |
+
typename out_calc_t,
|
| 150 |
+
typename loader_t,
|
| 151 |
+
typename storer_t>
|
| 152 |
+
static inline void launch_unrolled_kernel(
|
| 153 |
+
int64_t N,
|
| 154 |
+
const func_t& f,
|
| 155 |
+
array_t data,
|
| 156 |
+
inp_calc_t ic,
|
| 157 |
+
out_calc_t oc,
|
| 158 |
+
loader_t l,
|
| 159 |
+
storer_t s) {
|
| 160 |
+
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
|
| 161 |
+
int64_t grid = (N + block_work_size() - 1) / block_work_size();
|
| 162 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
| 163 |
+
unrolled_elementwise_kernel<func_t, array_t>
|
| 164 |
+
<<<grid, num_threads(), 0, stream>>>(N, f, data, ic, oc, l, s);
|
| 165 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
template <int nt, int vt, typename func_t>
|
| 169 |
+
C10_LAUNCH_BOUNDS_2(nt, 4)
|
| 170 |
+
__global__ void elementwise_kernel(int N, func_t f) {
|
| 171 |
+
int tid = threadIdx.x;
|
| 172 |
+
int nv = nt * vt;
|
| 173 |
+
int idx = nv * blockIdx.x + tid;
|
| 174 |
+
#pragma unroll
|
| 175 |
+
for (int i = 0; i < vt; i++) {
|
| 176 |
+
if (idx < N) {
|
| 177 |
+
f(idx);
|
| 178 |
+
idx += nt;
|
| 179 |
+
}
|
| 180 |
+
}
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
template <int nt, int vt, typename func_t>
|
| 184 |
+
static void launch_legacy_kernel(int64_t N, const func_t& f) {
|
| 185 |
+
TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
|
| 186 |
+
if (N == 0) {
|
| 187 |
+
return;
|
| 188 |
+
}
|
| 189 |
+
dim3 block(nt);
|
| 190 |
+
dim3 grid((N + block.x * vt - 1) / (block.x * vt));
|
| 191 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
| 192 |
+
elementwise_kernel<nt, vt, func_t><<<grid, block, 0, stream>>>(N, f);
|
| 193 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
template <typename traits, typename func_t, typename index_t, size_t... INDEX>
|
| 197 |
+
C10_HOST_DEVICE typename traits::result_type invoke_impl(
|
| 198 |
+
const func_t& f,
|
| 199 |
+
char* const C10_RESTRICT data[],
|
| 200 |
+
const index_t strides[],
|
| 201 |
+
int i,
|
| 202 |
+
std::index_sequence<INDEX...>) {
|
| 203 |
+
(void)strides;
|
| 204 |
+
(void)i;
|
| 205 |
+
return f(c10::load<typename traits::template arg<INDEX>::type>(
|
| 206 |
+
data[INDEX] + i * strides[INDEX])...);
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
template <
|
| 210 |
+
typename func_t,
|
| 211 |
+
typename index_t,
|
| 212 |
+
typename traits = function_traits<func_t>>
|
| 213 |
+
C10_HOST_DEVICE typename traits::result_type invoke(
|
| 214 |
+
const func_t& f,
|
| 215 |
+
char* const C10_RESTRICT data[],
|
| 216 |
+
const index_t strides[],
|
| 217 |
+
int i) {
|
| 218 |
+
using Indices = std::make_index_sequence<traits::arity>;
|
| 219 |
+
return invoke_impl<traits>(f, data, strides, i, Indices{});
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
template <typename traits, typename func_t, typename index_t, size_t... I>
|
| 223 |
+
C10_HOST_DEVICE typename traits::result_type invoke_impl(
|
| 224 |
+
const func_t& f,
|
| 225 |
+
char* const C10_RESTRICT data[],
|
| 226 |
+
const index_t strides[],
|
| 227 |
+
const ScalarType dtypes[],
|
| 228 |
+
int i,
|
| 229 |
+
std::index_sequence<I...>) {
|
| 230 |
+
(void)strides;
|
| 231 |
+
(void)i;
|
| 232 |
+
return f(c10::fetch_and_cast<typename traits::template arg<I>::type>(
|
| 233 |
+
dtypes[I], data[I] + i * strides[I])...);
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
template <
|
| 237 |
+
typename func_t,
|
| 238 |
+
typename index_t,
|
| 239 |
+
typename traits = function_traits<func_t>>
|
| 240 |
+
C10_HOST_DEVICE typename traits::result_type invoke(
|
| 241 |
+
const func_t& f,
|
| 242 |
+
char* const C10_RESTRICT data[],
|
| 243 |
+
const index_t strides[],
|
| 244 |
+
const ScalarType dtypes[],
|
| 245 |
+
int i) {
|
| 246 |
+
using Indices = std::make_index_sequence<traits::arity>;
|
| 247 |
+
return invoke_impl<traits>(f, data, strides, dtypes, i, Indices{});
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
template <typename func_t>
|
| 251 |
+
void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) {
|
| 252 |
+
using traits = function_traits<func_t>;
|
| 253 |
+
using arg0_t = typename traits::result_type;
|
| 254 |
+
constexpr int ntensors = traits::arity + 1;
|
| 255 |
+
|
| 256 |
+
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
|
| 257 |
+
TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
|
| 258 |
+
TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
|
| 259 |
+
TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
|
| 260 |
+
|
| 261 |
+
at::detail::Array<char*, ntensors> data;
|
| 262 |
+
for (int i = 0; i < ntensors; i++) {
|
| 263 |
+
data[i] = (char*)iter.data_ptr(i);
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
int64_t numel = iter.numel();
|
| 267 |
+
|
| 268 |
+
bool contiguous = iter.is_contiguous();
|
| 269 |
+
|
| 270 |
+
if (contiguous) {
|
| 271 |
+
return launch_vectorized_kernel(numel, f, data);
|
| 272 |
+
}
|
| 273 |
+
auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
|
| 274 |
+
constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 2 : 4;
|
| 275 |
+
launch_legacy_kernel<128, unroll_factor>(numel, [=] GPU_LAMBDA(int idx) {
|
| 276 |
+
auto offsets = offset_calc.get(idx);
|
| 277 |
+
arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
|
| 278 |
+
*out = invoke(f, &data.data[1], &offsets.data[1], 1);
|
| 279 |
+
});
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
template <typename func_t>
|
| 283 |
+
void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
|
| 284 |
+
if (!needs_dynamic_casting<func_t>::check(iter)) {
|
| 285 |
+
return gpu_kernel_impl_nocast(iter, f);
|
| 286 |
+
}
|
| 287 |
+
using traits = function_traits<func_t>;
|
| 288 |
+
using arg0_t = typename traits::result_type;
|
| 289 |
+
constexpr int ntensors = traits::arity + 1;
|
| 290 |
+
|
| 291 |
+
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
|
| 292 |
+
TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
|
| 293 |
+
TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
|
| 294 |
+
|
| 295 |
+
at::detail::Array<char*, ntensors> data;
|
| 296 |
+
for (int i = 0; i < ntensors; i++) {
|
| 297 |
+
data[i] = (char*)iter.data_ptr(i);
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
int64_t numel = iter.numel();
|
| 301 |
+
|
| 302 |
+
bool contiguous = iter.is_contiguous();
|
| 303 |
+
|
| 304 |
+
if (contiguous) {
|
| 305 |
+
#ifdef USE_ROCM
|
| 306 |
+
at::detail::Array<ScalarType, ntensors> dtypes;
|
| 307 |
+
auto inner_strides = iter.get_inner_strides();
|
| 308 |
+
at::detail::Array<int, ntensors> strides;
|
| 309 |
+
for (int i = 0; i < ntensors; i++) {
|
| 310 |
+
dtypes[i] = iter.dtype(i);
|
| 311 |
+
strides[i] = inner_strides[i];
|
| 312 |
+
}
|
| 313 |
+
launch_legacy_kernel<512, 1>(numel, [=]GPU_LAMBDA(int idx) {
|
| 314 |
+
void* out = data[0] + strides[0] * idx;
|
| 315 |
+
arg0_t result = invoke(f, &data.data[1], &strides.data[1], &dtypes.data[1], idx);
|
| 316 |
+
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
|
| 317 |
+
});
|
| 318 |
+
#else
|
| 319 |
+
auto loader = memory::LoadWithCast<traits::arity>(iter);
|
| 320 |
+
auto storer = memory::StoreWithCast<1>(iter);
|
| 321 |
+
auto input_offset_calculator = TrivialOffsetCalculator<traits::arity>();
|
| 322 |
+
auto output_offset_calculator = TrivialOffsetCalculator<1>();
|
| 323 |
+
launch_unrolled_kernel(
|
| 324 |
+
numel,
|
| 325 |
+
f,
|
| 326 |
+
data,
|
| 327 |
+
input_offset_calculator,
|
| 328 |
+
output_offset_calculator,
|
| 329 |
+
loader,
|
| 330 |
+
storer);
|
| 331 |
+
#endif
|
| 332 |
+
} else {
|
| 333 |
+
at::detail::Array<ScalarType, ntensors> dtypes;
|
| 334 |
+
for (int i = 0; i < ntensors; i++) {
|
| 335 |
+
dtypes[i] = iter.dtype(i);
|
| 336 |
+
}
|
| 337 |
+
auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
|
| 338 |
+
launch_legacy_kernel<128, 4>(numel, [=] GPU_LAMBDA(int idx) {
|
| 339 |
+
auto offsets = offset_calc.get(idx);
|
| 340 |
+
void* out = data[0] + offsets[0];
|
| 341 |
+
arg0_t result = invoke(f, &data.data[1], &offsets.data[1], &dtypes.data[1], 1);
|
| 342 |
+
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
|
| 343 |
+
});
|
| 344 |
+
}
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
} // namespace native
|
| 348 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CompositeRandomAccessor.h
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/native/CompositeRandomAccessorCommon.h>
|
| 4 |
+
#include <thrust/tuple.h>
|
| 5 |
+
|
| 6 |
+
namespace at { namespace native {
|
| 7 |
+
|
| 8 |
+
struct TupleInfoCPU {
|
| 9 |
+
template <typename ...Types>
|
| 10 |
+
using tuple = thrust::tuple<Types...>;
|
| 11 |
+
|
| 12 |
+
template <typename ...Types>
|
| 13 |
+
static constexpr auto tie(Types&... args) noexcept {
|
| 14 |
+
return thrust::tie(args...);
|
| 15 |
+
}
|
| 16 |
+
};
|
| 17 |
+
|
| 18 |
+
template <typename KeyAccessor, typename ValueAccessor>
|
| 19 |
+
using CompositeRandomAccessorCPU =
|
| 20 |
+
CompositeRandomAccessor<KeyAccessor, ValueAccessor, TupleInfoCPU>;
|
| 21 |
+
|
| 22 |
+
template <typename Values, typename References>
|
| 23 |
+
void swap(
|
| 24 |
+
references_holder<Values, References> rh1,
|
| 25 |
+
references_holder<Values, References> rh2
|
| 26 |
+
) {
|
| 27 |
+
return thrust::swap(rh1.data(), rh2.data());
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
template <int N, typename Values, typename References>
|
| 31 |
+
auto get(references_holder<Values, References> rh) -> decltype(thrust::get<N>(rh.data())) {
|
| 32 |
+
return thrust::get<N>(rh.data());
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
}} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Copy.h
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
namespace at {
|
| 4 |
+
struct TensorIteratorBase;
|
| 5 |
+
|
| 6 |
+
namespace native {
|
| 7 |
+
|
| 8 |
+
void direct_copy_kernel_cuda(TensorIteratorBase &iter);
|
| 9 |
+
|
| 10 |
+
}} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CuFFTUtils.h
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/Config.h>
|
| 4 |
+
|
| 5 |
+
#include <string>
|
| 6 |
+
#include <stdexcept>
|
| 7 |
+
#include <sstream>
|
| 8 |
+
#include <cufft.h>
|
| 9 |
+
#include <cufftXt.h>
|
| 10 |
+
|
| 11 |
+
namespace at { namespace native {
|
| 12 |
+
|
| 13 |
+
// This means that max dim is 3 + 2 = 5 with batch dimension and possible
|
| 14 |
+
// complex dimension
|
| 15 |
+
constexpr int max_rank = 3;
|
| 16 |
+
|
| 17 |
+
static inline std::string _cudaGetErrorEnum(cufftResult error)
|
| 18 |
+
{
|
| 19 |
+
switch (error)
|
| 20 |
+
{
|
| 21 |
+
case CUFFT_SUCCESS:
|
| 22 |
+
return "CUFFT_SUCCESS";
|
| 23 |
+
case CUFFT_INVALID_PLAN:
|
| 24 |
+
return "CUFFT_INVALID_PLAN";
|
| 25 |
+
case CUFFT_ALLOC_FAILED:
|
| 26 |
+
return "CUFFT_ALLOC_FAILED";
|
| 27 |
+
case CUFFT_INVALID_TYPE:
|
| 28 |
+
return "CUFFT_INVALID_TYPE";
|
| 29 |
+
case CUFFT_INVALID_VALUE:
|
| 30 |
+
return "CUFFT_INVALID_VALUE";
|
| 31 |
+
case CUFFT_INTERNAL_ERROR:
|
| 32 |
+
return "CUFFT_INTERNAL_ERROR";
|
| 33 |
+
case CUFFT_EXEC_FAILED:
|
| 34 |
+
return "CUFFT_EXEC_FAILED";
|
| 35 |
+
case CUFFT_SETUP_FAILED:
|
| 36 |
+
return "CUFFT_SETUP_FAILED";
|
| 37 |
+
case CUFFT_INVALID_SIZE:
|
| 38 |
+
return "CUFFT_INVALID_SIZE";
|
| 39 |
+
case CUFFT_UNALIGNED_DATA:
|
| 40 |
+
return "CUFFT_UNALIGNED_DATA";
|
| 41 |
+
case CUFFT_INCOMPLETE_PARAMETER_LIST:
|
| 42 |
+
return "CUFFT_INCOMPLETE_PARAMETER_LIST";
|
| 43 |
+
case CUFFT_INVALID_DEVICE:
|
| 44 |
+
return "CUFFT_INVALID_DEVICE";
|
| 45 |
+
case CUFFT_PARSE_ERROR:
|
| 46 |
+
return "CUFFT_PARSE_ERROR";
|
| 47 |
+
case CUFFT_NO_WORKSPACE:
|
| 48 |
+
return "CUFFT_NO_WORKSPACE";
|
| 49 |
+
case CUFFT_NOT_IMPLEMENTED:
|
| 50 |
+
return "CUFFT_NOT_IMPLEMENTED";
|
| 51 |
+
#if !defined(USE_ROCM)
|
| 52 |
+
case CUFFT_LICENSE_ERROR:
|
| 53 |
+
return "CUFFT_LICENSE_ERROR";
|
| 54 |
+
#endif
|
| 55 |
+
case CUFFT_NOT_SUPPORTED:
|
| 56 |
+
return "CUFFT_NOT_SUPPORTED";
|
| 57 |
+
default:
|
| 58 |
+
std::ostringstream ss;
|
| 59 |
+
ss << "unknown error " << error;
|
| 60 |
+
return ss.str();
|
| 61 |
+
}
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
static inline void CUFFT_CHECK(cufftResult error)
|
| 65 |
+
{
|
| 66 |
+
if (error != CUFFT_SUCCESS) {
|
| 67 |
+
std::ostringstream ss;
|
| 68 |
+
ss << "cuFFT error: " << _cudaGetErrorEnum(error);
|
| 69 |
+
AT_ERROR(ss.str());
|
| 70 |
+
}
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
}} // at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/DeviceSqrt.cuh
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
namespace at { namespace native {
|
| 4 |
+
#if defined(USE_ROCM)
|
| 5 |
+
// take these out when ROCm implements std:: math functions
|
| 6 |
+
#include <math.h>
|
| 7 |
+
template <typename scalar_t>
|
| 8 |
+
static __forceinline__ __device__ scalar_t device_sqrt(scalar_t val);
|
| 9 |
+
|
| 10 |
+
template <>
|
| 11 |
+
__forceinline__ __device__ float device_sqrt(float val) {
|
| 12 |
+
return ::sqrtf(val);
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
template <>
|
| 16 |
+
__forceinline__ __device__ double device_sqrt(double val) {
|
| 17 |
+
return ::sqrt(val);
|
| 18 |
+
}
|
| 19 |
+
#else
|
| 20 |
+
template<typename scalar_t>
|
| 21 |
+
__forceinline__ __device__ double device_sqrt(scalar_t val) {
|
| 22 |
+
return std::sqrt(val);
|
| 23 |
+
}
|
| 24 |
+
#endif
|
| 25 |
+
}}
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/DistributionTemplates.h
ADDED
|
@@ -0,0 +1,671 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/AccumulateType.h>
|
| 4 |
+
#include <ATen/Dispatch.h>
|
| 5 |
+
#include <ATen/Dispatch_v2.h>
|
| 6 |
+
#include <ATen/ExpandBase.h>
|
| 7 |
+
#include <ATen/OpMathType.h>
|
| 8 |
+
#include <ATen/native/TensorIterator.h>
|
| 9 |
+
#include <ATen/native/cuda/Loops.cuh>
|
| 10 |
+
#include <c10/util/Half.h>
|
| 11 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
| 12 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 13 |
+
#include <ATen/cuda/detail/OffsetCalculator.cuh>
|
| 14 |
+
#include <ATen/cuda/CUDAGraphsUtils.cuh>
|
| 15 |
+
#include <ATen/detail/FunctionTraits.h>
|
| 16 |
+
#include <ATen/core/DistributionsHelper.h>
|
| 17 |
+
|
| 18 |
+
#include <curand.h>
|
| 19 |
+
#include <curand_kernel.h>
|
| 20 |
+
#include <curand_philox4x32_x.h>
|
| 21 |
+
#include <cstdint>
|
| 22 |
+
#include <limits>
|
| 23 |
+
#include <utility>
|
| 24 |
+
#include <mutex>
|
| 25 |
+
#include <tuple>
|
| 26 |
+
#include <type_traits>
|
| 27 |
+
|
| 28 |
+
namespace at {
|
| 29 |
+
namespace native {
|
| 30 |
+
namespace {
|
| 31 |
+
|
| 32 |
+
// launch bounds used for kernels utilizing TensorIterator
|
| 33 |
+
const uint32_t block_size_bound = 256;
|
| 34 |
+
const uint32_t grid_size_bound = 4;
|
| 35 |
+
// At the time of writing, there is no curand_* call that increments the offset by more than 4.
|
| 36 |
+
// See: https://docs.nvidia.com/cuda/archive/11.8.0/curand/group__DEVICE.html
|
| 37 |
+
const uint32_t max_generator_offsets_per_curand_call = 4;
|
| 38 |
+
|
| 39 |
+
// utility function that calculates proper philox_offset
|
| 40 |
+
// for distributions utilizing TensorIterator. For distributions using
|
| 41 |
+
// TensorIterator, we are using a grid-stride loop with each
|
| 42 |
+
// thread yielding one element per thread. For the edge of the grid-stride
|
| 43 |
+
// loop, if the tensor size is large, the unroll loop will kick in and the float4
|
| 44 |
+
// from curand4 will start getting utilized (for common tensor sizes, we end up
|
| 45 |
+
// using rand.x from each thread). The philox_offset calculation was changed to
|
| 46 |
+
// (number of elements per thread * maximum generator increment per "curand_*" call), which makes
|
| 47 |
+
// sure that philox offset increment is not less than the number of randoms used
|
| 48 |
+
// in each thread.
|
| 49 |
+
std::tuple<uint64_t, dim3, dim3> calc_execution_policy(const int64_t total_elements, const uint32_t unroll_factor) {
|
| 50 |
+
const uint64_t numel = static_cast<uint64_t>(total_elements);
|
| 51 |
+
const uint32_t block_size = block_size_bound;
|
| 52 |
+
dim3 dim_block(block_size);
|
| 53 |
+
dim3 grid((numel + block_size - 1) / block_size);
|
| 54 |
+
uint32_t blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size;
|
| 55 |
+
grid.x = std::min(
|
| 56 |
+
static_cast<uint32_t>(at::cuda::getCurrentDeviceProperties()->multiProcessorCount) * blocks_per_sm,
|
| 57 |
+
grid.x);
|
| 58 |
+
//number of times random will be generated per thread, to offset philox counter in thc random state
|
| 59 |
+
uint64_t counter_offset = ((numel - 1) / (block_size * grid.x * unroll_factor) + 1) * max_generator_offsets_per_curand_call;
|
| 60 |
+
return std::make_tuple(counter_offset, grid, dim_block);
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
// grid stride loop kernel for distributions
|
| 64 |
+
template<typename accscalar_t, int unroll_factor, typename dist_t, typename transform_t>
|
| 65 |
+
C10_LAUNCH_BOUNDS_2(block_size_bound, grid_size_bound)
|
| 66 |
+
__global__ void distribution_elementwise_grid_stride_kernel(int numel,
|
| 67 |
+
PhiloxCudaState philox_args,
|
| 68 |
+
const dist_t dist_func,
|
| 69 |
+
const transform_t transform_func) {
|
| 70 |
+
auto seeds = at::cuda::philox::unpack(philox_args);
|
| 71 |
+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 72 |
+
curandStatePhilox4_32_10_t state;
|
| 73 |
+
curand_init(std::get<0>(seeds),
|
| 74 |
+
idx,
|
| 75 |
+
std::get<1>(seeds),
|
| 76 |
+
&state);
|
| 77 |
+
|
| 78 |
+
int rounded_size = ((numel - 1)/(blockDim.x * gridDim.x * unroll_factor)+1) *
|
| 79 |
+
blockDim.x * gridDim.x * unroll_factor;
|
| 80 |
+
for(int linear_index = idx; linear_index < rounded_size; linear_index += blockDim.x * gridDim.x * unroll_factor) {
|
| 81 |
+
auto rand = dist_func(&state);
|
| 82 |
+
#pragma unroll
|
| 83 |
+
for (int ii = 0; ii < unroll_factor; ii++) {
|
| 84 |
+
int li = linear_index + blockDim.x * gridDim.x * ii;
|
| 85 |
+
if (li < numel) {
|
| 86 |
+
transform_func(li, static_cast<accscalar_t>((&rand.x)[ii]));
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
__syncthreads();
|
| 90 |
+
}
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
/**
|
| 94 |
+
* distribution_nullary_kernel is analogous to gpu_kernel in
|
| 95 |
+
* ATen/native/cuda/Loops.cuh. Like gpu_kernel, it uses
|
| 96 |
+
* TensorIterator to launch a kernel. However, the differences are
|
| 97 |
+
* - it launches a grid-stride loop based kernel. The kernel is not
|
| 98 |
+
* generic like elementwise_kernel in Loops.cuh and is specialized
|
| 99 |
+
* for the distribution kernels here.
|
| 100 |
+
* - For big size tensors, we can launch multiple kernels recursively
|
| 101 |
+
* (i.e. if (!iter.can_use_32bit_indexing())) and hence, the philox
|
| 102 |
+
* offset calculation is done in this function.
|
| 103 |
+
*
|
| 104 |
+
* FIXME: Can we specialize elementwise_kernel and launch_kernel in Loops.cuh
|
| 105 |
+
* to have grid-stride loop kernel and then use that to launch our distribution
|
| 106 |
+
* kernels? Note that we need a grid-stride loop kernel because, we found by testing
|
| 107 |
+
* that it achieves peak effective bandwidth.
|
| 108 |
+
*/
|
| 109 |
+
template<typename scalar_t,
|
| 110 |
+
typename accscalar_t,
|
| 111 |
+
typename dist_func_return_t,
|
| 112 |
+
typename RNG,
|
| 113 |
+
typename dist_t,
|
| 114 |
+
typename transform_t>
|
| 115 |
+
void distribution_nullary_kernel(at::TensorIteratorBase& iter,
|
| 116 |
+
RNG gen,
|
| 117 |
+
const dist_t& dist_func,
|
| 118 |
+
const transform_t transform_func) {
|
| 119 |
+
const int unroll_factor = sizeof(dist_func_return_t) / sizeof(accscalar_t);
|
| 120 |
+
TORCH_CHECK(unroll_factor >= 1, "unroll_factor must be >= 1.");
|
| 121 |
+
int64_t numel = iter.numel();
|
| 122 |
+
if (numel == 0) {
|
| 123 |
+
return;
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
auto execution_policy = calc_execution_policy(numel, unroll_factor);
|
| 127 |
+
auto counter_offset = std::get<0>(execution_policy);
|
| 128 |
+
auto grid = std::get<1>(execution_policy);
|
| 129 |
+
auto block = std::get<2>(execution_policy);
|
| 130 |
+
PhiloxCudaState rng_engine_inputs;
|
| 131 |
+
{
|
| 132 |
+
// See Note [Acquire lock when using random generators]
|
| 133 |
+
std::lock_guard<std::mutex> lock(gen->mutex_);
|
| 134 |
+
rng_engine_inputs = gen->philox_cuda_state(counter_offset);
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
if (!iter.can_use_32bit_indexing()) {
|
| 138 |
+
for (auto& sub_iter : iter.with_32bit_indexing()) {
|
| 139 |
+
distribution_nullary_kernel<scalar_t, accscalar_t, dist_func_return_t>(sub_iter,
|
| 140 |
+
gen, dist_func, transform_func);
|
| 141 |
+
}
|
| 142 |
+
return;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
char* out_data = (char*)iter.data_ptr(0);
|
| 146 |
+
|
| 147 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
| 148 |
+
if (iter.is_trivial_1d()) {
|
| 149 |
+
auto strides = iter.get_inner_strides();
|
| 150 |
+
int stride0 = strides[0];
|
| 151 |
+
distribution_elementwise_grid_stride_kernel<accscalar_t, unroll_factor><<<grid, block, 0, stream>>>(
|
| 152 |
+
numel,
|
| 153 |
+
rng_engine_inputs,
|
| 154 |
+
dist_func,
|
| 155 |
+
[=]__device__(int idx, accscalar_t rand) {
|
| 156 |
+
scalar_t* out = (scalar_t*)&out_data[stride0 * idx];
|
| 157 |
+
*out = transform_func(rand);
|
| 158 |
+
}
|
| 159 |
+
);
|
| 160 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 161 |
+
} else {
|
| 162 |
+
auto offset_calc = make_offset_calculator<1>(iter);
|
| 163 |
+
distribution_elementwise_grid_stride_kernel<accscalar_t, unroll_factor><<<grid, block, 0, stream>>>(
|
| 164 |
+
numel,
|
| 165 |
+
rng_engine_inputs,
|
| 166 |
+
dist_func,
|
| 167 |
+
[=]__device__(int idx, accscalar_t rand) {
|
| 168 |
+
auto offsets = offset_calc.get(idx);
|
| 169 |
+
scalar_t* out = (scalar_t*)&out_data[offsets[0]];
|
| 170 |
+
*out = transform_func(rand);
|
| 171 |
+
}
|
| 172 |
+
);
|
| 173 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 174 |
+
}
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
// Binary kernel
|
| 178 |
+
template <typename func_t, typename inp_offset_calc_t, typename out_offset_calc_t>
|
| 179 |
+
__global__ void distribution_binary_elementwise_kernel(
|
| 180 |
+
int numel,
|
| 181 |
+
func_t f,
|
| 182 |
+
PhiloxCudaState philox_args,
|
| 183 |
+
typename function_traits<func_t>::result_type *output_data,
|
| 184 |
+
const typename function_traits<func_t>::template arg<1>::type *input_data_1,
|
| 185 |
+
const typename function_traits<func_t>::template arg<2>::type *input_data_2,
|
| 186 |
+
inp_offset_calc_t inp_calc,
|
| 187 |
+
out_offset_calc_t out_calc) {
|
| 188 |
+
auto seeds = at::cuda::philox::unpack(philox_args);
|
| 189 |
+
|
| 190 |
+
using input_t_1 = typename function_traits<func_t>::template arg<1>::type;
|
| 191 |
+
using input_t_2 = typename function_traits<func_t>::template arg<2>::type;
|
| 192 |
+
|
| 193 |
+
input_t_1 inputs_1[thread_work_size()];
|
| 194 |
+
input_t_2 inputs_2[thread_work_size()];
|
| 195 |
+
|
| 196 |
+
int base_index = block_work_size() * blockIdx.x;
|
| 197 |
+
int remaining = std::min<int>(numel - base_index, block_work_size());
|
| 198 |
+
|
| 199 |
+
curandStatePhilox4_32_10_t state;
|
| 200 |
+
curand_init(std::get<0>(seeds),
|
| 201 |
+
blockIdx.x * blockDim.x + threadIdx.x,
|
| 202 |
+
std::get<1>(seeds),
|
| 203 |
+
&state);
|
| 204 |
+
|
| 205 |
+
// load data into registers
|
| 206 |
+
int thread_idx = threadIdx.x;
|
| 207 |
+
#pragma unroll
|
| 208 |
+
for (int i = 0; i < thread_work_size(); i++) {
|
| 209 |
+
if (thread_idx >= remaining) {
|
| 210 |
+
break;
|
| 211 |
+
}
|
| 212 |
+
int input_idx = thread_idx + base_index;
|
| 213 |
+
auto offsets = inp_calc.get(input_idx);
|
| 214 |
+
inputs_1[i] = input_data_1[offsets[0]];
|
| 215 |
+
inputs_2[i] = input_data_2[offsets[1]];
|
| 216 |
+
|
| 217 |
+
thread_idx += num_threads();
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
// compute and store
|
| 221 |
+
thread_idx = threadIdx.x;
|
| 222 |
+
#pragma unroll
|
| 223 |
+
for (int i = 0; i < thread_work_size(); i++) {
|
| 224 |
+
if (thread_idx >= remaining) {
|
| 225 |
+
break;
|
| 226 |
+
}
|
| 227 |
+
int input_idx = thread_idx + base_index;
|
| 228 |
+
auto offsets = out_calc.get(input_idx);
|
| 229 |
+
output_data[offsets[0]] = f(state, inputs_1[i], inputs_2[i]);
|
| 230 |
+
thread_idx += num_threads();
|
| 231 |
+
}
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
template <typename func_t>
|
| 235 |
+
void distribution_binary_kernel(TensorIteratorBase &iter, PhiloxCudaState philox_args, const func_t &f) {
|
| 236 |
+
static_assert(std::is_same<typename function_traits<func_t>::template arg<0>::type, curandStatePhilox4_32_10_t&>::value, "the first argument of functor must be curandStatePhilox4_32_10_t");
|
| 237 |
+
using input_t_1 = typename function_traits<func_t>::template arg<1>::type;
|
| 238 |
+
using input_t_2 = typename function_traits<func_t>::template arg<2>::type;
|
| 239 |
+
using output_t = typename function_traits<func_t>::result_type;
|
| 240 |
+
|
| 241 |
+
if (!iter.can_use_32bit_indexing()) {
|
| 242 |
+
for (auto& sub_iter : iter.with_32bit_indexing()) {
|
| 243 |
+
distribution_binary_kernel(sub_iter, philox_args, f);
|
| 244 |
+
}
|
| 245 |
+
return;
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(iter.can_use_32bit_indexing());
|
| 249 |
+
|
| 250 |
+
int64_t numel = iter.numel();
|
| 251 |
+
if (numel == 0) {
|
| 252 |
+
return;
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
output_t *output_data = static_cast<output_t *>(iter.data_ptr(0));
|
| 256 |
+
const input_t_1 *input_data_1 = static_cast<const input_t_1 *>(iter.data_ptr(1));
|
| 257 |
+
const input_t_2 *input_data_2 = static_cast<const input_t_2 *>(iter.data_ptr(2));
|
| 258 |
+
|
| 259 |
+
int64_t grid = (numel + block_work_size() - 1) / block_work_size();
|
| 260 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
| 261 |
+
|
| 262 |
+
if (iter.is_contiguous()) {
|
| 263 |
+
distribution_binary_elementwise_kernel<<<grid,num_threads(), 0, stream>>>(
|
| 264 |
+
numel, f, philox_args, output_data, input_data_1, input_data_2,
|
| 265 |
+
TrivialOffsetCalculator<2>(), TrivialOffsetCalculator<1>());
|
| 266 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 267 |
+
} else {
|
| 268 |
+
distribution_binary_elementwise_kernel<<<grid, num_threads(), 0, stream>>>(
|
| 269 |
+
numel, f, philox_args, output_data, input_data_1, input_data_2,
|
| 270 |
+
make_input_offset_calculator<2>(iter), make_output_offset_calculator(iter));
|
| 271 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 272 |
+
}
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
} // namespace
|
| 276 |
+
}} // namespace at::native
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
namespace at {
|
| 280 |
+
namespace native {
|
| 281 |
+
namespace templates {
|
| 282 |
+
namespace cuda {
|
| 283 |
+
|
| 284 |
+
// ==================================================== Random ========================================================
|
| 285 |
+
|
| 286 |
+
template<typename RNG>
|
| 287 |
+
void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG gen) {
|
| 288 |
+
AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cuda", AT_WRAP([&] {
|
| 289 |
+
if ((
|
| 290 |
+
std::is_same<scalar_t, int64_t>::value ||
|
| 291 |
+
std::is_same<scalar_t, double>::value ||
|
| 292 |
+
std::is_same<scalar_t, float>::value ||
|
| 293 |
+
std::is_same<scalar_t, at::BFloat16>::value) && range >= 1ULL << 32)
|
| 294 |
+
{
|
| 295 |
+
// define lambda to mod with range and add base
|
| 296 |
+
auto random_func = [range, base] __device__ (uint64_t rand) {
|
| 297 |
+
return transformation::uniform_int_from_to<scalar_t>(rand, range, base);
|
| 298 |
+
};
|
| 299 |
+
distribution_nullary_kernel<scalar_t, uint64_t, ulonglong2>(iter,
|
| 300 |
+
gen,
|
| 301 |
+
[] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
|
| 302 |
+
ulonglong2 ret;
|
| 303 |
+
uint4 rand_val = curand4(state);
|
| 304 |
+
ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
|
| 305 |
+
ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
|
| 306 |
+
return ret;
|
| 307 |
+
},
|
| 308 |
+
random_func);
|
| 309 |
+
} else {
|
| 310 |
+
auto random_func = [range, base] __device__ (uint32_t rand) {
|
| 311 |
+
return transformation::uniform_int_from_to<scalar_t>(rand, range, base);
|
| 312 |
+
};
|
| 313 |
+
distribution_nullary_kernel<scalar_t, uint32_t, uint4>(iter,
|
| 314 |
+
gen,
|
| 315 |
+
[] __device__ (curandStatePhilox4_32_10_t* state) -> uint4 {
|
| 316 |
+
return curand4(state);
|
| 317 |
+
},
|
| 318 |
+
random_func);
|
| 319 |
+
}
|
| 320 |
+
}), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
// This is the special kernel to handle single specific case:
|
| 324 |
+
// from(inclusive) = std::numeric_limits<int64_t>::lowest()
|
| 325 |
+
// to(exclusive) = None (= std::numeric_limits<int64_t>::max() + 1)
|
| 326 |
+
template<typename RNG>
|
| 327 |
+
void random_full_64_bits_range_kernel(TensorIteratorBase& iter, RNG gen) {
|
| 328 |
+
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, iter.dtype(), "random_full_64_bits_range_kernel_cuda", [&] {
|
| 329 |
+
if (std::is_same<scalar_t, int64_t>::value ||
|
| 330 |
+
std::is_same<scalar_t, double>::value ||
|
| 331 |
+
std::is_same<scalar_t, float>::value ||
|
| 332 |
+
std::is_same<scalar_t, at::BFloat16>::value) {
|
| 333 |
+
auto random_func = [] __device__ (uint64_t rand) {
|
| 334 |
+
return transformation::uniform_int_full_range<scalar_t>(rand);
|
| 335 |
+
};
|
| 336 |
+
distribution_nullary_kernel<scalar_t, uint64_t, ulonglong2>(iter,
|
| 337 |
+
gen,
|
| 338 |
+
[] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
|
| 339 |
+
ulonglong2 ret;
|
| 340 |
+
uint4 rand_val = curand4(state);
|
| 341 |
+
ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
|
| 342 |
+
ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
|
| 343 |
+
return ret;
|
| 344 |
+
},
|
| 345 |
+
random_func);
|
| 346 |
+
} else {
|
| 347 |
+
TORCH_CHECK(false, "random_full_64_bits_range_kernel_cuda handles only int64, double, float and bfloat16");
|
| 348 |
+
}
|
| 349 |
+
});
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
template<typename RNG>
|
| 353 |
+
struct RandomFromToKernel {
|
| 354 |
+
void operator()(TensorIteratorBase& iter, uint64_t range, int64_t base, std::optional<Generator> gen) {
|
| 355 |
+
random_from_to_kernel(iter, range, base, check_generator<RNG>(gen));
|
| 356 |
+
}
|
| 357 |
+
void operator()(TensorIteratorBase& iter, std::optional<Generator> gen) {
|
| 358 |
+
random_full_64_bits_range_kernel(iter, check_generator<RNG>(gen));
|
| 359 |
+
}
|
| 360 |
+
};
|
| 361 |
+
|
| 362 |
+
template<typename RNG>
|
| 363 |
+
void random_kernel(TensorIteratorBase& iter, RNG gen) {
|
| 364 |
+
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cuda", [&] {
|
| 365 |
+
if (std::is_same<scalar_t, double>::value || std::is_same<scalar_t, int64_t>::value) {
|
| 366 |
+
auto random_func = [] __device__ (uint64_t rand) {
|
| 367 |
+
return transformation::uniform_int<scalar_t>(rand);
|
| 368 |
+
};
|
| 369 |
+
distribution_nullary_kernel<scalar_t, uint64_t, ulonglong2>(iter, gen,
|
| 370 |
+
[] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
|
| 371 |
+
ulonglong2 ret;
|
| 372 |
+
uint4 rand_val = curand4(state);
|
| 373 |
+
ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
|
| 374 |
+
ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
|
| 375 |
+
return ret;
|
| 376 |
+
},
|
| 377 |
+
random_func);
|
| 378 |
+
} else {
|
| 379 |
+
auto random_func = [] __device__ (uint32_t rand) {
|
| 380 |
+
return transformation::uniform_int<scalar_t>(rand);
|
| 381 |
+
};
|
| 382 |
+
distribution_nullary_kernel<scalar_t, uint32_t, uint4>(iter,
|
| 383 |
+
gen,
|
| 384 |
+
[] __device__ (curandStatePhilox4_32_10_t* state) -> uint4 {
|
| 385 |
+
return curand4(state);
|
| 386 |
+
},
|
| 387 |
+
random_func);
|
| 388 |
+
}
|
| 389 |
+
});
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
template<typename RNG>
|
| 393 |
+
struct RandomKernel {
|
| 394 |
+
void operator()(TensorIteratorBase& iter, RNG gen) {
|
| 395 |
+
random_kernel(iter, gen);
|
| 396 |
+
}
|
| 397 |
+
};
|
| 398 |
+
|
| 399 |
+
// ====================================================================================================================
|
| 400 |
+
|
| 401 |
+
template<typename scalar_t, typename accscalar_t, typename RNG, typename transform_t>
|
| 402 |
+
void uniform_and_transform(TensorIteratorBase& iter, RNG gen, transform_t transform) {
|
| 403 |
+
if (std::is_same<scalar_t, double>::value) {
|
| 404 |
+
distribution_nullary_kernel<scalar_t, accscalar_t, double2>(iter,
|
| 405 |
+
gen,
|
| 406 |
+
[] __device__ (curandStatePhilox4_32_10_t* state) -> double2 { return curand_uniform2_double(state); },
|
| 407 |
+
transform);
|
| 408 |
+
} else {
|
| 409 |
+
distribution_nullary_kernel<scalar_t, accscalar_t, float4>(iter,
|
| 410 |
+
gen,
|
| 411 |
+
[] __device__ (curandStatePhilox4_32_10_t* state) -> float4 { return curand_uniform4(state); },
|
| 412 |
+
transform);
|
| 413 |
+
}
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
template<typename scalar_t, typename accscalar_t, typename RNG, typename transform_t>
|
| 417 |
+
void normal_and_transform(TensorIteratorBase& iter, RNG gen, transform_t transform) {
|
| 418 |
+
if (std::is_same<scalar_t, double>::value) {
|
| 419 |
+
distribution_nullary_kernel<scalar_t, accscalar_t, double2>(iter,
|
| 420 |
+
gen,
|
| 421 |
+
[] __device__ (curandStatePhilox4_32_10_t* state) -> double2 { return curand_normal2_double(state); },
|
| 422 |
+
transform);
|
| 423 |
+
} else {
|
| 424 |
+
distribution_nullary_kernel<scalar_t, accscalar_t, float4>(iter,
|
| 425 |
+
gen,
|
| 426 |
+
[] __device__ (curandStatePhilox4_32_10_t* state) -> float4 { return curand_normal4(state); },
|
| 427 |
+
transform);
|
| 428 |
+
}
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
// ==================================================== Normal ========================================================
|
| 432 |
+
|
| 433 |
+
template<typename RNG>
|
| 434 |
+
void normal_kernel(const TensorBase &self, double mean_, double std_, RNG gen) {
|
| 435 |
+
auto iter = TensorIterator::borrowing_nullary_op(self);
|
| 436 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "normal_kernel_cuda", [&] {
|
| 437 |
+
using accscalar_t = at::acc_type<scalar_t, true>;
|
| 438 |
+
auto mean = static_cast<accscalar_t>(mean_);
|
| 439 |
+
auto std = static_cast<accscalar_t>(std_);
|
| 440 |
+
// define lambda to multiply std and add mean
|
| 441 |
+
auto normal_func = [mean, std] __device__ (accscalar_t rand) {
|
| 442 |
+
return static_cast<scalar_t>(transformation::normal<accscalar_t>(rand, mean, std));
|
| 443 |
+
};
|
| 444 |
+
normal_and_transform<scalar_t, accscalar_t>(iter, gen, normal_func);
|
| 445 |
+
});
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
template<typename RNG>
|
| 449 |
+
struct NormalKernel {
|
| 450 |
+
void operator()(const TensorBase &self, double mean, double std, std::optional<Generator> gen) {
|
| 451 |
+
normal_kernel(self, mean, std, check_generator<RNG>(gen));
|
| 452 |
+
}
|
| 453 |
+
};
|
| 454 |
+
|
| 455 |
+
// ==================================================== Uniform ========================================================
|
| 456 |
+
|
| 457 |
+
template<typename RNG>
|
| 458 |
+
void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG gen) {
|
| 459 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "uniform_kernel_cuda", [&] {
|
| 460 |
+
auto from = static_cast<scalar_t>(from_);
|
| 461 |
+
auto to = static_cast<scalar_t>(to_);
|
| 462 |
+
using opmath_t = at::opmath_type<scalar_t>;
|
| 463 |
+
auto range = static_cast<opmath_t>(to-from);
|
| 464 |
+
// define lambda to reverse bounds, multiply 'range' and add 'from_'
|
| 465 |
+
auto uniform_func = [range, from, to] __device__ (opmath_t rand) {
|
| 466 |
+
// Compute output value before reversing the bounds
|
| 467 |
+
// BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/96947
|
| 468 |
+
auto value = static_cast<scalar_t>(rand * range + from);
|
| 469 |
+
// reverse the bounds of curand4 from (0, 1] to [0, 1)
|
| 470 |
+
// Note that this method is from legacy THCTensorRandom and is likely to give
|
| 471 |
+
// you more 0-s, since, the probability of gettings 1-s is higher than 0-s and
|
| 472 |
+
// by reversing the bounds, we are flipping the probabilities of 1-s and 0-s.
|
| 473 |
+
// BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/16706
|
| 474 |
+
auto reverse_bound_value = value == to ? from : value;
|
| 475 |
+
return reverse_bound_value;
|
| 476 |
+
};
|
| 477 |
+
uniform_and_transform<scalar_t, opmath_t>(iter, gen, uniform_func);
|
| 478 |
+
});
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
template<typename RNG>
|
| 482 |
+
struct UniformKernel {
|
| 483 |
+
void operator()(TensorIteratorBase& iter, double from, double to, std::optional<Generator> gen) {
|
| 484 |
+
uniform_kernel(iter, from, to, check_generator<RNG>(gen));
|
| 485 |
+
}
|
| 486 |
+
};
|
| 487 |
+
|
| 488 |
+
// ================================================== LogNormal =======================================================
|
| 489 |
+
|
| 490 |
+
template<typename RNG>
|
| 491 |
+
void log_normal_kernel(TensorIteratorBase& iter, double mean_, double std_, RNG gen) {
|
| 492 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "log_normal_cuda", [&] {
|
| 493 |
+
using accscalar_t = at::acc_type<scalar_t, true>;
|
| 494 |
+
auto mean = static_cast<accscalar_t>(mean_);
|
| 495 |
+
auto std = static_cast<accscalar_t>(std_);
|
| 496 |
+
// define lambda for log_normal transformation
|
| 497 |
+
auto log_normal_func = [mean, std] __device__ (accscalar_t rand) {
|
| 498 |
+
return static_cast<scalar_t>(transformation::log_normal<accscalar_t>(transformation::normal<accscalar_t>(rand, mean, std)));
|
| 499 |
+
};
|
| 500 |
+
normal_and_transform<scalar_t, accscalar_t>(iter, gen, log_normal_func);
|
| 501 |
+
});
|
| 502 |
+
}
|
| 503 |
+
|
| 504 |
+
template<typename RNG>
|
| 505 |
+
struct LogNormalKernel {
|
| 506 |
+
void operator()(TensorIteratorBase& iter, double mean, double std, std::optional<Generator> gen) {
|
| 507 |
+
log_normal_kernel(iter, mean, std, check_generator<RNG>(gen));
|
| 508 |
+
}
|
| 509 |
+
};
|
| 510 |
+
|
| 511 |
+
// =================================================== Geometric ======================================================
|
| 512 |
+
|
| 513 |
+
template<typename RNG>
|
| 514 |
+
void geometric_kernel(TensorIteratorBase& iter, double p, RNG gen) {
|
| 515 |
+
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_cuda", [&] {
|
| 516 |
+
using accscalar_t = at::DiscreteDistributionType<scalar_t>::type;
|
| 517 |
+
// define lambda for geometric transformation
|
| 518 |
+
auto geometric_func = [p] __device__ (accscalar_t rand) {
|
| 519 |
+
return static_cast<scalar_t>(transformation::geometric<accscalar_t>(rand, p));
|
| 520 |
+
};
|
| 521 |
+
uniform_and_transform<scalar_t, accscalar_t>(iter, gen, geometric_func);
|
| 522 |
+
});
|
| 523 |
+
}
|
| 524 |
+
|
| 525 |
+
template<typename RNG>
|
| 526 |
+
struct GeometricKernel {
|
| 527 |
+
void operator()(TensorIteratorBase& iter, double p, std::optional<Generator> gen) {
|
| 528 |
+
geometric_kernel(iter, p, check_generator<RNG>(gen));
|
| 529 |
+
}
|
| 530 |
+
};
|
| 531 |
+
|
| 532 |
+
// ================================================== Exponential =====================================================
|
| 533 |
+
|
| 534 |
+
template<typename RNG>
|
| 535 |
+
void exponential_kernel(TensorIteratorBase& iter, double lambda_, RNG gen) {
|
| 536 |
+
TORCH_CHECK(isFloatingType(iter.dtype()), "Exponential distribution is a continuous probability distribution. dtype must be a floating point but you specified ", iter.dtype());
|
| 537 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_cuda", [&] {
|
| 538 |
+
using accscalar_t = at::acc_type<scalar_t, true>;
|
| 539 |
+
auto lambda = static_cast<accscalar_t>(lambda_);
|
| 540 |
+
// define lambda for exponential transformation
|
| 541 |
+
auto exponential_func = [lambda] __device__ (accscalar_t rand) {
|
| 542 |
+
return static_cast<scalar_t>(transformation::exponential<accscalar_t>(rand, lambda));
|
| 543 |
+
};
|
| 544 |
+
uniform_and_transform<scalar_t, accscalar_t>(iter, gen, exponential_func);
|
| 545 |
+
});
|
| 546 |
+
}
|
| 547 |
+
|
| 548 |
+
template<typename RNG>
|
| 549 |
+
struct ExponentialKernel {
|
| 550 |
+
void operator()(TensorIteratorBase& iter, double lambda, std::optional<Generator> gen) {
|
| 551 |
+
exponential_kernel(iter, lambda, check_generator<RNG>(gen));
|
| 552 |
+
}
|
| 553 |
+
};
|
| 554 |
+
|
| 555 |
+
// ==================================================== Cauchy ========================================================
|
| 556 |
+
|
| 557 |
+
template<typename RNG>
|
| 558 |
+
void cauchy_kernel(TensorIteratorBase& iter, double median_, double sigma_, RNG gen) {
|
| 559 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "cauchy_cuda", [&] {
|
| 560 |
+
using accscalar_t = at::acc_type<scalar_t, true>;
|
| 561 |
+
auto median = static_cast<accscalar_t>(median_);
|
| 562 |
+
auto sigma = static_cast<accscalar_t>(sigma_);
|
| 563 |
+
// define lambda for cauchy transformation
|
| 564 |
+
auto cauchy_func = [median, sigma] __device__ (accscalar_t rand) {
|
| 565 |
+
return static_cast<scalar_t>(transformation::cauchy<accscalar_t>(rand, median, sigma));
|
| 566 |
+
};
|
| 567 |
+
uniform_and_transform<scalar_t, accscalar_t>(iter, gen, cauchy_func);
|
| 568 |
+
});
|
| 569 |
+
}
|
| 570 |
+
|
| 571 |
+
template<typename RNG>
|
| 572 |
+
struct CauchyKernel {
|
| 573 |
+
void operator()(TensorIteratorBase& iter, double median, double sigma, std::optional<Generator> gen) {
|
| 574 |
+
cauchy_kernel(iter, median, sigma, check_generator<RNG>(gen));
|
| 575 |
+
}
|
| 576 |
+
};
|
| 577 |
+
|
| 578 |
+
// ==================================================== Bernoulli =====================================================
|
| 579 |
+
|
| 580 |
+
template<typename scalar_t, typename prob_t>
|
| 581 |
+
void bernoulli_tensor_cuda_kernel(
|
| 582 |
+
const TensorBase &ret, const at::TensorBase &p,
|
| 583 |
+
PhiloxCudaState philox_args) {
|
| 584 |
+
auto functor = [philox_args] __device__(
|
| 585 |
+
int n, scalar_t& v1, scalar_t& v2, scalar_t& v3, scalar_t& v4,
|
| 586 |
+
const prob_t& p1, const prob_t& p2, const prob_t& p3, const prob_t& p4) {
|
| 587 |
+
auto seeds = at::cuda::philox::unpack(philox_args);
|
| 588 |
+
curandStatePhilox4_32_10_t state;
|
| 589 |
+
curand_init(std::get<0>(seeds),
|
| 590 |
+
blockIdx.x * blockDim.x + threadIdx.x,
|
| 591 |
+
std::get<1>(seeds),
|
| 592 |
+
&state);
|
| 593 |
+
|
| 594 |
+
// See Note [Register spilling in curand call for CUDA < 10]
|
| 595 |
+
float4 rand = curand_uniform4(&state);
|
| 596 |
+
switch (n) {
|
| 597 |
+
case 4: {
|
| 598 |
+
CUDA_KERNEL_ASSERT(0 <= p4 && p4 <= 1);
|
| 599 |
+
v4 = static_cast<scalar_t>(rand.w <= p4);
|
| 600 |
+
[[fallthrough]];
|
| 601 |
+
}
|
| 602 |
+
case 3: {
|
| 603 |
+
CUDA_KERNEL_ASSERT(0 <= p3 && p3 <= 1);
|
| 604 |
+
v3 = static_cast<scalar_t>(rand.z <= p3);
|
| 605 |
+
[[fallthrough]];
|
| 606 |
+
}
|
| 607 |
+
case 2: {
|
| 608 |
+
CUDA_KERNEL_ASSERT(0 <= p2 && p2 <= 1);
|
| 609 |
+
v2 = static_cast<scalar_t>(rand.y <= p2);
|
| 610 |
+
[[fallthrough]];
|
| 611 |
+
}
|
| 612 |
+
case 1: {
|
| 613 |
+
CUDA_KERNEL_ASSERT(0 <= p1 && p1 <= 1);
|
| 614 |
+
v1 = static_cast<scalar_t>(rand.x <= p1);
|
| 615 |
+
}
|
| 616 |
+
}
|
| 617 |
+
};
|
| 618 |
+
// The template argument `4` below indicates that we want to operate on four
|
| 619 |
+
// element at each time. See NOTE [ CUDA_tensor_applyN helpers ] for details.
|
| 620 |
+
at::cuda::CUDA_tensor_apply2<scalar_t, const prob_t, 4, decltype(functor),
|
| 621 |
+
/*max_threads_per_block=*/512,
|
| 622 |
+
/*min_blocks_per_sm==*/2>(ret, p, functor);
|
| 623 |
+
}
|
| 624 |
+
|
| 625 |
+
template<typename RNG>
|
| 626 |
+
void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG gen) {
|
| 627 |
+
PhiloxCudaState rng_engine_inputs;
|
| 628 |
+
{
|
| 629 |
+
// See Note [Acquire lock when using random generators]
|
| 630 |
+
std::lock_guard<std::mutex> lock(gen->mutex_);
|
| 631 |
+
rng_engine_inputs = gen->philox_cuda_state(10);
|
| 632 |
+
}
|
| 633 |
+
TORCH_CHECK(at::isFloatingType(p_.scalar_type()), "expected probabilities tensor to have floating type, got ", p_.scalar_type());
|
| 634 |
+
// cast probabilities tensor to double for double `self` tensor, and to `float` for everything else
|
| 635 |
+
const auto p_type = self.dtype() == at::kDouble ? at::kDouble : at::kFloat;
|
| 636 |
+
auto p_cuda = p_.to(TensorOptions().device(self.device()).dtype(p_type));
|
| 637 |
+
auto p = expand_inplace(self, p_cuda);
|
| 638 |
+
AT_DISPATCH_ALL_TYPES_AND3(
|
| 639 |
+
at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, self.scalar_type(), "bernoulli_tensor_cuda_self_", [&] {
|
| 640 |
+
if (std::is_same<scalar_t, double>::value) {
|
| 641 |
+
return bernoulli_tensor_cuda_kernel<double, double>(self, *p, rng_engine_inputs);
|
| 642 |
+
} else {
|
| 643 |
+
return bernoulli_tensor_cuda_kernel<scalar_t, float>(self, *p, rng_engine_inputs);
|
| 644 |
+
}
|
| 645 |
+
});
|
| 646 |
+
}
|
| 647 |
+
|
| 648 |
+
template<typename RNG>
|
| 649 |
+
void bernoulli_kernel(TensorIteratorBase& iter, double p, RNG gen) {
|
| 650 |
+
AT_DISPATCH_ALL_TYPES_AND3(
|
| 651 |
+
at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "bernoulli_scalar_cuda_", [&] {
|
| 652 |
+
using accscalar_t = at::DiscreteDistributionType<scalar_t>::type;
|
| 653 |
+
// define lambda for bernoulli transformation
|
| 654 |
+
auto bernoulli_func = [p] __device__ (accscalar_t rand) {
|
| 655 |
+
return static_cast<scalar_t>(transformation::bernoulli<accscalar_t>(rand, p));
|
| 656 |
+
};
|
| 657 |
+
uniform_and_transform<scalar_t, accscalar_t>(iter, gen, bernoulli_func);
|
| 658 |
+
});
|
| 659 |
+
}
|
| 660 |
+
|
| 661 |
+
template<typename RNG>
|
| 662 |
+
struct BernoulliKernel {
|
| 663 |
+
void operator()(TensorIteratorBase& iter, double p, std::optional<Generator> gen) {
|
| 664 |
+
bernoulli_kernel(iter, p, check_generator<RNG>(gen));
|
| 665 |
+
}
|
| 666 |
+
void operator()(const TensorBase &self, const TensorBase &p_, std::optional<Generator> gen) {
|
| 667 |
+
bernoulli_kernel(self, p_, check_generator<RNG>(gen));
|
| 668 |
+
}
|
| 669 |
+
};
|
| 670 |
+
|
| 671 |
+
}}}}
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Distributions.h
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
namespace at {
|
| 4 |
+
struct CUDAGeneratorImpl;
|
| 5 |
+
struct TensorIteratorBase;
|
| 6 |
+
class TensorBase;
|
| 7 |
+
|
| 8 |
+
namespace native {
|
| 9 |
+
|
| 10 |
+
void launch_poisson_cuda_kernel(
|
| 11 |
+
const TensorBase &ret, const TensorBase &lambda, CUDAGeneratorImpl *gen);
|
| 12 |
+
|
| 13 |
+
void launch_gamma_kernel(
|
| 14 |
+
const TensorBase &ret, const TensorBase &alpha, CUDAGeneratorImpl *gen);
|
| 15 |
+
|
| 16 |
+
void launch_binomial_cuda_kernel(
|
| 17 |
+
TensorIteratorBase &iter, CUDAGeneratorImpl *gen);
|
| 18 |
+
|
| 19 |
+
void launch_dirichlet_kernel(TensorIteratorBase &iter);
|
| 20 |
+
|
| 21 |
+
void launch_standard_gamma_grad_kernel(TensorIteratorBase &iter);
|
| 22 |
+
|
| 23 |
+
void launch_dirichlet_grad_kernel(TensorIteratorBase &iter);
|
| 24 |
+
|
| 25 |
+
}} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/ForeachFunctors.cuh
ADDED
|
@@ -0,0 +1,681 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/OpMathType.h>
|
| 3 |
+
#include <ATen/native/ForeachUtils.h>
|
| 4 |
+
#include <ATen/native/cuda/MultiTensorApply.cuh>
|
| 5 |
+
#include <ATen/native/cuda/Pow.cuh>
|
| 6 |
+
|
| 7 |
+
namespace at::native {
|
| 8 |
+
|
| 9 |
+
namespace {
|
| 10 |
+
|
| 11 |
+
// TODO(crcrpar): Handle version bump in codegen.
|
| 12 |
+
// rel:
|
| 13 |
+
// https://github.com/pytorch/pytorch/blob/9cf84347767c8abb8feba18a9a1baba321eeb8b9/tools/autograd/gen_inplace_or_view_type.py#L481-L482
|
| 14 |
+
inline void increment_version(TensorList tensors) {
|
| 15 |
+
for (const auto& t : tensors) {
|
| 16 |
+
t.unsafeGetTensorImpl()->bump_version();
|
| 17 |
+
}
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
// Initializes args and checks if all args are aligned
|
| 21 |
+
template <int depth, typename T>
|
| 22 |
+
__device__ bool init_args(
|
| 23 |
+
T** args,
|
| 24 |
+
TensorListMetadata<depth>& tl,
|
| 25 |
+
const int64_t chunk_idx,
|
| 26 |
+
const int64_t chunk_size,
|
| 27 |
+
const int64_t tensor_loc) {
|
| 28 |
+
bool all_aligned = true;
|
| 29 |
+
for (int i = 0; i < depth; i++) {
|
| 30 |
+
args[i] = (T*)tl.addresses[i][tensor_loc];
|
| 31 |
+
args[i] += chunk_idx * chunk_size;
|
| 32 |
+
|
| 33 |
+
if (!is_aligned(args[i])) {
|
| 34 |
+
all_aligned = false;
|
| 35 |
+
}
|
| 36 |
+
}
|
| 37 |
+
return all_aligned;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
// Initializes args and checks if all args are aligned
|
| 41 |
+
template <int depth, typename T, typename T2>
|
| 42 |
+
__device__ bool init_args(
|
| 43 |
+
T** args,
|
| 44 |
+
TensorListScalarListMetadata<T2, depth>& tl,
|
| 45 |
+
const int64_t chunk_idx,
|
| 46 |
+
const int64_t chunk_size,
|
| 47 |
+
const int64_t tensor_loc) {
|
| 48 |
+
bool all_aligned = true;
|
| 49 |
+
for (int i = 0; i < depth; i++) {
|
| 50 |
+
args[i] = (T*)tl.addresses[i][tensor_loc];
|
| 51 |
+
args[i] += chunk_idx * chunk_size;
|
| 52 |
+
|
| 53 |
+
if (!is_aligned(args[i])) {
|
| 54 |
+
all_aligned = false;
|
| 55 |
+
}
|
| 56 |
+
}
|
| 57 |
+
return all_aligned;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
template <int depth, typename T>
|
| 61 |
+
__device__ bool init_args(
|
| 62 |
+
T** args,
|
| 63 |
+
FusedOptimizerTensorListMetadata<depth>& tl,
|
| 64 |
+
const int64_t chunk_idx,
|
| 65 |
+
const int64_t chunk_size,
|
| 66 |
+
const int64_t tensor_loc) {
|
| 67 |
+
bool all_aligned = true;
|
| 68 |
+
for (int i = 0; i < depth; i++) {
|
| 69 |
+
args[i] = (T*)tl.addresses[i][tensor_loc];
|
| 70 |
+
args[i] += chunk_idx * chunk_size;
|
| 71 |
+
|
| 72 |
+
if (!is_aligned(args[i])) {
|
| 73 |
+
all_aligned = false;
|
| 74 |
+
}
|
| 75 |
+
}
|
| 76 |
+
return all_aligned;
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
template <int depth, typename T>
|
| 80 |
+
__device__ void load_args(
|
| 81 |
+
T r_args[][kILP],
|
| 82 |
+
T** args,
|
| 83 |
+
const int64_t i_start,
|
| 84 |
+
const int64_t chunk_size,
|
| 85 |
+
const int64_t n) {
|
| 86 |
+
#pragma unroll
|
| 87 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 88 |
+
const auto i = i_start + threadIdx.x + ii * blockDim.x;
|
| 89 |
+
for (int r_index = 0; r_index < depth; r_index++) {
|
| 90 |
+
r_args[r_index][ii] = 0;
|
| 91 |
+
if (i < n && i < chunk_size) {
|
| 92 |
+
r_args[r_index][ii] = args[r_index][i];
|
| 93 |
+
}
|
| 94 |
+
}
|
| 95 |
+
}
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
template <typename T>
|
| 99 |
+
__device__ void store_args(
|
| 100 |
+
T* dst,
|
| 101 |
+
T* src,
|
| 102 |
+
const int64_t i_start,
|
| 103 |
+
const int64_t chunk_size,
|
| 104 |
+
const int64_t n) {
|
| 105 |
+
#pragma unroll
|
| 106 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 107 |
+
const int64_t i = i_start + threadIdx.x + ii * blockDim.x;
|
| 108 |
+
if (i < n && i < chunk_size)
|
| 109 |
+
dst[i] = src[ii];
|
| 110 |
+
}
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
template <int res_arg_index, typename Op, typename T, typename opmath_t>
|
| 114 |
+
__device__ __forceinline__ void binary_op_scalar(
|
| 115 |
+
T r_args[][kILP],
|
| 116 |
+
T** args,
|
| 117 |
+
opmath_t scalar,
|
| 118 |
+
const int64_t n,
|
| 119 |
+
const int64_t chunk_size,
|
| 120 |
+
const bool all_aligned,
|
| 121 |
+
Op op) {
|
| 122 |
+
// to make things simple, we put aligned case in a different code path
|
| 123 |
+
if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
|
| 124 |
+
for (int64_t i_start = threadIdx.x;
|
| 125 |
+
i_start * kILP < n && i_start * kILP < chunk_size;
|
| 126 |
+
i_start += blockDim.x) {
|
| 127 |
+
// load
|
| 128 |
+
load_store(r_args[0], args[0], 0, i_start);
|
| 129 |
+
#pragma unroll
|
| 130 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 131 |
+
r_args[0][ii] = static_cast<T>(
|
| 132 |
+
op(static_cast<opmath_t>(r_args[0][ii]),
|
| 133 |
+
static_cast<opmath_t>(scalar)));
|
| 134 |
+
}
|
| 135 |
+
// store
|
| 136 |
+
load_store(args[res_arg_index], r_args[0], i_start, 0);
|
| 137 |
+
}
|
| 138 |
+
} else {
|
| 139 |
+
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
|
| 140 |
+
i_start += blockDim.x * kILP) {
|
| 141 |
+
// Regardless if depth is 1 (for inplace) or 2 (for out of place), r_args
|
| 142 |
+
// has depth 1
|
| 143 |
+
load_args<1>(r_args, args, i_start, chunk_size, n);
|
| 144 |
+
#pragma unroll
|
| 145 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 146 |
+
r_args[0][ii] = static_cast<T>(
|
| 147 |
+
op(static_cast<opmath_t>(r_args[0][ii]),
|
| 148 |
+
static_cast<opmath_t>(scalar)));
|
| 149 |
+
}
|
| 150 |
+
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
|
| 151 |
+
}
|
| 152 |
+
}
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
template <int res_arg_index, typename Op, typename T, typename opmath_t>
|
| 156 |
+
__device__ __forceinline__ void pointwise_op_scalar(
|
| 157 |
+
T r_args[][kILP],
|
| 158 |
+
T** args,
|
| 159 |
+
opmath_t scalar,
|
| 160 |
+
const int64_t n,
|
| 161 |
+
const int64_t chunk_size,
|
| 162 |
+
const bool all_aligned,
|
| 163 |
+
Op op) {
|
| 164 |
+
// to make things simple, we put aligned case in a different code path
|
| 165 |
+
if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
|
| 166 |
+
for (int64_t i_start = threadIdx.x;
|
| 167 |
+
i_start * kILP < n && i_start * kILP < chunk_size;
|
| 168 |
+
i_start += blockDim.x) {
|
| 169 |
+
// load
|
| 170 |
+
load_store(r_args[0], args[0], 0, i_start);
|
| 171 |
+
load_store(r_args[1], args[1], 0, i_start);
|
| 172 |
+
load_store(r_args[2], args[2], 0, i_start);
|
| 173 |
+
#pragma unroll
|
| 174 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 175 |
+
r_args[0][ii] = static_cast<T>(
|
| 176 |
+
static_cast<opmath_t>(r_args[0][ii]) +
|
| 177 |
+
scalar *
|
| 178 |
+
op(static_cast<opmath_t>(r_args[1][ii]),
|
| 179 |
+
static_cast<opmath_t>(r_args[2][ii])));
|
| 180 |
+
}
|
| 181 |
+
// store
|
| 182 |
+
load_store(args[res_arg_index], r_args[0], i_start, 0);
|
| 183 |
+
}
|
| 184 |
+
} else {
|
| 185 |
+
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
|
| 186 |
+
i_start += blockDim.x * kILP) {
|
| 187 |
+
// Regardless if depth is 3 (for inplace) or 4 (for out of place), r_args
|
| 188 |
+
// has depth 3
|
| 189 |
+
load_args<3>(r_args, args, i_start, chunk_size, n);
|
| 190 |
+
#pragma unroll
|
| 191 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 192 |
+
r_args[0][ii] = static_cast<T>(
|
| 193 |
+
static_cast<opmath_t>(r_args[0][ii]) +
|
| 194 |
+
scalar *
|
| 195 |
+
op(static_cast<opmath_t>(r_args[1][ii]),
|
| 196 |
+
static_cast<opmath_t>(r_args[2][ii])));
|
| 197 |
+
}
|
| 198 |
+
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
|
| 199 |
+
}
|
| 200 |
+
}
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
//
|
| 204 |
+
// Binary Functors
|
| 205 |
+
//
|
| 206 |
+
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
| 207 |
+
struct BinaryOpScalarFunctor {
|
| 208 |
+
using opmath_t = at::opmath_type<T>;
|
| 209 |
+
template <typename Op>
|
| 210 |
+
__device__ __forceinline__ void operator()(
|
| 211 |
+
int chunk_size,
|
| 212 |
+
TensorListMetadata<depth>& tl,
|
| 213 |
+
Op op,
|
| 214 |
+
opmath_t scalar) {
|
| 215 |
+
const int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
| 216 |
+
const int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
| 217 |
+
auto n = tl.numel_for_tensor[tensor_loc];
|
| 218 |
+
|
| 219 |
+
T* args[depth];
|
| 220 |
+
const bool all_aligned =
|
| 221 |
+
init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
| 222 |
+
n -= chunk_idx * chunk_size;
|
| 223 |
+
T r_args[r_args_depth][kILP];
|
| 224 |
+
|
| 225 |
+
binary_op_scalar<res_arg_index>(
|
| 226 |
+
r_args, args, scalar, n, chunk_size, all_aligned, op);
|
| 227 |
+
}
|
| 228 |
+
};
|
| 229 |
+
|
| 230 |
+
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
| 231 |
+
struct BinaryOpScalarListFunctor {
|
| 232 |
+
using opmath_t = at::opmath_type<T>;
|
| 233 |
+
template <typename Op>
|
| 234 |
+
__device__ __forceinline__ void operator()(
|
| 235 |
+
int chunk_size,
|
| 236 |
+
TensorListScalarListMetadata<opmath_t, depth>& tl,
|
| 237 |
+
Op op) {
|
| 238 |
+
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
| 239 |
+
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
| 240 |
+
auto n = tl.numel_for_tensor[tensor_loc];
|
| 241 |
+
|
| 242 |
+
T* args[depth];
|
| 243 |
+
const bool all_aligned =
|
| 244 |
+
init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
| 245 |
+
opmath_t scalar = tl.scalar_vals[tensor_loc];
|
| 246 |
+
n -= chunk_idx * chunk_size;
|
| 247 |
+
T r_args[r_args_depth][kILP];
|
| 248 |
+
|
| 249 |
+
binary_op_scalar<res_arg_index>(
|
| 250 |
+
r_args, args, scalar, n, chunk_size, all_aligned, op);
|
| 251 |
+
}
|
| 252 |
+
};
|
| 253 |
+
|
| 254 |
+
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
| 255 |
+
struct BinaryOpListAlphaFunctor {
|
| 256 |
+
using opmath_t = at::opmath_type<T>;
|
| 257 |
+
template <typename Op>
|
| 258 |
+
__device__ __forceinline__ void operator()(
|
| 259 |
+
int chunk_size,
|
| 260 |
+
TensorListMetadata<depth>& tl,
|
| 261 |
+
Op op,
|
| 262 |
+
opmath_t alpha) {
|
| 263 |
+
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
| 264 |
+
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
| 265 |
+
auto n = tl.numel_for_tensor[tensor_loc];
|
| 266 |
+
|
| 267 |
+
T* args[depth];
|
| 268 |
+
const bool all_aligned =
|
| 269 |
+
init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
| 270 |
+
n -= chunk_idx * chunk_size;
|
| 271 |
+
T r_args[r_args_depth][kILP];
|
| 272 |
+
|
| 273 |
+
// to make things simple, we put aligned case in a different code path
|
| 274 |
+
if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
|
| 275 |
+
for (int64_t i_start = threadIdx.x;
|
| 276 |
+
i_start * kILP < n && i_start * kILP < chunk_size;
|
| 277 |
+
i_start += blockDim.x) {
|
| 278 |
+
// load
|
| 279 |
+
load_store(r_args[0], args[0], 0, i_start);
|
| 280 |
+
load_store(r_args[1], args[1], 0, i_start);
|
| 281 |
+
#pragma unroll
|
| 282 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 283 |
+
r_args[0][ii] = static_cast<T>(
|
| 284 |
+
op(static_cast<opmath_t>(r_args[0][ii]),
|
| 285 |
+
alpha * static_cast<opmath_t>(r_args[1][ii])));
|
| 286 |
+
}
|
| 287 |
+
// store
|
| 288 |
+
load_store(args[res_arg_index], r_args[0], i_start, 0);
|
| 289 |
+
}
|
| 290 |
+
} else {
|
| 291 |
+
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
|
| 292 |
+
i_start += blockDim.x * kILP) {
|
| 293 |
+
load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
|
| 294 |
+
#pragma unroll
|
| 295 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 296 |
+
r_args[0][ii] = static_cast<T>(
|
| 297 |
+
op(static_cast<opmath_t>(r_args[0][ii]),
|
| 298 |
+
alpha * static_cast<opmath_t>(r_args[1][ii])));
|
| 299 |
+
}
|
| 300 |
+
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
|
| 301 |
+
}
|
| 302 |
+
}
|
| 303 |
+
}
|
| 304 |
+
};
|
| 305 |
+
|
| 306 |
+
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
| 307 |
+
struct BinaryOpScalarTensorFunctor {
|
| 308 |
+
using opmath_t = at::opmath_type<T>;
|
| 309 |
+
template <typename Op>
|
| 310 |
+
__device__ __forceinline__ void operator()(
|
| 311 |
+
int chunk_size,
|
| 312 |
+
TensorListMetadata<depth>& tl,
|
| 313 |
+
Op op,
|
| 314 |
+
T* scalar,
|
| 315 |
+
opmath_t alpha) {
|
| 316 |
+
const int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
| 317 |
+
const int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
| 318 |
+
auto n = tl.numel_for_tensor[tensor_loc];
|
| 319 |
+
|
| 320 |
+
T* args[depth];
|
| 321 |
+
const bool all_aligned =
|
| 322 |
+
init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
| 323 |
+
n -= chunk_idx * chunk_size;
|
| 324 |
+
T r_args[r_args_depth][kILP];
|
| 325 |
+
|
| 326 |
+
// to make things simple, we put aligned case in a different code path
|
| 327 |
+
if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
|
| 328 |
+
for (int64_t i_start = threadIdx.x;
|
| 329 |
+
i_start * kILP < n && i_start * kILP < chunk_size;
|
| 330 |
+
i_start += blockDim.x) {
|
| 331 |
+
// load
|
| 332 |
+
load_store(r_args[0], args[0], 0, i_start);
|
| 333 |
+
#pragma unroll
|
| 334 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 335 |
+
r_args[0][ii] = static_cast<T>(op(
|
| 336 |
+
static_cast<opmath_t>(r_args[0][ii]),
|
| 337 |
+
static_cast<opmath_t>(alpha) * static_cast<opmath_t>(*scalar)));
|
| 338 |
+
}
|
| 339 |
+
// store
|
| 340 |
+
load_store(args[res_arg_index], r_args[0], i_start, 0);
|
| 341 |
+
}
|
| 342 |
+
} else {
|
| 343 |
+
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
|
| 344 |
+
i_start += blockDim.x * kILP) {
|
| 345 |
+
// Regardless if depth is 1 (for inplace) or 2 (for out of place),
|
| 346 |
+
// r_args has depth 1
|
| 347 |
+
load_args<1>(r_args, args, i_start, chunk_size, n);
|
| 348 |
+
#pragma unroll
|
| 349 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 350 |
+
r_args[0][ii] = static_cast<T>(op(
|
| 351 |
+
static_cast<opmath_t>(r_args[0][ii]),
|
| 352 |
+
static_cast<opmath_t>(alpha) * static_cast<opmath_t>(*scalar)));
|
| 353 |
+
}
|
| 354 |
+
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
|
| 355 |
+
}
|
| 356 |
+
}
|
| 357 |
+
}
|
| 358 |
+
};
|
| 359 |
+
|
| 360 |
+
//
|
| 361 |
+
// Unary Functors
|
| 362 |
+
//
|
| 363 |
+
|
| 364 |
+
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
| 365 |
+
struct ZeroFunctor {
|
| 366 |
+
__device__ __forceinline__ void operator()(
|
| 367 |
+
int chunk_size,
|
| 368 |
+
TensorListMetadata<1>& tl) {
|
| 369 |
+
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
| 370 |
+
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
| 371 |
+
auto n = tl.numel_for_tensor[tensor_loc];
|
| 372 |
+
|
| 373 |
+
T* args[depth];
|
| 374 |
+
const auto all_aligned =
|
| 375 |
+
init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
| 376 |
+
n -= chunk_idx * chunk_size;
|
| 377 |
+
T r_args[r_args_depth][kILP];
|
| 378 |
+
|
| 379 |
+
// to make things simple, we put aligned case in a different code path
|
| 380 |
+
if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
|
| 381 |
+
for (int64_t i_start = threadIdx.x;
|
| 382 |
+
i_start * kILP < n && i_start * kILP < chunk_size;
|
| 383 |
+
i_start += blockDim.x) {
|
| 384 |
+
#pragma unroll
|
| 385 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 386 |
+
r_args[0][ii] = 0;
|
| 387 |
+
}
|
| 388 |
+
// store
|
| 389 |
+
load_store(args[0], r_args[0], i_start, 0);
|
| 390 |
+
}
|
| 391 |
+
} else {
|
| 392 |
+
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
|
| 393 |
+
i_start += blockDim.x * kILP) {
|
| 394 |
+
#pragma unroll
|
| 395 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 396 |
+
r_args[0][ii] = 0;
|
| 397 |
+
}
|
| 398 |
+
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
|
| 399 |
+
}
|
| 400 |
+
}
|
| 401 |
+
}
|
| 402 |
+
};
|
| 403 |
+
|
| 404 |
+
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
| 405 |
+
struct UnaryOpFunctor {
|
| 406 |
+
using opmath_t = at::opmath_type<T>;
|
| 407 |
+
template <typename Op>
|
| 408 |
+
__device__ __forceinline__ void operator()(
|
| 409 |
+
int chunk_size,
|
| 410 |
+
TensorListMetadata<depth>& tl,
|
| 411 |
+
Op op) {
|
| 412 |
+
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
| 413 |
+
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
| 414 |
+
auto n = tl.numel_for_tensor[tensor_loc];
|
| 415 |
+
|
| 416 |
+
T* args[depth];
|
| 417 |
+
bool all_aligned =
|
| 418 |
+
init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
| 419 |
+
n -= chunk_idx * chunk_size;
|
| 420 |
+
T r_args[r_args_depth][kILP];
|
| 421 |
+
|
| 422 |
+
// to make things simple, we put aligned case in a different code path
|
| 423 |
+
if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
|
| 424 |
+
for (int64_t i_start = threadIdx.x;
|
| 425 |
+
i_start * kILP < n && i_start * kILP < chunk_size;
|
| 426 |
+
i_start += blockDim.x) {
|
| 427 |
+
// load
|
| 428 |
+
load_store(r_args[0], args[0], 0, i_start);
|
| 429 |
+
#pragma unroll
|
| 430 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 431 |
+
r_args[0][ii] =
|
| 432 |
+
static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii])));
|
| 433 |
+
}
|
| 434 |
+
// store
|
| 435 |
+
load_store(args[res_arg_index], r_args[0], i_start, 0);
|
| 436 |
+
}
|
| 437 |
+
} else {
|
| 438 |
+
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
|
| 439 |
+
i_start += blockDim.x * kILP) {
|
| 440 |
+
load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
|
| 441 |
+
#pragma unroll
|
| 442 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 443 |
+
r_args[0][ii] =
|
| 444 |
+
static_cast<T>(op(static_cast<opmath_t>(r_args[0][ii])));
|
| 445 |
+
}
|
| 446 |
+
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
|
| 447 |
+
}
|
| 448 |
+
}
|
| 449 |
+
}
|
| 450 |
+
};
|
| 451 |
+
|
| 452 |
+
//
|
| 453 |
+
// Pointwise Functors
|
| 454 |
+
//
|
| 455 |
+
|
| 456 |
+
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
| 457 |
+
struct PointwiseOpScalarFunctor {
|
| 458 |
+
using opmath_t = at::opmath_type<T>;
|
| 459 |
+
template <typename Op>
|
| 460 |
+
__device__ __forceinline__ void operator()(
|
| 461 |
+
int chunk_size,
|
| 462 |
+
TensorListMetadata<depth>& tl,
|
| 463 |
+
Op op,
|
| 464 |
+
opmath_t scalar) {
|
| 465 |
+
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
| 466 |
+
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
| 467 |
+
auto n = tl.numel_for_tensor[tensor_loc];
|
| 468 |
+
|
| 469 |
+
T* args[depth];
|
| 470 |
+
const bool all_aligned =
|
| 471 |
+
init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
| 472 |
+
n -= chunk_idx * chunk_size;
|
| 473 |
+
T r_args[r_args_depth][kILP];
|
| 474 |
+
|
| 475 |
+
pointwise_op_scalar<res_arg_index>(
|
| 476 |
+
r_args, args, scalar, n, chunk_size, all_aligned, op);
|
| 477 |
+
}
|
| 478 |
+
};
|
| 479 |
+
|
| 480 |
+
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
| 481 |
+
struct PointwiseOpScalarListFunctor {
|
| 482 |
+
using opmath_t = at::opmath_type<T>;
|
| 483 |
+
template <typename Op>
|
| 484 |
+
__device__ __forceinline__ void operator()(
|
| 485 |
+
int chunk_size,
|
| 486 |
+
TensorListScalarListMetadata<opmath_t, depth>& tl,
|
| 487 |
+
Op op) {
|
| 488 |
+
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
| 489 |
+
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
| 490 |
+
auto n = tl.numel_for_tensor[tensor_loc];
|
| 491 |
+
|
| 492 |
+
T* args[depth];
|
| 493 |
+
const bool all_aligned =
|
| 494 |
+
init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
| 495 |
+
opmath_t scalar = tl.scalar_vals[tensor_loc];
|
| 496 |
+
n -= chunk_idx * chunk_size;
|
| 497 |
+
T r_args[r_args_depth][kILP];
|
| 498 |
+
|
| 499 |
+
pointwise_op_scalar<res_arg_index>(
|
| 500 |
+
r_args, args, scalar, n, chunk_size, all_aligned, op);
|
| 501 |
+
}
|
| 502 |
+
};
|
| 503 |
+
|
| 504 |
+
template <typename T, int depth>
|
| 505 |
+
struct PointwiseOpListFunctor {
|
| 506 |
+
using opmath_t = at::opmath_type<T>;
|
| 507 |
+
template <typename Op>
|
| 508 |
+
__device__ __forceinline__ void operator()(
|
| 509 |
+
int chunk_size,
|
| 510 |
+
TensorListMetadata<depth>& tl,
|
| 511 |
+
Op op) {
|
| 512 |
+
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
| 513 |
+
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
| 514 |
+
auto n = tl.numel_for_tensor[tensor_loc];
|
| 515 |
+
|
| 516 |
+
T* args[depth];
|
| 517 |
+
const bool all_aligned =
|
| 518 |
+
init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
| 519 |
+
n -= chunk_idx * chunk_size;
|
| 520 |
+
T r_args[depth - 1][kILP];
|
| 521 |
+
|
| 522 |
+
// to make things simple, we put aligned case in a different code path
|
| 523 |
+
if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
|
| 524 |
+
for (int64_t i_start = threadIdx.x;
|
| 525 |
+
i_start * kILP < n && i_start * kILP < chunk_size;
|
| 526 |
+
i_start += blockDim.x) {
|
| 527 |
+
// load
|
| 528 |
+
load_store(r_args[0], args[0], 0, i_start);
|
| 529 |
+
load_store(r_args[1], args[1], 0, i_start);
|
| 530 |
+
#pragma unroll
|
| 531 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 532 |
+
r_args[0][ii] = static_cast<T>(
|
| 533 |
+
op(static_cast<opmath_t>(r_args[0][ii]),
|
| 534 |
+
static_cast<opmath_t>(r_args[1][ii])));
|
| 535 |
+
}
|
| 536 |
+
// store
|
| 537 |
+
load_store(args[2], r_args[0], i_start, 0);
|
| 538 |
+
}
|
| 539 |
+
} else {
|
| 540 |
+
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
|
| 541 |
+
i_start += blockDim.x * kILP) {
|
| 542 |
+
load_args<depth - 1>(r_args, args, i_start, chunk_size, n);
|
| 543 |
+
#pragma unroll
|
| 544 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 545 |
+
r_args[0][ii] = static_cast<T>(
|
| 546 |
+
op(static_cast<opmath_t>(r_args[0][ii]),
|
| 547 |
+
static_cast<opmath_t>(r_args[1][ii])));
|
| 548 |
+
}
|
| 549 |
+
store_args(args[2], r_args[0], i_start, chunk_size, n);
|
| 550 |
+
}
|
| 551 |
+
}
|
| 552 |
+
}
|
| 553 |
+
};
|
| 554 |
+
|
| 555 |
+
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
| 556 |
+
struct TernaryOpListFunctor {
|
| 557 |
+
using opmath_t = at::opmath_type<T>;
|
| 558 |
+
template <typename Op>
|
| 559 |
+
__device__ __forceinline__ void operator()(
|
| 560 |
+
int chunk_size,
|
| 561 |
+
TensorListMetadata<depth>& tl,
|
| 562 |
+
Op op) {
|
| 563 |
+
static_assert(depth == 3 || depth == 4, "");
|
| 564 |
+
static_assert(depth >= r_args_depth, "");
|
| 565 |
+
static_assert(res_arg_index == depth - 1 || res_arg_index == 0, "");
|
| 566 |
+
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
| 567 |
+
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
| 568 |
+
auto n = tl.numel_for_tensor[tensor_loc];
|
| 569 |
+
|
| 570 |
+
T* args[depth];
|
| 571 |
+
const bool all_aligned =
|
| 572 |
+
init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
| 573 |
+
n -= chunk_idx * chunk_size;
|
| 574 |
+
T r_args[r_args_depth][kILP];
|
| 575 |
+
|
| 576 |
+
if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
|
| 577 |
+
for (int64_t i_start = threadIdx.x;
|
| 578 |
+
i_start * kILP < n && i_start * kILP < chunk_size;
|
| 579 |
+
i_start += blockDim.x) {
|
| 580 |
+
load_store(r_args[0], args[0], 0, i_start);
|
| 581 |
+
load_store(r_args[1], args[1], 0, i_start);
|
| 582 |
+
load_store(r_args[2], args[2], 0, i_start);
|
| 583 |
+
#pragma unroll
|
| 584 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 585 |
+
r_args[0][ii] =
|
| 586 |
+
op(static_cast<opmath_t>(r_args[0][ii]),
|
| 587 |
+
static_cast<opmath_t>(r_args[1][ii]),
|
| 588 |
+
static_cast<opmath_t>(r_args[2][ii]));
|
| 589 |
+
}
|
| 590 |
+
load_store(args[res_arg_index], r_args[0], i_start, 0);
|
| 591 |
+
}
|
| 592 |
+
} else {
|
| 593 |
+
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
|
| 594 |
+
i_start += blockDim.x * kILP) {
|
| 595 |
+
load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
|
| 596 |
+
#pragma unroll
|
| 597 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 598 |
+
r_args[0][ii] =
|
| 599 |
+
op(static_cast<opmath_t>(r_args[0][ii]),
|
| 600 |
+
static_cast<opmath_t>(r_args[1][ii]),
|
| 601 |
+
static_cast<opmath_t>(r_args[2][ii]));
|
| 602 |
+
}
|
| 603 |
+
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
|
| 604 |
+
}
|
| 605 |
+
}
|
| 606 |
+
}
|
| 607 |
+
};
|
| 608 |
+
|
| 609 |
+
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
| 610 |
+
struct TernaryOpScalarFunctor {
|
| 611 |
+
using opmath_t = at::opmath_type<T>;
|
| 612 |
+
template <typename Op>
|
| 613 |
+
__device__ __forceinline__ void operator()(
|
| 614 |
+
int chunk_size,
|
| 615 |
+
TensorListMetadata<depth>& tl,
|
| 616 |
+
Op op,
|
| 617 |
+
opmath_t alpha) {
|
| 618 |
+
static_assert(depth == 2 || depth == 3, "");
|
| 619 |
+
static_assert(depth >= r_args_depth, "");
|
| 620 |
+
static_assert(res_arg_index == depth - 1 || res_arg_index == 0, "");
|
| 621 |
+
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
| 622 |
+
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
| 623 |
+
auto n = tl.numel_for_tensor[tensor_loc];
|
| 624 |
+
|
| 625 |
+
T* args[depth];
|
| 626 |
+
const bool all_aligned =
|
| 627 |
+
init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
| 628 |
+
n -= chunk_idx * chunk_size;
|
| 629 |
+
T r_args[r_args_depth][kILP];
|
| 630 |
+
|
| 631 |
+
// to make things simple, we put aligned case in a different code path
|
| 632 |
+
if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
|
| 633 |
+
for (int64_t i_start = threadIdx.x;
|
| 634 |
+
i_start * kILP < n && i_start * kILP < chunk_size;
|
| 635 |
+
i_start += blockDim.x) {
|
| 636 |
+
// load
|
| 637 |
+
load_store(r_args[0], args[0], 0, i_start);
|
| 638 |
+
load_store(r_args[1], args[1], 0, i_start);
|
| 639 |
+
#pragma unroll
|
| 640 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 641 |
+
r_args[0][ii] =
|
| 642 |
+
op(static_cast<opmath_t>(r_args[0][ii]),
|
| 643 |
+
static_cast<opmath_t>(r_args[1][ii]),
|
| 644 |
+
alpha);
|
| 645 |
+
}
|
| 646 |
+
// store
|
| 647 |
+
load_store(args[res_arg_index], r_args[0], i_start, 0);
|
| 648 |
+
}
|
| 649 |
+
} else {
|
| 650 |
+
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
|
| 651 |
+
i_start += blockDim.x * kILP) {
|
| 652 |
+
load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
|
| 653 |
+
#pragma unroll
|
| 654 |
+
for (int ii = 0; ii < kILP; ii++) {
|
| 655 |
+
r_args[0][ii] =
|
| 656 |
+
op(static_cast<opmath_t>(r_args[0][ii]),
|
| 657 |
+
static_cast<opmath_t>(r_args[1][ii]),
|
| 658 |
+
alpha);
|
| 659 |
+
}
|
| 660 |
+
store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n);
|
| 661 |
+
}
|
| 662 |
+
}
|
| 663 |
+
}
|
| 664 |
+
};
|
| 665 |
+
|
| 666 |
+
template <typename T>
|
| 667 |
+
struct power_functor {
|
| 668 |
+
C10_DEVICE T operator()(const T& a, const T& b) const {
|
| 669 |
+
return at::native::pow_(a, b);
|
| 670 |
+
}
|
| 671 |
+
};
|
| 672 |
+
|
| 673 |
+
template <typename T>
|
| 674 |
+
struct reverse_power_functor {
|
| 675 |
+
C10_DEVICE T operator()(const T& a, const T& b) const {
|
| 676 |
+
return at::native::pow_(b, a);
|
| 677 |
+
}
|
| 678 |
+
};
|
| 679 |
+
|
| 680 |
+
} // namespace
|
| 681 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/GridSampler.cuh
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/native/cuda/KernelUtils.cuh>
|
| 3 |
+
#include <ATen/native/GridSamplerUtils.h>
|
| 4 |
+
|
| 5 |
+
namespace at { namespace native {
|
| 6 |
+
|
| 7 |
+
using detail::GridSamplerInterpolation;
|
| 8 |
+
using detail::GridSamplerPadding;
|
| 9 |
+
|
| 10 |
+
// Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value,
|
| 11 |
+
// where we view each pixel as an area between (idx - 0.5) and (idx + 0.5).
|
| 12 |
+
// if align_corners: -1 and +1 get sent to the centers of the corner pixels
|
| 13 |
+
// -1 --> 0
|
| 14 |
+
// +1 --> (size - 1)
|
| 15 |
+
// scale_factor = (size - 1) / 2
|
| 16 |
+
// if not align_corners: -1 and +1 get sent to the image edges
|
| 17 |
+
// -1 --> -0.5
|
| 18 |
+
// +1 --> (size - 1) + 0.5 == size - 0.5
|
| 19 |
+
// scale_factor = size / 2
|
| 20 |
+
template <typename scalar_t>
|
| 21 |
+
__forceinline__ __device__
|
| 22 |
+
scalar_t grid_sampler_unnormalize(scalar_t coord, int size, bool align_corners) {
|
| 23 |
+
if (align_corners) {
|
| 24 |
+
// unnormalize coord from [-1, 1] to [0, size - 1]
|
| 25 |
+
return ((coord + 1.f) / 2) * (size - 1);
|
| 26 |
+
} else {
|
| 27 |
+
// unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
|
| 28 |
+
return ((coord + 1.f) * size - 1) / 2;
|
| 29 |
+
}
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
// grid_sampler_unnormalize_set_grad works the same as grid_sampler_unnormalize
|
| 33 |
+
// except that it also returns the `d output / d input` via pointer argument
|
| 34 |
+
// `grad_in`.
|
| 35 |
+
// This is useful in the backward pass of grid_sampler.
|
| 36 |
+
template <typename scalar_t>
|
| 37 |
+
__forceinline__ __device__
|
| 38 |
+
scalar_t grid_sampler_unnormalize_set_grad(scalar_t coord, int size,
|
| 39 |
+
bool align_corners, scalar_t *grad_in) {
|
| 40 |
+
if (align_corners) {
|
| 41 |
+
// unnormalize coord from [-1, 1] to [0, size - 1]
|
| 42 |
+
*grad_in = static_cast<scalar_t>(size - 1) / 2;
|
| 43 |
+
return ((coord + 1.f) / 2) * (size - 1);
|
| 44 |
+
} else {
|
| 45 |
+
// unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
|
| 46 |
+
*grad_in = static_cast<scalar_t>(size) / 2;
|
| 47 |
+
return ((coord + 1.f) * size - 1) / 2;
|
| 48 |
+
}
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
// Clips coordinates to between 0 and clip_limit - 1
|
| 52 |
+
template <typename scalar_t>
|
| 53 |
+
__forceinline__ __device__
|
| 54 |
+
scalar_t clip_coordinates(scalar_t in, int clip_limit) {
|
| 55 |
+
return ::min(static_cast<scalar_t>(clip_limit - 1), ::max(in, static_cast<scalar_t>(0)));
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
// clip_coordinates_set_grad works similarly to clip_coordinates except that
|
| 59 |
+
// it also returns the `d output / d input` via pointer argument `grad_in`.
|
| 60 |
+
// This is useful in the backward pass of grid_sampler.
|
| 61 |
+
template <typename scalar_t>
|
| 62 |
+
__forceinline__ __device__
|
| 63 |
+
scalar_t clip_coordinates_set_grad(scalar_t in, int clip_limit, scalar_t *grad_in) {
|
| 64 |
+
// Note that it is important for the gradient calculation that borders
|
| 65 |
+
// are considered out of bounds.
|
| 66 |
+
if (in <= static_cast<scalar_t>(0)) {
|
| 67 |
+
*grad_in = static_cast<scalar_t>(0);
|
| 68 |
+
return static_cast<scalar_t>(0);
|
| 69 |
+
} else {
|
| 70 |
+
scalar_t max = static_cast<scalar_t>(clip_limit - 1);
|
| 71 |
+
if (in >= max) {
|
| 72 |
+
*grad_in = static_cast<scalar_t>(0);
|
| 73 |
+
return max;
|
| 74 |
+
} else {
|
| 75 |
+
*grad_in = static_cast<scalar_t>(1);
|
| 76 |
+
return in;
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
// Reflects coordinates until they fall between low and high (inclusive).
|
| 82 |
+
// The bounds are passed as twice their value so that half-integer values
|
| 83 |
+
// can be represented as ints.
|
| 84 |
+
template <typename scalar_t>
|
| 85 |
+
__forceinline__ __device__
|
| 86 |
+
scalar_t reflect_coordinates(scalar_t in, int twice_low, int twice_high) {
|
| 87 |
+
if (twice_low == twice_high) {
|
| 88 |
+
return static_cast<scalar_t>(0);
|
| 89 |
+
}
|
| 90 |
+
scalar_t min = static_cast<scalar_t>(twice_low) / 2;
|
| 91 |
+
scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
|
| 92 |
+
in = ::fabs(in - min);
|
| 93 |
+
// `fmod` returns same sign as `in`, which is positive after the `fabs` above.
|
| 94 |
+
scalar_t extra = ::fmod(in, span);
|
| 95 |
+
int flips = static_cast<int>(::floor(in / span));
|
| 96 |
+
if (flips % 2 == 0) {
|
| 97 |
+
return extra + min;
|
| 98 |
+
} else {
|
| 99 |
+
return span - extra + min;
|
| 100 |
+
}
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
// reflect_coordinates_set_grad works similarly to reflect_coordinates except
|
| 104 |
+
// that it also returns the `d output / d input` via pointer argument
|
| 105 |
+
// `grad_in`.
|
| 106 |
+
// This is useful in the backward pass of grid_sampler.
|
| 107 |
+
template <typename scalar_t>
|
| 108 |
+
__forceinline__ __device__
|
| 109 |
+
scalar_t reflect_coordinates_set_grad(scalar_t in, int twice_low, int twice_high,
|
| 110 |
+
scalar_t *grad_in) {
|
| 111 |
+
if (twice_low == twice_high) {
|
| 112 |
+
*grad_in = static_cast<scalar_t>(0);
|
| 113 |
+
return static_cast<scalar_t>(0);
|
| 114 |
+
}
|
| 115 |
+
int grad_in_mult_;
|
| 116 |
+
scalar_t min = static_cast<scalar_t>(twice_low) / 2;
|
| 117 |
+
scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
|
| 118 |
+
in = in - min;
|
| 119 |
+
if (in < static_cast<scalar_t>(0)) {
|
| 120 |
+
grad_in_mult_ = -1;
|
| 121 |
+
in = -in;
|
| 122 |
+
} else {
|
| 123 |
+
grad_in_mult_ = 1;
|
| 124 |
+
}
|
| 125 |
+
// `fmod` returns same sign as `in`, which is positive after the `if` above.
|
| 126 |
+
scalar_t extra = ::fmod(in, span);
|
| 127 |
+
int flips = static_cast<int>(::floor(in / span));
|
| 128 |
+
if (flips % 2 == 0) {
|
| 129 |
+
*grad_in = static_cast<scalar_t>(grad_in_mult_);
|
| 130 |
+
return extra + min;
|
| 131 |
+
} else {
|
| 132 |
+
*grad_in = static_cast<scalar_t>(-grad_in_mult_);
|
| 133 |
+
return span - extra + min;
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
template<typename scalar_t>
|
| 138 |
+
__forceinline__ __device__
|
| 139 |
+
scalar_t safe_downgrade_to_int_range(scalar_t x){
|
| 140 |
+
// -100.0 does not have special meaning. This is just to make sure
|
| 141 |
+
// it's not within_bounds_2d or within_bounds_3d, and does not cause
|
| 142 |
+
// undefined behavior. See #35506.
|
| 143 |
+
if (x > INT_MAX-1 || x < INT_MIN || !::isfinite(static_cast<double>(x)))
|
| 144 |
+
return static_cast<scalar_t>(-100.0);
|
| 145 |
+
return x;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
template<typename scalar_t>
|
| 149 |
+
__forceinline__ __device__
|
| 150 |
+
scalar_t compute_coordinates(scalar_t coord, int size,
|
| 151 |
+
GridSamplerPadding padding_mode,
|
| 152 |
+
bool align_corners) {
|
| 153 |
+
if (padding_mode == GridSamplerPadding::Border) {
|
| 154 |
+
// clip coordinates to image borders
|
| 155 |
+
coord = clip_coordinates(coord, size);
|
| 156 |
+
} else if (padding_mode == GridSamplerPadding::Reflection) {
|
| 157 |
+
// reflect coordinates by image borders
|
| 158 |
+
if (align_corners) {
|
| 159 |
+
coord = reflect_coordinates(coord, 0, 2*(size - 1));
|
| 160 |
+
} else {
|
| 161 |
+
coord = reflect_coordinates(coord, -1, 2*size - 1);
|
| 162 |
+
}
|
| 163 |
+
// clip coordinates to image borders
|
| 164 |
+
coord = clip_coordinates(coord, size);
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
coord = safe_downgrade_to_int_range(coord);
|
| 168 |
+
return coord;
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
// Computes the pixel source index value for a grid coordinate
|
| 172 |
+
template <typename scalar_t>
|
| 173 |
+
__forceinline__ __device__
|
| 174 |
+
scalar_t grid_sampler_compute_source_index(
|
| 175 |
+
scalar_t coord,
|
| 176 |
+
int size,
|
| 177 |
+
GridSamplerPadding padding_mode,
|
| 178 |
+
bool align_corners) {
|
| 179 |
+
coord = grid_sampler_unnormalize(coord, size, align_corners);
|
| 180 |
+
coord = compute_coordinates(coord, size, padding_mode, align_corners);
|
| 181 |
+
return coord;
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
// grid_sampler_compute_source_index_set_grad works similarly to
|
| 185 |
+
// grid_sampler_compute_source_index except that it also returns the
|
| 186 |
+
// `d output / d input` via pointer argument `grad_in`.
|
| 187 |
+
// This is useful in the backward pass of grid_sampler.
|
| 188 |
+
template <typename scalar_t>
|
| 189 |
+
__forceinline__ __device__
|
| 190 |
+
scalar_t grid_sampler_compute_source_index_set_grad(
|
| 191 |
+
scalar_t coord,
|
| 192 |
+
int size,
|
| 193 |
+
GridSamplerPadding padding_mode,
|
| 194 |
+
bool align_corners,
|
| 195 |
+
scalar_t *grad_in) {
|
| 196 |
+
scalar_t grad_clip, grad_refl;
|
| 197 |
+
coord = grid_sampler_unnormalize_set_grad(coord, size, align_corners, grad_in);
|
| 198 |
+
if (padding_mode == GridSamplerPadding::Border) {
|
| 199 |
+
// clip coordinates to image borders
|
| 200 |
+
coord = clip_coordinates_set_grad(coord, size, &grad_clip);
|
| 201 |
+
*grad_in = (*grad_in) * grad_clip;
|
| 202 |
+
} else if (padding_mode == GridSamplerPadding::Reflection) {
|
| 203 |
+
// reflect coordinates by image borders
|
| 204 |
+
if (align_corners) {
|
| 205 |
+
coord = reflect_coordinates_set_grad(coord, 0, 2*(size - 1), &grad_refl);
|
| 206 |
+
} else {
|
| 207 |
+
coord = reflect_coordinates_set_grad(coord, -1, 2*size - 1, &grad_refl);
|
| 208 |
+
}
|
| 209 |
+
// clip coordinates to image borders
|
| 210 |
+
coord = clip_coordinates_set_grad(coord, size, &grad_clip);
|
| 211 |
+
*grad_in = (*grad_in) * grad_refl * grad_clip;
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
coord = safe_downgrade_to_int_range(coord);
|
| 215 |
+
return coord;
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
__forceinline__ __device__
|
| 219 |
+
bool within_bounds_2d(int h, int w, int H, int W) {
|
| 220 |
+
return h >= 0 && h < H && w >= 0 && w < W;
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
__forceinline__ __device__
|
| 224 |
+
bool within_bounds_3d(int d, int h, int w, int D, int H, int W) {
|
| 225 |
+
return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W;
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
template<typename scalar_t>
|
| 229 |
+
__forceinline__ __device__
|
| 230 |
+
scalar_t get_value_bounded(
|
| 231 |
+
const scalar_t *data, scalar_t x, scalar_t y, int W, int H, int sW, int sH,
|
| 232 |
+
GridSamplerPadding padding_mode,
|
| 233 |
+
bool align_corners) {
|
| 234 |
+
|
| 235 |
+
x = compute_coordinates(x, W, padding_mode, align_corners);
|
| 236 |
+
y = compute_coordinates(y, H, padding_mode, align_corners);
|
| 237 |
+
|
| 238 |
+
int ix = static_cast<int>(x);
|
| 239 |
+
int iy = static_cast<int>(y);
|
| 240 |
+
|
| 241 |
+
if (within_bounds_2d(iy, ix, H, W)) {
|
| 242 |
+
return data[iy * sH + ix * sW];
|
| 243 |
+
}
|
| 244 |
+
return static_cast<scalar_t>(0);
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
template<typename scalar_t, typename index_t>
|
| 248 |
+
__forceinline__ __device__
|
| 249 |
+
void safe_add_2d(scalar_t *data, int h, int w,
|
| 250 |
+
int sH, int sW, int H, int W,
|
| 251 |
+
scalar_t delta,
|
| 252 |
+
const index_t NC_offset,
|
| 253 |
+
const index_t memory_span) {
|
| 254 |
+
if (within_bounds_2d(h, w, H, W)) {
|
| 255 |
+
fastAtomicAdd(data,
|
| 256 |
+
NC_offset + h * sH + w * sW,
|
| 257 |
+
memory_span,
|
| 258 |
+
delta,
|
| 259 |
+
true);
|
| 260 |
+
}
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
template<typename scalar_t, typename index_t>
|
| 264 |
+
__forceinline__ __device__
|
| 265 |
+
void safe_add_3d(scalar_t *data, int d, int h, int w,
|
| 266 |
+
int sD, int sH, int sW, int D, int H, int W,
|
| 267 |
+
scalar_t delta,
|
| 268 |
+
const index_t NC_offset,
|
| 269 |
+
const index_t memory_span) {
|
| 270 |
+
if (within_bounds_3d(d, h, w, D, H, W)) {
|
| 271 |
+
fastAtomicAdd(data,
|
| 272 |
+
NC_offset + d * sD + h * sH + w * sW,
|
| 273 |
+
memory_span,
|
| 274 |
+
delta,
|
| 275 |
+
true);
|
| 276 |
+
}
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
template<typename scalar_t, typename index_t>
|
| 280 |
+
__forceinline__ __device__
|
| 281 |
+
void add_value_bounded(
|
| 282 |
+
scalar_t* data, scalar_t x, scalar_t y, int W, int H, int sW, int sH,
|
| 283 |
+
scalar_t delta,
|
| 284 |
+
GridSamplerPadding padding_mode,
|
| 285 |
+
bool align_corners,
|
| 286 |
+
const index_t NC_offset,
|
| 287 |
+
const index_t memory_span) {
|
| 288 |
+
|
| 289 |
+
x = compute_coordinates(x, W, padding_mode, align_corners);
|
| 290 |
+
y = compute_coordinates(y, H, padding_mode, align_corners);
|
| 291 |
+
|
| 292 |
+
int ix = static_cast<int>(x);
|
| 293 |
+
int iy = static_cast<int>(y);
|
| 294 |
+
|
| 295 |
+
safe_add_2d(data, iy, ix, sH, sW, H, W, delta, NC_offset, memory_span);
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
// Calculate the differential of the cubic convolution, i.e. `d coeff / d x`
|
| 299 |
+
template<typename scalar_t>
|
| 300 |
+
__forceinline__ __device__
|
| 301 |
+
void get_cubic_coefficients_grad(
|
| 302 |
+
scalar_t coeffs[4],
|
| 303 |
+
scalar_t t) {
|
| 304 |
+
|
| 305 |
+
// Must be the same as forward calculation in
|
| 306 |
+
// aten/src/ATen/native/cuda/UpSample.cuh:get_cubic_upsample_coefficients
|
| 307 |
+
scalar_t A = -0.75;
|
| 308 |
+
|
| 309 |
+
scalar_t x;
|
| 310 |
+
x = -1 - t; // 1 < x = |-1 - tx| < 2
|
| 311 |
+
coeffs[0] = (-3 * A * x - 10 * A ) * x - 8 * A;
|
| 312 |
+
x = -t; // x = |0 - tx| <= 1
|
| 313 |
+
coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x;
|
| 314 |
+
x = 1 - t; // x = |1 - tx| <= 1
|
| 315 |
+
coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x;
|
| 316 |
+
x = 2 - t; // 1 < x = |2 - tx| < 2
|
| 317 |
+
coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A;
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
}} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/IndexKernel.h
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <c10/core/ScalarType.h>
|
| 3 |
+
#include <cstdint>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
struct TensorIteratorBase;
|
| 7 |
+
class TensorBase;
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
namespace at {
|
| 11 |
+
namespace native {
|
| 12 |
+
/// @param maskPrefixSum[in,out]
|
| 13 |
+
void launch_masked_scatter_kernel(
|
| 14 |
+
const TensorBase &self, const TensorBase &mask,
|
| 15 |
+
const TensorBase &maskPrefixSum, const TensorBase &source);
|
| 16 |
+
}}
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/KernelUtils.cuh
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/cuda/Atomic.cuh>
|
| 3 |
+
|
| 4 |
+
#if !(defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
|
| 5 |
+
#include <cuda_bf16.h>
|
| 6 |
+
#endif
|
| 7 |
+
|
| 8 |
+
namespace at {
|
| 9 |
+
namespace native {
|
| 10 |
+
|
| 11 |
+
__device__ __forceinline__ size_t
|
| 12 |
+
idx(const size_t nc,
|
| 13 |
+
const size_t height,
|
| 14 |
+
const size_t width,
|
| 15 |
+
const size_t h,
|
| 16 |
+
const size_t w) {
|
| 17 |
+
return (nc * height + h) * width + w;
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
// for channels-last
|
| 21 |
+
__device__ __forceinline__ size_t
|
| 22 |
+
idx_cl(
|
| 23 |
+
const size_t n, const size_t h, const size_t w, const size_t c,
|
| 24 |
+
const size_t height, const size_t width, const size_t channel
|
| 25 |
+
) {
|
| 26 |
+
return ((n * height + h) * width + w) * channel + c;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
// fastSpecializedAtomicAdd (and fastAtomicAdd) are an optimization
|
| 30 |
+
// that speed up half-precision atomics. The situation with half
|
| 31 |
+
// precision atomics is that we have a slow __half atomic, and
|
| 32 |
+
// a fast vectored __half2 atomic (this can be worth up to a 6x
|
| 33 |
+
// speedup, see https://github.com/pytorch/pytorch/pull/21879).
|
| 34 |
+
// We can convert a __half atomic into a __half2 atomic by simply
|
| 35 |
+
// pairing the __half with a zero entry on the left/right depending
|
| 36 |
+
// on alignment... but only if this wouldn't cause an out of bounds
|
| 37 |
+
// access! Thus, you must specify tensor and numel so we can check
|
| 38 |
+
// if you would be out-of-bounds and use a plain __half atomic if
|
| 39 |
+
// you would be.
|
| 40 |
+
template <
|
| 41 |
+
typename scalar_t,
|
| 42 |
+
typename index_t,
|
| 43 |
+
typename std::enable_if<std::is_same<c10::Half, scalar_t>::value>::type* =
|
| 44 |
+
nullptr>
|
| 45 |
+
__device__ __forceinline__ void fastSpecializedAtomicAdd(
|
| 46 |
+
scalar_t* tensor,
|
| 47 |
+
index_t index,
|
| 48 |
+
const index_t numel,
|
| 49 |
+
scalar_t value) {
|
| 50 |
+
#if ( \
|
| 51 |
+
(defined(USE_ROCM)) || \
|
| 52 |
+
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
|
| 53 |
+
gpuAtomicAddNoReturn(
|
| 54 |
+
reinterpret_cast<at::Half*>(tensor) + index,
|
| 55 |
+
static_cast<at::Half>(value));
|
| 56 |
+
#else
|
| 57 |
+
// Accounts for the chance tensor falls on an odd 16 bit alignment (ie, not 32 bit aligned)
|
| 58 |
+
__half* target_addr = reinterpret_cast<__half*>(tensor + index);
|
| 59 |
+
bool low_byte = (reinterpret_cast<std::uintptr_t>(target_addr) % sizeof(__half2) == 0);
|
| 60 |
+
|
| 61 |
+
if (low_byte && index < (numel - 1)) {
|
| 62 |
+
__half2 value2;
|
| 63 |
+
value2.x = static_cast<__half>(value);
|
| 64 |
+
value2.y = __int2half_rz(0);
|
| 65 |
+
atomicAdd(reinterpret_cast<__half2*>(target_addr), value2);
|
| 66 |
+
|
| 67 |
+
} else if (!low_byte && index > 0) {
|
| 68 |
+
__half2 value2;
|
| 69 |
+
value2.x = __int2half_rz(0);
|
| 70 |
+
value2.y = static_cast<__half>(value);
|
| 71 |
+
atomicAdd(reinterpret_cast<__half2*>(target_addr - 1), value2);
|
| 72 |
+
|
| 73 |
+
} else {
|
| 74 |
+
atomicAdd(
|
| 75 |
+
reinterpret_cast<__half*>(tensor) + index, static_cast<__half>(value));
|
| 76 |
+
}
|
| 77 |
+
#endif
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
template <
|
| 81 |
+
typename scalar_t,
|
| 82 |
+
typename index_t,
|
| 83 |
+
typename std::enable_if<std::is_same<c10::BFloat16, scalar_t>::value>::type* =
|
| 84 |
+
nullptr>
|
| 85 |
+
__device__ __forceinline__ void fastSpecializedAtomicAdd(
|
| 86 |
+
scalar_t* tensor,
|
| 87 |
+
index_t index,
|
| 88 |
+
const index_t numel,
|
| 89 |
+
scalar_t value) {
|
| 90 |
+
#if ( \
|
| 91 |
+
(defined(USE_ROCM)) || \
|
| 92 |
+
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))
|
| 93 |
+
gpuAtomicAddNoReturn(
|
| 94 |
+
reinterpret_cast<at::BFloat16*>(tensor) + index,
|
| 95 |
+
static_cast<at::BFloat16>(value));
|
| 96 |
+
#else
|
| 97 |
+
// Accounts for the chance tensor falls on an odd 16 bit alignment (ie, not 32 bit aligned)
|
| 98 |
+
__nv_bfloat16* target_addr = reinterpret_cast<__nv_bfloat16*>(tensor + index);
|
| 99 |
+
bool low_byte = (reinterpret_cast<std::uintptr_t>(target_addr) % sizeof(__nv_bfloat162) == 0);
|
| 100 |
+
|
| 101 |
+
if (low_byte && index < (numel - 1)) {
|
| 102 |
+
__nv_bfloat162 value2;
|
| 103 |
+
value2.x = *reinterpret_cast<__nv_bfloat16*>(&value);
|
| 104 |
+
value2.y = __int2bfloat16_rz(0);
|
| 105 |
+
atomicAdd(reinterpret_cast<__nv_bfloat162*>(target_addr), value2);
|
| 106 |
+
|
| 107 |
+
} else if (!low_byte && index > 0) {
|
| 108 |
+
__nv_bfloat162 value2;
|
| 109 |
+
value2.x = __int2bfloat16_rz(0);
|
| 110 |
+
value2.y = *reinterpret_cast<__nv_bfloat16*>(&value);
|
| 111 |
+
atomicAdd(reinterpret_cast<__nv_bfloat162*>(target_addr - 1), value2);
|
| 112 |
+
|
| 113 |
+
} else {
|
| 114 |
+
atomicAdd(
|
| 115 |
+
reinterpret_cast<__nv_bfloat16*>(tensor) + index, *reinterpret_cast<__nv_bfloat16*>(&value));
|
| 116 |
+
}
|
| 117 |
+
#endif
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
template <
|
| 122 |
+
typename scalar_t,
|
| 123 |
+
typename index_t,
|
| 124 |
+
typename std::enable_if<!std::is_same<c10::Half, scalar_t>::value && !std::is_same<c10::BFloat16, scalar_t>::value >::type* =
|
| 125 |
+
nullptr>
|
| 126 |
+
__device__ __forceinline__ void fastSpecializedAtomicAdd(
|
| 127 |
+
scalar_t* tensor,
|
| 128 |
+
index_t index,
|
| 129 |
+
const index_t numel,
|
| 130 |
+
scalar_t value) {
|
| 131 |
+
gpuAtomicAddNoReturn(tensor + index, value);
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
template <class scalar_t, class index_t>
|
| 135 |
+
__device__ __forceinline__ void fastAtomicAdd(
|
| 136 |
+
scalar_t* tensor,
|
| 137 |
+
index_t index,
|
| 138 |
+
const index_t numel,
|
| 139 |
+
scalar_t value,
|
| 140 |
+
bool fast_atomics) {
|
| 141 |
+
if (fast_atomics) {
|
| 142 |
+
fastSpecializedAtomicAdd(tensor, index, numel, value);
|
| 143 |
+
} else {
|
| 144 |
+
gpuAtomicAddNoReturn(tensor + index, value);
|
| 145 |
+
}
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
} // namespace native
|
| 149 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/LaunchUtils.h
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include<algorithm>
|
| 3 |
+
|
| 4 |
+
namespace at {
|
| 5 |
+
namespace native {
|
| 6 |
+
|
| 7 |
+
// returns 2**floor(log2(n))
|
| 8 |
+
static int lastPow2(unsigned int n) {
|
| 9 |
+
n |= (n >> 1);
|
| 10 |
+
n |= (n >> 2);
|
| 11 |
+
n |= (n >> 4);
|
| 12 |
+
n |= (n >> 8);
|
| 13 |
+
n |= (n >> 16);
|
| 14 |
+
return std::max<int>(1, n - (n >> 1));
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
} // namespace native
|
| 18 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/MemoryAccess.cuh
ADDED
|
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cstdint>
|
| 4 |
+
#include <type_traits>
|
| 5 |
+
#include <c10/core/DynamicCast.h>
|
| 6 |
+
#include <c10/util/Exception.h>
|
| 7 |
+
#include <c10/util/TypeCast.h>
|
| 8 |
+
#include <c10/macros/Macros.h>
|
| 9 |
+
#include <ATen/core/Array.h>
|
| 10 |
+
#include <ATen/detail/FunctionTraits.h>
|
| 11 |
+
#include <ATen/cuda/detail/OffsetCalculator.cuh>
|
| 12 |
+
#include <ATen/native/cuda/thread_constants.h>
|
| 13 |
+
|
| 14 |
+
#include <thrust/tuple.h>
|
| 15 |
+
|
| 16 |
+
// References:
|
| 17 |
+
// https://devblogs.nvidia.com/cuda-pro-tip-increase-performance-with-vectorized-memory-access/
|
| 18 |
+
|
| 19 |
+
namespace at { namespace native { namespace memory {
|
| 20 |
+
|
| 21 |
+
namespace detail {
|
| 22 |
+
|
| 23 |
+
// What does the `static_unroll` do?
|
| 24 |
+
//
|
| 25 |
+
// We want to do something like:
|
| 26 |
+
//
|
| 27 |
+
// using args_t = typename traits::ArgsTuple;
|
| 28 |
+
// args_t args;
|
| 29 |
+
// #pragma unroll
|
| 30 |
+
// for (int i = 0; i < traits::arity; i++) {
|
| 31 |
+
// std::get<i>(args) = ....
|
| 32 |
+
// }
|
| 33 |
+
//
|
| 34 |
+
// but unfortunately the above code does not work because
|
| 35 |
+
// the template argument has to be a compile time constant
|
| 36 |
+
// so `static_unroll` is created to simulate `#pragma unroll`
|
| 37 |
+
// using template metaprogramming.
|
| 38 |
+
|
| 39 |
+
template<template<int i> typename func, int end, int current=0>
|
| 40 |
+
struct static_unroll {
|
| 41 |
+
template<typename... Args>
|
| 42 |
+
static inline C10_HOST_DEVICE void with_args(Args&&... args) {
|
| 43 |
+
func<current>::apply(std::forward<Args>(args)...);
|
| 44 |
+
static_unroll<func, end, current+1>::with_args(args...);
|
| 45 |
+
}
|
| 46 |
+
};
|
| 47 |
+
|
| 48 |
+
template<template<int i> typename func, int end>
|
| 49 |
+
struct static_unroll<func, end, end> {
|
| 50 |
+
template<typename... Args>
|
| 51 |
+
static inline C10_HOST_DEVICE void with_args(Args... args) {}
|
| 52 |
+
};
|
| 53 |
+
|
| 54 |
+
// helper structs to be used with static_unroll to load arguments
|
| 55 |
+
// one by one
|
| 56 |
+
|
| 57 |
+
template<int arg_index>
|
| 58 |
+
struct vectorized_load_helper {
|
| 59 |
+
template <typename args_t, typename policy_t>
|
| 60 |
+
static __device__ void apply(policy_t &self, args_t *args, int idx) {
|
| 61 |
+
using arg_t = std::tuple_element_t<arg_index, args_t>;
|
| 62 |
+
// `data` hold the data_ptr for tensors [output, input0, input1, ...], so we
|
| 63 |
+
// need a +1 offset to get the input
|
| 64 |
+
auto ptr = reinterpret_cast<arg_t *>(self.data[arg_index + 1]) + block_work_size() * idx;
|
| 65 |
+
auto args_accessor = [&args] __device__ (int thread_unroll_idx) -> arg_t & { return std::get<arg_index>(args[thread_unroll_idx]); };
|
| 66 |
+
self.load_single_arg(args_accessor, ptr);
|
| 67 |
+
}
|
| 68 |
+
};
|
| 69 |
+
|
| 70 |
+
template<int arg_index>
|
| 71 |
+
struct unroll_load_helper {
|
| 72 |
+
template <typename args_t, typename policy_t, typename offset_t, typename loader_t>
|
| 73 |
+
static __device__ void apply(policy_t &self, args_t *args, offset_t offset, loader_t loader, int j, int num_outputs) {
|
| 74 |
+
using arg_t = std::tuple_element_t<arg_index, args_t>;
|
| 75 |
+
// `data` hold the data_ptr for tensors [output, input0, input1, ...], so we
|
| 76 |
+
// need a +1 offset to get the input
|
| 77 |
+
std::get<arg_index>(args[j]) = loader.template load<arg_t>(self.data[arg_index + num_outputs], offset[arg_index], arg_index);
|
| 78 |
+
}
|
| 79 |
+
};
|
| 80 |
+
|
| 81 |
+
template <int current>
|
| 82 |
+
struct multi_outputs_store_helper {
|
| 83 |
+
template<int ntensors, int num_outputs, typename ...Args>
|
| 84 |
+
C10_HOST_DEVICE static void apply(
|
| 85 |
+
at::detail::Array<char*, ntensors> data,
|
| 86 |
+
at::detail::Array<uint32_t, num_outputs> offsets,
|
| 87 |
+
thrust::tuple<Args...> ret) {
|
| 88 |
+
using T = typename thrust::tuple_element<current, thrust::tuple<Args...>>::type;
|
| 89 |
+
T *to = reinterpret_cast<T *>(data[current]) + offsets[current];
|
| 90 |
+
*to = thrust::get<current>(ret);
|
| 91 |
+
}
|
| 92 |
+
};
|
| 93 |
+
|
| 94 |
+
} // namespace detail
|
| 95 |
+
|
| 96 |
+
struct LoadWithoutCast {
|
| 97 |
+
template<typename scalar_t>
|
| 98 |
+
__device__ scalar_t load(char *base_ptr, uint32_t offset, int arg) {
|
| 99 |
+
return c10::load(reinterpret_cast<scalar_t *>(base_ptr) + offset);
|
| 100 |
+
}
|
| 101 |
+
};
|
| 102 |
+
|
| 103 |
+
template <int N>
|
| 104 |
+
struct LoadWithCast {
|
| 105 |
+
using array_t = at::detail::Array<at::ScalarType, std::max<int>(N, 1)>;
|
| 106 |
+
using size_array_t = at::detail::Array<uint32_t, std::max<int>(N, 1)>;
|
| 107 |
+
|
| 108 |
+
array_t dtypes;
|
| 109 |
+
size_array_t element_sizes;
|
| 110 |
+
|
| 111 |
+
LoadWithCast(const TensorIteratorBase& iter) {
|
| 112 |
+
CUDA_KERNEL_ASSERT(iter.ninputs() == N);
|
| 113 |
+
#pragma unroll
|
| 114 |
+
for (auto i = 0; i < N; ++i) {
|
| 115 |
+
this->dtypes[i] = iter.dtype(i + iter.noutputs());
|
| 116 |
+
element_sizes[i] = c10::elementSize(iter.dtype(i + iter.noutputs()));
|
| 117 |
+
}
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
template<typename scalar_t>
|
| 121 |
+
__device__ scalar_t load(char *base_ptr, uint32_t offset, int arg) {
|
| 122 |
+
void *ptr = base_ptr + element_sizes[arg] * offset;
|
| 123 |
+
return c10::fetch_and_cast<scalar_t>(dtypes[arg], ptr);
|
| 124 |
+
}
|
| 125 |
+
};
|
| 126 |
+
|
| 127 |
+
struct StoreWithoutCast {
|
| 128 |
+
template<typename scalar_t>
|
| 129 |
+
__device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg = 0) {
|
| 130 |
+
*(reinterpret_cast<scalar_t *>(base_ptr) + offset) = value;
|
| 131 |
+
}
|
| 132 |
+
};
|
| 133 |
+
|
| 134 |
+
template <int N = 1>
|
| 135 |
+
struct StoreWithCast {
|
| 136 |
+
using array_t = at::detail::Array<at::ScalarType, std::max<int>(N, 1)>;
|
| 137 |
+
using size_array_t = at::detail::Array<uint32_t, std::max<int>(N, 1)>;
|
| 138 |
+
|
| 139 |
+
array_t dtypes;
|
| 140 |
+
size_array_t element_sizes;
|
| 141 |
+
|
| 142 |
+
StoreWithCast(const TensorIteratorBase& iter) {
|
| 143 |
+
CUDA_KERNEL_ASSERT(iter.noutputs() == N);
|
| 144 |
+
#pragma unroll
|
| 145 |
+
for (auto i = 0; i < N; ++i) {
|
| 146 |
+
this->dtypes[i] = iter.dtype(i);
|
| 147 |
+
element_sizes[i] = c10::elementSize(iter.dtype(i));
|
| 148 |
+
}
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
template<typename scalar_t>
|
| 152 |
+
__device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg = 0) {
|
| 153 |
+
void *ptr = base_ptr + element_sizes[arg] * offset;
|
| 154 |
+
c10::cast_and_store<scalar_t>(dtypes[arg], ptr, value);
|
| 155 |
+
}
|
| 156 |
+
};
|
| 157 |
+
|
| 158 |
+
// aligned vector generates vectorized load/store on CUDA
|
| 159 |
+
template<typename scalar_t, int vec_size>
|
| 160 |
+
struct alignas(sizeof(scalar_t) * vec_size) aligned_vector {
|
| 161 |
+
scalar_t val[vec_size];
|
| 162 |
+
};
|
| 163 |
+
|
| 164 |
+
template <int vec_size, typename scalar_t>
|
| 165 |
+
__device__ aligned_vector<scalar_t, vec_size> load_vector(const scalar_t *base_ptr, uint32_t offset) {
|
| 166 |
+
using vec_t = aligned_vector<scalar_t, vec_size>;
|
| 167 |
+
auto *from = reinterpret_cast<const vec_t *>(base_ptr);
|
| 168 |
+
return from[offset];
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
template <int vec_size>
|
| 172 |
+
__device__ aligned_vector<bool, vec_size> load_vector(const bool *base_ptr, uint32_t offset) {
|
| 173 |
+
// See NOTE [Loading boolean values]
|
| 174 |
+
auto tmp = load_vector<vec_size>(reinterpret_cast<const uint8_t*>(base_ptr), offset);
|
| 175 |
+
aligned_vector<bool, vec_size> ret;
|
| 176 |
+
for (int i = 0; i < vec_size; ++i) {
|
| 177 |
+
ret.val[i] = bool(tmp.val[i]);
|
| 178 |
+
}
|
| 179 |
+
return ret;
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
namespace policies {
|
| 183 |
+
|
| 184 |
+
// Assumption:
|
| 185 |
+
// all tensors are contiguous, that is: stride == sizeof(type) for all tensors
|
| 186 |
+
template<typename data_t, typename inp_calc_t, typename out_calc_t, typename loader_t, typename storer_t, int num_outputs = 1>
|
| 187 |
+
struct unroll {
|
| 188 |
+
|
| 189 |
+
data_t data;
|
| 190 |
+
int remaining;
|
| 191 |
+
inp_calc_t input_offset_calculator;
|
| 192 |
+
out_calc_t output_offset_calculator;
|
| 193 |
+
loader_t loader;
|
| 194 |
+
storer_t storer;
|
| 195 |
+
|
| 196 |
+
__device__ unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s):
|
| 197 |
+
data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc), loader(l), storer(s) {}
|
| 198 |
+
|
| 199 |
+
__device__ inline bool check_inbounds(int thread_work_elem) {
|
| 200 |
+
return ((int)(threadIdx.x + thread_work_elem*num_threads()) < remaining);
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
template<typename args_t>
|
| 204 |
+
__device__ inline void load(args_t *args, int idx) {
|
| 205 |
+
constexpr int arity = std::tuple_size<args_t>::value;
|
| 206 |
+
int thread_idx = threadIdx.x;
|
| 207 |
+
#pragma unroll
|
| 208 |
+
for (int i = 0; i < thread_work_size(); i++) {
|
| 209 |
+
if (thread_idx >= remaining) {
|
| 210 |
+
return;
|
| 211 |
+
}
|
| 212 |
+
int linear_idx = thread_idx + block_work_size() * idx;
|
| 213 |
+
auto offset = input_offset_calculator.get(linear_idx);
|
| 214 |
+
detail::static_unroll<detail::unroll_load_helper, arity>::with_args(*this, args, offset, loader, i, num_outputs);
|
| 215 |
+
thread_idx += num_threads();
|
| 216 |
+
}
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
template<typename scalar_t>
|
| 220 |
+
__device__ inline void store(scalar_t *from, int idx) {
|
| 221 |
+
int thread_idx = threadIdx.x;
|
| 222 |
+
#pragma unroll
|
| 223 |
+
for (int i = 0; i < thread_work_size(); i++) {
|
| 224 |
+
if (thread_idx >= remaining) {
|
| 225 |
+
return;
|
| 226 |
+
}
|
| 227 |
+
int linear_idx = thread_idx + block_work_size() * idx;
|
| 228 |
+
int offset = output_offset_calculator.get(linear_idx)[0];
|
| 229 |
+
storer.store(from[i], data[0], offset);
|
| 230 |
+
thread_idx += num_threads();
|
| 231 |
+
}
|
| 232 |
+
}
|
| 233 |
+
};
|
| 234 |
+
|
| 235 |
+
// Assumption:
|
| 236 |
+
// all tensors are contiguous, that is: stride == sizeof(type) for all tensors
|
| 237 |
+
// Note:
|
| 238 |
+
// Functions in vectorized policy does not do boundary check. It assumes the whole block
|
| 239 |
+
// has its job to do. So the reminders should be handled by the caller manually.
|
| 240 |
+
template <int vec_size, typename data_t> // vec_size: number of scalars, can be 1, 2, or 4.
|
| 241 |
+
struct vectorized {
|
| 242 |
+
|
| 243 |
+
static_assert(thread_work_size() % vec_size == 0, "The workload per thread must be a multiple of vec_size");
|
| 244 |
+
static constexpr int loop_size = thread_work_size() / vec_size;
|
| 245 |
+
|
| 246 |
+
data_t data;
|
| 247 |
+
|
| 248 |
+
__device__ vectorized(data_t data) : data(data) {}
|
| 249 |
+
|
| 250 |
+
__device__ inline constexpr bool check_inbounds(int thread_work_elem) {
|
| 251 |
+
return true;
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
template<typename accessor_t, typename scalar_t>
|
| 255 |
+
__device__ inline void load_single_arg(accessor_t to, scalar_t *from) {
|
| 256 |
+
int thread_idx = threadIdx.x;
|
| 257 |
+
#pragma unroll
|
| 258 |
+
for (int i = 0; i < loop_size; i++) {
|
| 259 |
+
int index = thread_idx + i * num_threads();
|
| 260 |
+
auto v = load_vector<vec_size>(from, index);
|
| 261 |
+
#pragma unroll
|
| 262 |
+
for (int j = 0; j < vec_size; j++) {
|
| 263 |
+
to(vec_size * i + j) = v.val[j];
|
| 264 |
+
}
|
| 265 |
+
}
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
template<typename args_t>
|
| 269 |
+
__device__ inline void load(args_t *args, int idx) {
|
| 270 |
+
constexpr int arity = std::tuple_size<args_t>::value;
|
| 271 |
+
detail::static_unroll<detail::vectorized_load_helper, arity>::with_args(*this, args, idx);
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
template<typename scalar_t>
|
| 275 |
+
__device__ inline void store(scalar_t *from, int idx) {
|
| 276 |
+
using vec_t = aligned_vector<scalar_t, vec_size>;
|
| 277 |
+
scalar_t *to = reinterpret_cast<scalar_t *>(data[0]) + block_work_size() * idx;
|
| 278 |
+
vec_t *to_ = reinterpret_cast<vec_t *>(to);
|
| 279 |
+
int thread_idx = threadIdx.x;
|
| 280 |
+
#pragma unroll
|
| 281 |
+
for (int i = 0; i < loop_size; i++) {
|
| 282 |
+
int index = thread_idx + i * num_threads();
|
| 283 |
+
vec_t v;
|
| 284 |
+
for (int j = 0; j < vec_size; j++) {
|
| 285 |
+
v.val[j] = from[vec_size * i + j];
|
| 286 |
+
}
|
| 287 |
+
to_[index] = v;
|
| 288 |
+
}
|
| 289 |
+
}
|
| 290 |
+
};
|
| 291 |
+
|
| 292 |
+
template <typename data_t, typename inp_calc_t, typename out_calc_t, int num_outputs>
|
| 293 |
+
struct multi_outputs_unroll {
|
| 294 |
+
//multi_outputs_unroll struct members and check_inbounds and load methods are copypasted from unroll struct
|
| 295 |
+
//we don't use inheritance because of compiler bug in cuda 10.2+
|
| 296 |
+
data_t data;
|
| 297 |
+
int remaining;
|
| 298 |
+
inp_calc_t input_offset_calculator;
|
| 299 |
+
out_calc_t output_offset_calculator;
|
| 300 |
+
LoadWithoutCast loader;
|
| 301 |
+
StoreWithoutCast storer;
|
| 302 |
+
|
| 303 |
+
__device__ multi_outputs_unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc):
|
| 304 |
+
data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc) {}
|
| 305 |
+
|
| 306 |
+
__device__ inline bool check_inbounds(int thread_work_elem) {
|
| 307 |
+
return ((int)(threadIdx.x + thread_work_elem*num_threads()) < remaining);
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
template<typename args_t>
|
| 311 |
+
__device__ inline void load(args_t *args, int idx) {
|
| 312 |
+
constexpr int arity = std::tuple_size<args_t>::value;
|
| 313 |
+
int thread_idx = threadIdx.x;
|
| 314 |
+
#pragma unroll
|
| 315 |
+
for (int i = 0; i < thread_work_size(); i++) {
|
| 316 |
+
if (thread_idx >= remaining) {
|
| 317 |
+
return;
|
| 318 |
+
}
|
| 319 |
+
int linear_idx = thread_idx + block_work_size() * idx;
|
| 320 |
+
auto offset = input_offset_calculator.get(linear_idx);
|
| 321 |
+
detail::static_unroll<detail::unroll_load_helper, arity>::with_args(*this, args, offset, loader, i, num_outputs);
|
| 322 |
+
thread_idx += num_threads();
|
| 323 |
+
}
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
template <typename return_t>
|
| 328 |
+
__device__ inline void store(return_t *from, int idx) {
|
| 329 |
+
int thread_idx = threadIdx.x;
|
| 330 |
+
#pragma unroll
|
| 331 |
+
for (int i = 0; i < thread_work_size(); i++) {
|
| 332 |
+
if (thread_idx >= this->remaining) {
|
| 333 |
+
return;
|
| 334 |
+
}
|
| 335 |
+
int linear_idx = thread_idx + block_work_size() * idx;
|
| 336 |
+
auto offsets = this->output_offset_calculator.get(linear_idx);
|
| 337 |
+
memory::detail::static_unroll<detail::multi_outputs_store_helper, num_outputs>::with_args(this->data, offsets, from[i]);
|
| 338 |
+
thread_idx += num_threads();
|
| 339 |
+
}
|
| 340 |
+
}
|
| 341 |
+
};
|
| 342 |
+
|
| 343 |
+
} // namespace policies
|
| 344 |
+
|
| 345 |
+
// This is only used in host, but we will wrap this into some templates
|
| 346 |
+
// which is C10_HOST_DEVICE, so we have to make this C10_HOST_DEVICE
|
| 347 |
+
// in order to compile
|
| 348 |
+
template<typename scalar_t>
|
| 349 |
+
inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) {
|
| 350 |
+
uint64_t address = reinterpret_cast<uint64_t>(pointer);
|
| 351 |
+
constexpr int vec2_alignment = std::alignment_of<aligned_vector<scalar_t, 2>>::value;
|
| 352 |
+
constexpr int vec4_alignment = std::alignment_of<aligned_vector<scalar_t, 4>>::value;
|
| 353 |
+
if (address % vec4_alignment == 0) {
|
| 354 |
+
return 4;
|
| 355 |
+
} else if (address % vec2_alignment == 0) {
|
| 356 |
+
return 2;
|
| 357 |
+
}
|
| 358 |
+
return 1;
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
template<typename scalar_t>
|
| 362 |
+
inline C10_HOST_DEVICE int can_vectorize_up_to(char *pointer) {
|
| 363 |
+
return can_vectorize_up_to<scalar_t>(static_cast<const char*>(pointer));
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
template<int i>
|
| 367 |
+
struct can_vectorize_up_to_helper {
|
| 368 |
+
template <typename array_t, typename traits>
|
| 369 |
+
static C10_HOST_DEVICE void apply(int &result, array_t pointers, traits _) {
|
| 370 |
+
using arg_t = typename traits::template arg<i>::type;
|
| 371 |
+
// `pointers` hold the data_ptr for tensors [output, input0, input1, ...], so we
|
| 372 |
+
// need a +1 offset to get the input
|
| 373 |
+
result = std::min<int>(result, can_vectorize_up_to<arg_t>(pointers[i + 1]));
|
| 374 |
+
}
|
| 375 |
+
};
|
| 376 |
+
|
| 377 |
+
template<typename func_t, typename array_t>
|
| 378 |
+
inline int can_vectorize_up_to(array_t pointers) {
|
| 379 |
+
using traits = function_traits<func_t>;
|
| 380 |
+
using return_t = typename traits::result_type;
|
| 381 |
+
constexpr int arity = traits::arity;
|
| 382 |
+
int result = can_vectorize_up_to<return_t>(pointers[0]);
|
| 383 |
+
// We need to get the type for each argument of `func_t`, this can only
|
| 384 |
+
// be done at compile time.
|
| 385 |
+
detail::static_unroll<can_vectorize_up_to_helper, arity>::with_args(result, pointers, traits());
|
| 386 |
+
return result;
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
}}} // namespace at::native::memory
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/MiscUtils.h
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/cuda/Exceptions.h>
|
| 3 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 4 |
+
#include <ATen/cuda/CUDAConfig.h>
|
| 5 |
+
#include <ATen/cuda/PinnedMemoryAllocator.h>
|
| 6 |
+
|
| 7 |
+
namespace at {
|
| 8 |
+
namespace native {
|
| 9 |
+
|
| 10 |
+
static inline int cuda_int_cast(int64_t value, const char* varname) {
|
| 11 |
+
auto result = static_cast<int>(value);
|
| 12 |
+
TORCH_CHECK(static_cast<int64_t>(result) == value,
|
| 13 |
+
"cuda_int_cast: The value of ", varname, "(", (long long)value,
|
| 14 |
+
") is too large to fit into a int (", sizeof(int), " bytes)");
|
| 15 |
+
return result;
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
// Creates an array of size elements of type T, backed by pinned memory
|
| 19 |
+
// wrapped in a Storage
|
| 20 |
+
template<class T>
|
| 21 |
+
static inline Storage pin_memory(int64_t size) {
|
| 22 |
+
auto* allocator = cuda::getPinnedMemoryAllocator();
|
| 23 |
+
int64_t adjusted_size = size * sizeof(T);
|
| 24 |
+
return Storage(
|
| 25 |
+
Storage::use_byte_size_t(),
|
| 26 |
+
adjusted_size,
|
| 27 |
+
allocator,
|
| 28 |
+
/*resizable=*/false);
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
} // namespace native
|
| 32 |
+
} // namespace at
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/MultiTensorApply.cuh
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Tensor.h>
|
| 3 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 4 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 5 |
+
#include <ATen/native/cuda/Loops.cuh>
|
| 6 |
+
#include <ATen/native/cuda/MemoryAccess.cuh>
|
| 7 |
+
#include <vector>
|
| 8 |
+
|
| 9 |
+
namespace at::native {
|
| 10 |
+
|
| 11 |
+
namespace {
|
| 12 |
+
|
| 13 |
+
static constexpr int64_t kILP = 4;
|
| 14 |
+
static constexpr int64_t kChunkSize = 65536;
|
| 15 |
+
static constexpr int64_t kBlockSize = 512;
|
| 16 |
+
|
| 17 |
+
// TODO(crcrpar): Add `n>5` for `low prec params & their higher prec copy`
|
| 18 |
+
// TensorListMetadata has to be < 4KB - the limit for kernel launch argument
|
| 19 |
+
static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
|
| 20 |
+
static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
|
| 21 |
+
static constexpr int depth_to_max_tensors_scalarlist[5] = {96, 64, 48, 36, 30};
|
| 22 |
+
static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = {
|
| 23 |
+
72,
|
| 24 |
+
60};
|
| 25 |
+
|
| 26 |
+
template <typename T>
|
| 27 |
+
__device__ __forceinline__ bool is_aligned(T* p) {
|
| 28 |
+
return ((uint64_t)p) % (kILP * sizeof(T)) == 0;
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
template <typename T>
|
| 32 |
+
__device__ __forceinline__ void load_store(
|
| 33 |
+
T* dst,
|
| 34 |
+
T* src,
|
| 35 |
+
int64_t dst_offset,
|
| 36 |
+
int64_t src_offset) {
|
| 37 |
+
using LT = at::native::memory::aligned_vector<T, kILP>;
|
| 38 |
+
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
template <int n>
|
| 42 |
+
struct TensorListMetadata {
|
| 43 |
+
const void* addresses[n][depth_to_max_tensors[n - 1]];
|
| 44 |
+
int64_t numel_for_tensor[depth_to_max_tensors[n - 1]];
|
| 45 |
+
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
|
| 46 |
+
int block_to_chunk[depth_to_max_blocks[n - 1]];
|
| 47 |
+
int start_tensor_this_launch;
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
+
template <typename scalar_vals_t, int n>
|
| 51 |
+
struct TensorListScalarListMetadata {
|
| 52 |
+
const void* addresses[n][depth_to_max_tensors_scalarlist[n - 1]];
|
| 53 |
+
int64_t numel_for_tensor[depth_to_max_tensors_scalarlist[n - 1]];
|
| 54 |
+
scalar_vals_t scalar_vals[depth_to_max_tensors_scalarlist[n - 1]];
|
| 55 |
+
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
|
| 56 |
+
int block_to_chunk[depth_to_max_blocks[n - 1]];
|
| 57 |
+
};
|
| 58 |
+
|
| 59 |
+
// note(mkozuki): `n` of 1&2 violate the limit of cuda kernel argument size of
|
| 60 |
+
// 4kb with `c10::complex<double>`
|
| 61 |
+
template <>
|
| 62 |
+
struct TensorListScalarListMetadata<c10::complex<double>, 1> {
|
| 63 |
+
const void* addresses[1]
|
| 64 |
+
[depth_to_max_tensors_scalarlist_of_complex_double[0]];
|
| 65 |
+
int64_t
|
| 66 |
+
numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[0]];
|
| 67 |
+
c10::complex<double>
|
| 68 |
+
scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[0]];
|
| 69 |
+
unsigned char block_to_tensor[depth_to_max_blocks[1 - 1]];
|
| 70 |
+
int block_to_chunk[depth_to_max_blocks[1 - 1]];
|
| 71 |
+
};
|
| 72 |
+
|
| 73 |
+
template <>
|
| 74 |
+
struct TensorListScalarListMetadata<c10::complex<double>, 2> {
|
| 75 |
+
const void* addresses[2]
|
| 76 |
+
[depth_to_max_tensors_scalarlist_of_complex_double[1]];
|
| 77 |
+
int64_t
|
| 78 |
+
numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[1]];
|
| 79 |
+
c10::complex<double>
|
| 80 |
+
scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[1]];
|
| 81 |
+
unsigned char block_to_tensor[depth_to_max_blocks[2 - 1]];
|
| 82 |
+
int block_to_chunk[depth_to_max_blocks[2 - 1]];
|
| 83 |
+
};
|
| 84 |
+
|
| 85 |
+
// NOTE(crcrpar): This is a conservative resolution to handle `state_steps`
|
| 86 |
+
// whose each element is `at::Tensor` of 1 element representing the number of
|
| 87 |
+
// `step`s called so far.
|
| 88 |
+
template <int n>
|
| 89 |
+
struct FusedOptimizerTensorListMetadata {
|
| 90 |
+
const void* addresses[n][depth_to_max_tensors[n - 1]];
|
| 91 |
+
int64_t numel_for_tensor[depth_to_max_tensors[n - 1]];
|
| 92 |
+
const void* state_steps_addresses[depth_to_max_tensors_scalarlist[n - 1]];
|
| 93 |
+
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
|
| 94 |
+
int block_to_chunk[depth_to_max_blocks[n - 1]];
|
| 95 |
+
int start_tensor_this_launch;
|
| 96 |
+
};
|
| 97 |
+
|
| 98 |
+
template <typename T, typename U, typename... ArgTypes>
|
| 99 |
+
C10_LAUNCH_BOUNDS_1(kBlockSize)
|
| 100 |
+
__global__ void multi_tensor_apply_kernel(
|
| 101 |
+
T tensorListMeta,
|
| 102 |
+
U callable,
|
| 103 |
+
ArgTypes... args) {
|
| 104 |
+
// Hand the chunk information to the user-supplied functor to process however
|
| 105 |
+
// it likes.
|
| 106 |
+
callable(kChunkSize, tensorListMeta, args...);
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
} // namespace
|
| 110 |
+
|
| 111 |
+
// multi_tensor_apply enables horizontal fusion across lists of tensors.
|
| 112 |
+
// For example, whereas you once had a for-loop of a + b = c, where a, b,
|
| 113 |
+
// and c are individual tensors in lists as, bs, and cs, you can now with
|
| 114 |
+
// fewer kernel launches compute as + bs = cs.
|
| 115 |
+
//
|
| 116 |
+
// You can also imagine bs to be a scalar list vs a tensor list.
|
| 117 |
+
//
|
| 118 |
+
// The function below takes in tensor lists, scalars, and a callable and
|
| 119 |
+
// chunks up the computation to launch as few kernels as possible by iterating
|
| 120 |
+
// through every "chunk" in every tensor (thus the nested for loops). In the
|
| 121 |
+
// simplest case, everything gets bundled into just one kernel launch, but
|
| 122 |
+
// due to blocksize constraints, we may need to launch multiple kernels.
|
| 123 |
+
// Each kernel launch is defined by one tensorListMeta construct, which we
|
| 124 |
+
// use to track and reset the necessary metadata for each launch.
|
| 125 |
+
template <int depth, typename scalar_T, typename T, typename... ArgTypes>
|
| 126 |
+
void multi_tensor_apply(
|
| 127 |
+
std::vector<std::vector<at::Tensor>>& tensor_lists,
|
| 128 |
+
at::ArrayRef<Scalar> scalars,
|
| 129 |
+
T callable,
|
| 130 |
+
ArgTypes... args) {
|
| 131 |
+
TORCH_CHECK(
|
| 132 |
+
tensor_lists.size() == depth,
|
| 133 |
+
"Number of tensor lists has to match the depth.");
|
| 134 |
+
const size_t n_tensors = tensor_lists[0].size();
|
| 135 |
+
using scalar_vals_t = typename T::opmath_t;
|
| 136 |
+
TensorListScalarListMetadata<scalar_vals_t, depth> tensorListMeta;
|
| 137 |
+
|
| 138 |
+
int loc_block_info = 0;
|
| 139 |
+
int loc_tensor_info = 0;
|
| 140 |
+
for (size_t t = 0; t < n_tensors; t++) {
|
| 141 |
+
// short-circuit to avoid adding empty tensors to tensorListMeta
|
| 142 |
+
if (tensor_lists[0][t].numel() == 0) {
|
| 143 |
+
continue;
|
| 144 |
+
}
|
| 145 |
+
tensorListMeta.scalar_vals[loc_tensor_info] = scalars[t].to<scalar_T>();
|
| 146 |
+
tensorListMeta.numel_for_tensor[loc_tensor_info] =
|
| 147 |
+
tensor_lists[0][t].numel();
|
| 148 |
+
for (int d = 0; d < depth; d++) {
|
| 149 |
+
tensorListMeta.addresses[d][loc_tensor_info] =
|
| 150 |
+
tensor_lists[d][t].const_data_ptr();
|
| 151 |
+
}
|
| 152 |
+
loc_tensor_info++;
|
| 153 |
+
|
| 154 |
+
// now we enter [chunking territory].
|
| 155 |
+
// we will launch a kernel when EITHER the blocks get filled up OR
|
| 156 |
+
// the tensors get filled up. There will always be at least one block
|
| 157 |
+
// per tensor since the zero-sized ones will not enter the loop, so
|
| 158 |
+
// the nested forloop within represents iterating through the chunks
|
| 159 |
+
// of a single tensor.
|
| 160 |
+
const auto numel = tensor_lists[0][t].numel();
|
| 161 |
+
const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0);
|
| 162 |
+
for (auto chunk = 0; chunk < chunks; chunk++) {
|
| 163 |
+
tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
|
| 164 |
+
tensorListMeta.block_to_chunk[loc_block_info] = chunk;
|
| 165 |
+
loc_block_info++;
|
| 166 |
+
|
| 167 |
+
// a tensor is not considered full unless all its chunks have been
|
| 168 |
+
// processed
|
| 169 |
+
const bool tensors_full =
|
| 170 |
+
(loc_tensor_info == depth_to_max_tensors_scalarlist[depth - 1] &&
|
| 171 |
+
chunk == chunks - 1);
|
| 172 |
+
const bool blocks_full =
|
| 173 |
+
(loc_block_info == depth_to_max_blocks[depth - 1]);
|
| 174 |
+
|
| 175 |
+
if (tensors_full || blocks_full) {
|
| 176 |
+
multi_tensor_apply_kernel<<<
|
| 177 |
+
loc_block_info,
|
| 178 |
+
kBlockSize,
|
| 179 |
+
0,
|
| 180 |
+
at::cuda::getCurrentCUDAStream()>>>(
|
| 181 |
+
tensorListMeta, callable, args...);
|
| 182 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 183 |
+
|
| 184 |
+
// Reset.
|
| 185 |
+
loc_block_info = 0;
|
| 186 |
+
// all chunks have already been handled in the kernel
|
| 187 |
+
if (chunk == chunks - 1) {
|
| 188 |
+
loc_tensor_info = 0;
|
| 189 |
+
} else { // blocks were full and tensor chunks remain
|
| 190 |
+
tensorListMeta.numel_for_tensor[0] =
|
| 191 |
+
tensorListMeta.numel_for_tensor[loc_tensor_info - 1];
|
| 192 |
+
tensorListMeta.scalar_vals[0] =
|
| 193 |
+
tensorListMeta.scalar_vals[loc_tensor_info - 1];
|
| 194 |
+
for (int d = 0; d < depth; d++) {
|
| 195 |
+
tensorListMeta.addresses[d][0] =
|
| 196 |
+
tensorListMeta.addresses[d][loc_tensor_info - 1];
|
| 197 |
+
}
|
| 198 |
+
loc_tensor_info = 1;
|
| 199 |
+
}
|
| 200 |
+
}
|
| 201 |
+
}
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
// note: [finishing what we started]
|
| 205 |
+
// if there's remaining work to be done but the tensors/blocks aren't full
|
| 206 |
+
// yet we are at the end, submit the kernel to do the work!
|
| 207 |
+
if (loc_block_info != 0) {
|
| 208 |
+
multi_tensor_apply_kernel<<<
|
| 209 |
+
loc_block_info,
|
| 210 |
+
kBlockSize,
|
| 211 |
+
0,
|
| 212 |
+
at::cuda::getCurrentCUDAStream()>>>(tensorListMeta, callable, args...);
|
| 213 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 214 |
+
}
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
template <int depth, typename T, typename... ArgTypes>
|
| 218 |
+
void multi_tensor_apply(
|
| 219 |
+
std::vector<std::vector<at::Tensor>>& tensor_lists,
|
| 220 |
+
T callable,
|
| 221 |
+
ArgTypes... args) {
|
| 222 |
+
TORCH_CHECK(
|
| 223 |
+
tensor_lists.size() == depth,
|
| 224 |
+
"Number of tensor lists has to match the depth.");
|
| 225 |
+
const size_t n_tensors = tensor_lists[0].size();
|
| 226 |
+
TensorListMetadata<depth> tensorListMeta;
|
| 227 |
+
tensorListMeta.start_tensor_this_launch = 0;
|
| 228 |
+
|
| 229 |
+
int loc_block_info = 0;
|
| 230 |
+
int loc_tensor_info = 0;
|
| 231 |
+
for (size_t t = 0; t < n_tensors; t++) {
|
| 232 |
+
// short-circuit to avoid adding empty tensors to tensorListMeta
|
| 233 |
+
if (tensor_lists[0][t].numel() == 0) {
|
| 234 |
+
continue;
|
| 235 |
+
}
|
| 236 |
+
tensorListMeta.numel_for_tensor[loc_tensor_info] =
|
| 237 |
+
tensor_lists[0][t].numel();
|
| 238 |
+
for (int d = 0; d < depth; d++) {
|
| 239 |
+
tensorListMeta.addresses[d][loc_tensor_info] =
|
| 240 |
+
tensor_lists[d][t].const_data_ptr();
|
| 241 |
+
}
|
| 242 |
+
loc_tensor_info++;
|
| 243 |
+
|
| 244 |
+
// see note: [chunking territory].
|
| 245 |
+
const auto numel = tensor_lists[0][t].numel();
|
| 246 |
+
const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0);
|
| 247 |
+
for (auto chunk = 0; chunk < chunks; chunk++) {
|
| 248 |
+
tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
|
| 249 |
+
tensorListMeta.block_to_chunk[loc_block_info] = chunk;
|
| 250 |
+
loc_block_info++;
|
| 251 |
+
|
| 252 |
+
const bool tensors_full =
|
| 253 |
+
(loc_tensor_info == depth_to_max_tensors[depth - 1] &&
|
| 254 |
+
chunk == chunks - 1);
|
| 255 |
+
const bool blocks_full =
|
| 256 |
+
(loc_block_info == depth_to_max_blocks[depth - 1]);
|
| 257 |
+
|
| 258 |
+
if (tensors_full || blocks_full) {
|
| 259 |
+
multi_tensor_apply_kernel<<<
|
| 260 |
+
loc_block_info,
|
| 261 |
+
kBlockSize,
|
| 262 |
+
0,
|
| 263 |
+
at::cuda::getCurrentCUDAStream()>>>(
|
| 264 |
+
tensorListMeta, callable, args...);
|
| 265 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 266 |
+
|
| 267 |
+
// Reset.
|
| 268 |
+
loc_block_info = 0;
|
| 269 |
+
if (chunk == chunks - 1) {
|
| 270 |
+
loc_tensor_info = 0;
|
| 271 |
+
tensorListMeta.start_tensor_this_launch = t + 1;
|
| 272 |
+
} else {
|
| 273 |
+
tensorListMeta.numel_for_tensor[0] =
|
| 274 |
+
tensorListMeta.numel_for_tensor[loc_tensor_info - 1];
|
| 275 |
+
for (int d = 0; d < depth; d++) {
|
| 276 |
+
tensorListMeta.addresses[d][0] =
|
| 277 |
+
tensorListMeta.addresses[d][loc_tensor_info - 1];
|
| 278 |
+
}
|
| 279 |
+
loc_tensor_info = 1;
|
| 280 |
+
tensorListMeta.start_tensor_this_launch = t;
|
| 281 |
+
}
|
| 282 |
+
}
|
| 283 |
+
}
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
// see note: [finishing what we started]
|
| 287 |
+
if (loc_block_info != 0) {
|
| 288 |
+
multi_tensor_apply_kernel<<<
|
| 289 |
+
loc_block_info,
|
| 290 |
+
kBlockSize,
|
| 291 |
+
0,
|
| 292 |
+
at::cuda::getCurrentCUDAStream()>>>(tensorListMeta, callable, args...);
|
| 293 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 294 |
+
}
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
template <int depth, typename T, typename... ArgTypes>
|
| 298 |
+
void multi_tensor_apply_for_fused_optimizer(
|
| 299 |
+
std::vector<std::vector<at::Tensor>>& tensor_lists,
|
| 300 |
+
at::TensorList state_steps,
|
| 301 |
+
T callable,
|
| 302 |
+
ArgTypes... args) {
|
| 303 |
+
TORCH_CHECK(
|
| 304 |
+
tensor_lists.size() == depth,
|
| 305 |
+
"Number of tensor lists has to match the depth");
|
| 306 |
+
const auto num_tensors = tensor_lists[0].size();
|
| 307 |
+
FusedOptimizerTensorListMetadata<depth> tensorListMeta;
|
| 308 |
+
|
| 309 |
+
int loc_block_info = 0;
|
| 310 |
+
int loc_tensor_info = 0;
|
| 311 |
+
for (const auto& tensor_index : c10::irange(num_tensors)) {
|
| 312 |
+
// short-circuit to avoid adding empty tensors to tensorListMeta
|
| 313 |
+
if (tensor_lists[0][tensor_index].numel() == 0) {
|
| 314 |
+
continue;
|
| 315 |
+
}
|
| 316 |
+
tensorListMeta.state_steps_addresses[loc_tensor_info] =
|
| 317 |
+
state_steps[tensor_index].const_data_ptr();
|
| 318 |
+
tensorListMeta.numel_for_tensor[loc_tensor_info] =
|
| 319 |
+
tensor_lists[0][tensor_index].numel();
|
| 320 |
+
for (const auto& d : c10::irange(depth)) {
|
| 321 |
+
tensorListMeta.addresses[d][loc_tensor_info] =
|
| 322 |
+
tensor_lists[d][tensor_index].const_data_ptr();
|
| 323 |
+
}
|
| 324 |
+
loc_tensor_info++;
|
| 325 |
+
|
| 326 |
+
// see above note: [chunking territory]
|
| 327 |
+
const auto numel = tensor_lists[0][tensor_index].numel();
|
| 328 |
+
const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0);
|
| 329 |
+
TORCH_CHECK(chunks > -1);
|
| 330 |
+
for (const auto& chunk : c10::irange(chunks)) {
|
| 331 |
+
tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
|
| 332 |
+
tensorListMeta.block_to_chunk[loc_block_info] = chunk;
|
| 333 |
+
loc_block_info++;
|
| 334 |
+
|
| 335 |
+
const auto tensor_full =
|
| 336 |
+
(loc_tensor_info == depth_to_max_tensors[depth - 1] &&
|
| 337 |
+
chunk == chunks - 1);
|
| 338 |
+
const auto blocks_full = loc_block_info == depth_to_max_blocks[depth - 1];
|
| 339 |
+
|
| 340 |
+
if (tensor_full || blocks_full) {
|
| 341 |
+
multi_tensor_apply_kernel<<<
|
| 342 |
+
loc_block_info,
|
| 343 |
+
kBlockSize,
|
| 344 |
+
0,
|
| 345 |
+
at::cuda::getCurrentCUDAStream()>>>(
|
| 346 |
+
tensorListMeta, callable, args...);
|
| 347 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 348 |
+
|
| 349 |
+
// Reset.
|
| 350 |
+
loc_block_info = 0;
|
| 351 |
+
if (chunk == chunks - 1) {
|
| 352 |
+
loc_tensor_info = 0;
|
| 353 |
+
} else {
|
| 354 |
+
tensorListMeta.numel_for_tensor[0] =
|
| 355 |
+
tensorListMeta.numel_for_tensor[loc_tensor_info - 1];
|
| 356 |
+
tensorListMeta.state_steps_addresses[0] =
|
| 357 |
+
tensorListMeta.state_steps_addresses[loc_tensor_info - 1];
|
| 358 |
+
for (const auto& d : c10::irange(depth)) {
|
| 359 |
+
tensorListMeta.addresses[d][0] =
|
| 360 |
+
tensorListMeta.addresses[d][loc_tensor_info - 1];
|
| 361 |
+
}
|
| 362 |
+
loc_tensor_info = 1;
|
| 363 |
+
}
|
| 364 |
+
}
|
| 365 |
+
}
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
// see above note: [finishing what we've started]
|
| 369 |
+
if (loc_block_info != 0) {
|
| 370 |
+
multi_tensor_apply_kernel<<<
|
| 371 |
+
loc_block_info,
|
| 372 |
+
kBlockSize,
|
| 373 |
+
0,
|
| 374 |
+
at::cuda::getCurrentCUDAStream()>>>(tensorListMeta, callable, args...);
|
| 375 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 376 |
+
}
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Normalization.cuh
ADDED
|
@@ -0,0 +1,1742 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Tensor.h>
|
| 4 |
+
#include <ATen/Dispatch.h>
|
| 5 |
+
#include <ATen/AccumulateType.h>
|
| 6 |
+
#include <ATen/ceil_div.h>
|
| 7 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 8 |
+
#include <ATen/cuda/DeviceUtils.cuh>
|
| 9 |
+
#include <ATen/native/cuda/block_reduce.cuh>
|
| 10 |
+
#include <ATen/native/cuda/DeviceSqrt.cuh>
|
| 11 |
+
#include <ATen/native/cuda/LaunchUtils.h>
|
| 12 |
+
#include <c10/macros/Macros.h>
|
| 13 |
+
|
| 14 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 15 |
+
#include <ATen/Functions.h>
|
| 16 |
+
#else
|
| 17 |
+
#include <ATen/ops/empty.h>
|
| 18 |
+
#include <ATen/ops/empty_like.h>
|
| 19 |
+
#include <ATen/ops/zeros.h>
|
| 20 |
+
#endif
|
| 21 |
+
|
| 22 |
+
namespace at { namespace native {
|
| 23 |
+
|
| 24 |
+
// The maximum number of threads in a block
|
| 25 |
+
#if defined(USE_ROCM)
|
| 26 |
+
constexpr int MAX_BLOCK_SIZE = 256;
|
| 27 |
+
#else
|
| 28 |
+
constexpr int MAX_BLOCK_SIZE = 512;
|
| 29 |
+
#endif
|
| 30 |
+
|
| 31 |
+
constexpr unsigned MAX_GRID_SIZE = 65535u;
|
| 32 |
+
|
| 33 |
+
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
|
| 34 |
+
static int getNumThreads(int nElem) {
|
| 35 |
+
#if defined(USE_ROCM)
|
| 36 |
+
int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
|
| 37 |
+
#else
|
| 38 |
+
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
|
| 39 |
+
#endif
|
| 40 |
+
for (int i = 0; i != 5; ++i) {
|
| 41 |
+
if (nElem <= threadSizes[i]) {
|
| 42 |
+
return threadSizes[i];
|
| 43 |
+
}
|
| 44 |
+
}
|
| 45 |
+
return MAX_BLOCK_SIZE;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
// Returns the index of the most significant 1 bit in `val`.
|
| 49 |
+
__device__ __forceinline__ int getMSB(int val) {
|
| 50 |
+
return 31 - __clz(val);
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
template <typename scalar_t, typename accscalar_t>
|
| 54 |
+
struct Float2 {
|
| 55 |
+
accscalar_t v1, v2;
|
| 56 |
+
__device__ Float2() {}
|
| 57 |
+
__device__ Float2(scalar_t v1, scalar_t v2) : v1(static_cast<accscalar_t>(v1)), v2(static_cast<accscalar_t>(v2)) {}
|
| 58 |
+
__device__ Float2(int v) : v1(static_cast<accscalar_t>(v)), v2(static_cast<accscalar_t>(v)) {}
|
| 59 |
+
__device__ Float2& operator+=(const Float2& a) {
|
| 60 |
+
v1 += a.v1;
|
| 61 |
+
v2 += a.v2;
|
| 62 |
+
return *this;
|
| 63 |
+
}
|
| 64 |
+
__device__ friend Float2 operator+(Float2 a, const Float2& b) {
|
| 65 |
+
a += b;
|
| 66 |
+
return a;
|
| 67 |
+
}
|
| 68 |
+
};
|
| 69 |
+
|
| 70 |
+
template <typename scalar_t, typename accscalar_t, typename PTA>
|
| 71 |
+
struct GradOp {
|
| 72 |
+
__device__ GradOp(accscalar_t m, const PTA& i, const PTA& g)
|
| 73 |
+
: mean(m), input(i), grad_output(g) {}
|
| 74 |
+
__device__ __forceinline__ Float2<scalar_t, accscalar_t> operator()(int batch, int plane, int n) {
|
| 75 |
+
accscalar_t g = grad_output[batch][plane][n];
|
| 76 |
+
accscalar_t c = static_cast<accscalar_t>(input[batch][plane][n]) - mean;
|
| 77 |
+
return Float2<scalar_t, accscalar_t>(g, g * c);
|
| 78 |
+
}
|
| 79 |
+
const accscalar_t mean;
|
| 80 |
+
const PTA& input;
|
| 81 |
+
const PTA& grad_output;
|
| 82 |
+
};
|
| 83 |
+
|
| 84 |
+
template <typename acc_t>
|
| 85 |
+
struct SumReduceOp {
|
| 86 |
+
__device__ __forceinline__ acc_t combine(acc_t a, acc_t b) const { return a + b; }
|
| 87 |
+
|
| 88 |
+
__device__ __forceinline__ acc_t warp_shfl_down(acc_t data, int offset) const {
|
| 89 |
+
return WARP_SHFL_DOWN(data, offset);
|
| 90 |
+
}
|
| 91 |
+
};
|
| 92 |
+
|
| 93 |
+
template <typename scalar_t, typename accscalar_t>
|
| 94 |
+
struct SumReduceOp<Float2<scalar_t, accscalar_t>> {
|
| 95 |
+
using acc_t = Float2<scalar_t, accscalar_t>;
|
| 96 |
+
|
| 97 |
+
__device__ __forceinline__ acc_t combine(acc_t a, acc_t b) const { return a + b; }
|
| 98 |
+
|
| 99 |
+
__device__ __forceinline__ acc_t warp_shfl_down(acc_t data, int offset) const {
|
| 100 |
+
return {WARP_SHFL_DOWN(data.v1, offset), WARP_SHFL_DOWN(data.v2, offset)};
|
| 101 |
+
}
|
| 102 |
+
};
|
| 103 |
+
|
| 104 |
+
// Sum across (batch, x/y/z) applying Op() pointwise
|
| 105 |
+
// this works by first having each thread sum it's part
|
| 106 |
+
// of the data. Then there is a double-shuffling reduction.
|
| 107 |
+
// First each warp (of C10_WARP_SIZE threads) uses warpSum to reduce its
|
| 108 |
+
// data to the "warp leader", who writes its value into shared memory.
|
| 109 |
+
// Then a single warp reads the remaining (at most C10_WARP_SIZE) items
|
| 110 |
+
// and reduces them using another warpSum.
|
| 111 |
+
// The implicit assumption is that there are no more
|
| 112 |
+
// than C10_WARP_SIZE**2 threads.
|
| 113 |
+
template<typename scalar_t, typename Op, typename PTA>
|
| 114 |
+
__device__ scalar_t reduce(Op op, PTA tensor, int plane) {
|
| 115 |
+
// first the reductions each thread does separately
|
| 116 |
+
scalar_t sum = static_cast<scalar_t>(0);
|
| 117 |
+
for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) {
|
| 118 |
+
for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) {
|
| 119 |
+
sum += op(batch, plane, x);
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
__shared__ scalar_t shared[C10_WARP_SIZE];
|
| 123 |
+
SumReduceOp<scalar_t> reduce_op;
|
| 124 |
+
sum = cuda_utils::BlockReduce<scalar_t, SumReduceOp<scalar_t>, cuda_utils::Block2D>(sum, reduce_op, 0, shared);
|
| 125 |
+
if (threadIdx.x == 0 && threadIdx.y == 0) {
|
| 126 |
+
shared[0] = sum;
|
| 127 |
+
}
|
| 128 |
+
__syncthreads();
|
| 129 |
+
// Everyone picks it up, should be broadcast into the whole grad_input
|
| 130 |
+
return shared[0];
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
constexpr int ELEMENTS_PER_ITER = 4; // enables concurrency within each thread to hide latency
|
| 134 |
+
constexpr int ELEMENTS_PER_THREAD = 16;
|
| 135 |
+
constexpr int OPTIMAL_TILE_W = 32;
|
| 136 |
+
constexpr int MAX_H_BLOCK = 128;
|
| 137 |
+
|
| 138 |
+
__host__ void flexible_launch_configs(
|
| 139 |
+
const int reduction,
|
| 140 |
+
const int stride,
|
| 141 |
+
dim3 &block,
|
| 142 |
+
dim3 &grid,
|
| 143 |
+
const bool coop_flag = false) {
|
| 144 |
+
int block_x = std::min(lastPow2(stride), OPTIMAL_TILE_W);
|
| 145 |
+
int block_y = std::min(lastPow2(at::ceil_div(reduction , ELEMENTS_PER_THREAD)),
|
| 146 |
+
MAX_BLOCK_SIZE / block_x);
|
| 147 |
+
if (block_x * block_y != MAX_BLOCK_SIZE) {
|
| 148 |
+
block_x = std::min(lastPow2(stride), MAX_BLOCK_SIZE / block_y);
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
int grid_x = at::ceil_div(stride, block_x);
|
| 152 |
+
int grid_y = std::min(at::ceil_div(reduction, block_y * ELEMENTS_PER_THREAD), MAX_H_BLOCK);
|
| 153 |
+
if (coop_flag) {
|
| 154 |
+
// it's not worth having a grid reduction if the reduction dimension is not big enough
|
| 155 |
+
grid_y = grid_y < 8 ? 1 : grid_y;
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
block.x = block_x;
|
| 159 |
+
block.y = block_y;
|
| 160 |
+
block.z = 1;
|
| 161 |
+
grid.x = grid_x;
|
| 162 |
+
grid.y = grid_y;
|
| 163 |
+
grid.z = 1;
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
template<typename T, typename C>
|
| 167 |
+
__device__ __forceinline__ void welford_merge_element(C& count,
|
| 168 |
+
T& mean,
|
| 169 |
+
T& m2n,
|
| 170 |
+
const C& count_new,
|
| 171 |
+
const T& mean_new,
|
| 172 |
+
const T& m2n_new) {
|
| 173 |
+
T factor = T(1.0) / ::max(1, (count + count_new));
|
| 174 |
+
T delta0 = mean - mean_new;
|
| 175 |
+
mean = (mean_new * count_new + mean * count) * factor;
|
| 176 |
+
m2n += m2n_new + delta0 * delta0 * count_new * count * factor;
|
| 177 |
+
count += count_new;
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
// merge mean/m2n among threadIdx.y within block
|
| 181 |
+
template<typename T, typename C>
|
| 182 |
+
__device__ __forceinline__ void welford_merge_block_vertical(C& count,
|
| 183 |
+
T& mean,
|
| 184 |
+
T& m2n,
|
| 185 |
+
C* shmem_count,
|
| 186 |
+
T* shmem_mean,
|
| 187 |
+
T* shmem_m2n) {
|
| 188 |
+
// write to shared memory
|
| 189 |
+
auto address_base = threadIdx.x + threadIdx.y * blockDim.x;
|
| 190 |
+
|
| 191 |
+
#pragma unroll
|
| 192 |
+
for (int offset = blockDim.y/2; offset > 0; offset >>= 1) {
|
| 193 |
+
if (threadIdx.y < offset*2) {
|
| 194 |
+
shmem_mean[address_base] = mean;
|
| 195 |
+
shmem_m2n[address_base] = m2n;
|
| 196 |
+
shmem_count[address_base] = count;
|
| 197 |
+
}
|
| 198 |
+
__syncthreads();
|
| 199 |
+
if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
|
| 200 |
+
auto address = address_base + offset * blockDim.x;
|
| 201 |
+
// read shared memory back to register for reduction
|
| 202 |
+
auto count_new = shmem_count[address];
|
| 203 |
+
auto mean_new = shmem_mean[address];
|
| 204 |
+
auto m2n_new = shmem_m2n[address];
|
| 205 |
+
|
| 206 |
+
welford_merge_element(count, mean, m2n, count_new, mean_new, m2n_new);
|
| 207 |
+
}
|
| 208 |
+
}
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, bool train, typename index_t>
|
| 212 |
+
__global__ void batch_norm_transform_input_kernel(
|
| 213 |
+
const GenericPackedTensorAccessor<const input_scalar_t, 3, RestrictPtrTraits, index_t> input,
|
| 214 |
+
GenericPackedTensorAccessor<input_scalar_t, 3, RestrictPtrTraits, index_t> output,
|
| 215 |
+
const GenericPackedTensorAccessor<typename std::conditional<train, stat_accscalar_t, stat_scalar_t>::type, 1, RestrictPtrTraits, index_t> mean_,
|
| 216 |
+
const GenericPackedTensorAccessor<typename std::conditional<train, stat_accscalar_t, stat_scalar_t>::type, 1, RestrictPtrTraits, index_t> var_or_invstd,
|
| 217 |
+
const GenericPackedTensorAccessor<const stat_scalar_t, 1, RestrictPtrTraits, index_t> weight,
|
| 218 |
+
const GenericPackedTensorAccessor<const stat_scalar_t, 1, RestrictPtrTraits, index_t> bias,
|
| 219 |
+
stat_accscalar_t epsilon) {
|
| 220 |
+
|
| 221 |
+
index_t plane = blockIdx.x;
|
| 222 |
+
|
| 223 |
+
if (plane >= input.size(1)) {
|
| 224 |
+
return;
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
stat_accscalar_t gamma = weight.size(0) > 0 ? static_cast<stat_accscalar_t>(weight[plane]) : static_cast<stat_accscalar_t>(1);
|
| 228 |
+
stat_accscalar_t beta = bias.size(0) > 0 ? static_cast<stat_accscalar_t>(bias[plane]) : static_cast<stat_accscalar_t>(0);
|
| 229 |
+
stat_accscalar_t mean = static_cast<stat_accscalar_t>(mean_[plane]);
|
| 230 |
+
stat_accscalar_t invstd;
|
| 231 |
+
if (train) {
|
| 232 |
+
invstd = var_or_invstd[plane];
|
| 233 |
+
} else {
|
| 234 |
+
invstd = static_cast<stat_accscalar_t>(1) / device_sqrt(static_cast<stat_accscalar_t>(var_or_invstd[plane]) + epsilon);
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
index_t bs = input.size(0);
|
| 238 |
+
index_t fs = input.size(2);
|
| 239 |
+
|
| 240 |
+
index_t bstep = blockDim.y * gridDim.y;
|
| 241 |
+
for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) {
|
| 242 |
+
auto o = output[batch][plane];
|
| 243 |
+
auto i = input[batch][plane];
|
| 244 |
+
for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) {
|
| 245 |
+
o[feature] = static_cast<input_scalar_t>(gamma * (i[feature] - mean) * invstd + beta);
|
| 246 |
+
}
|
| 247 |
+
}
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
struct InvStd {
|
| 251 |
+
template <typename T>
|
| 252 |
+
__device__ __forceinline__ T operator()(T var, double epsilon) const {
|
| 253 |
+
T invstd = 0;
|
| 254 |
+
if (var != static_cast<T>(0) || epsilon != static_cast<T>(0)) {
|
| 255 |
+
invstd = static_cast<T>(1) / device_sqrt(var + epsilon);
|
| 256 |
+
}
|
| 257 |
+
return invstd;
|
| 258 |
+
}
|
| 259 |
+
};
|
| 260 |
+
|
| 261 |
+
struct Var {
|
| 262 |
+
template <typename T>
|
| 263 |
+
__device__ __forceinline__ T operator()(T var, double epsilon) const {
|
| 264 |
+
return var;
|
| 265 |
+
}
|
| 266 |
+
};
|
| 267 |
+
|
| 268 |
+
template <typename VarTransform, typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
|
| 269 |
+
__global__ void batch_norm_collect_statistics_kernel(
|
| 270 |
+
const GenericPackedTensorAccessor<const input_scalar_t, 3, RestrictPtrTraits, index_t> input,
|
| 271 |
+
const stat_accscalar_t epsilon,
|
| 272 |
+
const stat_accscalar_t momentum,
|
| 273 |
+
GenericPackedTensorAccessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t> save_mean,
|
| 274 |
+
GenericPackedTensorAccessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t> save_transformed_var) {
|
| 275 |
+
|
| 276 |
+
__shared__ int shared_n[2 * 2 * C10_WARP_SIZE + C10_WARP_SIZE];
|
| 277 |
+
|
| 278 |
+
int plane = blockIdx.x;
|
| 279 |
+
int N = input.size(0) * input.size(2);
|
| 280 |
+
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
| 281 |
+
|
| 282 |
+
// Compute the mean and variance across (batch, x/y/z)
|
| 283 |
+
// this uses the Welford (in the for loop)/parallel algorithm (to sum across the block)
|
| 284 |
+
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
|
| 285 |
+
// and the parallel algorithm on the same page.
|
| 286 |
+
// We use two shuffles to reduce across the entire block.
|
| 287 |
+
// https://devblogs.nvidia.com/faster-parallel-reductions-kepler/ has a description.
|
| 288 |
+
stat_accscalar_t* shared_avg_var = (stat_accscalar_t*) &shared_n[C10_WARP_SIZE];
|
| 289 |
+
|
| 290 |
+
// first the reductions each thread does separately
|
| 291 |
+
stat_accscalar_t avg = 0;
|
| 292 |
+
stat_accscalar_t var_n = 0;
|
| 293 |
+
int n = 0;
|
| 294 |
+
for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) {
|
| 295 |
+
for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) {
|
| 296 |
+
stat_accscalar_t v = input[batch][plane][x];
|
| 297 |
+
stat_accscalar_t d1 = v - avg;
|
| 298 |
+
n++;
|
| 299 |
+
avg += d1 / n;
|
| 300 |
+
var_n += d1 * (v - avg);
|
| 301 |
+
}
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
// first warpSum to get one value per thread to
|
| 305 |
+
// one value per warp
|
| 306 |
+
for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) {
|
| 307 |
+
stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, C10_WARP_SIZE);
|
| 308 |
+
int o_n = WARP_SHFL_XOR(n, 1 << i, C10_WARP_SIZE);
|
| 309 |
+
stat_accscalar_t factor = 1.0 / fmaxf(1.0, n+o_n);
|
| 310 |
+
var_n += WARP_SHFL_XOR(var_n, 1 << i, C10_WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor;
|
| 311 |
+
avg = (n * avg + o_n * o_avg) * factor;
|
| 312 |
+
n += o_n;
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
// this writes each warps item into shared memory
|
| 316 |
+
// there are at most C10_WARP_SIZE items left because
|
| 317 |
+
// there are at most C10_WARP_SIZE**2 threads at the beginning
|
| 318 |
+
__syncthreads();
|
| 319 |
+
if (tid % C10_WARP_SIZE == 0) {
|
| 320 |
+
shared_n[tid / C10_WARP_SIZE] = n;
|
| 321 |
+
shared_avg_var[tid / C10_WARP_SIZE * 2] = avg;
|
| 322 |
+
shared_avg_var[tid / C10_WARP_SIZE * 2 + 1] = var_n;
|
| 323 |
+
}
|
| 324 |
+
__syncthreads();
|
| 325 |
+
// now have a second warpSum to reduce the intermediate values
|
| 326 |
+
// from shared memory to a single number. The very first
|
| 327 |
+
// thread writes it to shared memory.
|
| 328 |
+
|
| 329 |
+
if (tid < C10_WARP_SIZE) {
|
| 330 |
+
n = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_n[tid] : 0);
|
| 331 |
+
avg = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_avg_var[2 * tid] : stat_accscalar_t(0));
|
| 332 |
+
var_n = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_avg_var[2 * tid + 1] : stat_accscalar_t(0));
|
| 333 |
+
}
|
| 334 |
+
for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) {
|
| 335 |
+
stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, C10_WARP_SIZE);
|
| 336 |
+
int o_n = WARP_SHFL_XOR(n, 1 << i, C10_WARP_SIZE);
|
| 337 |
+
stat_accscalar_t factor = 1.0 / fmaxf(1.0, n+o_n);
|
| 338 |
+
var_n += WARP_SHFL_XOR(var_n, 1 << i, C10_WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor;
|
| 339 |
+
avg = (n * avg + o_n * o_avg) * factor;
|
| 340 |
+
n += o_n;
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
// Save the mean, variance, and moving averages
|
| 344 |
+
if (tid == 0) {
|
| 345 |
+
if (save_mean.data() != NULL) {
|
| 346 |
+
save_mean[plane] = avg;
|
| 347 |
+
}
|
| 348 |
+
if (save_transformed_var.data() != NULL) {
|
| 349 |
+
save_transformed_var[plane] = VarTransform{}(var_n / N, epsilon);
|
| 350 |
+
}
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
|
| 356 |
+
__global__ void batch_norm_backward_kernel(
|
| 357 |
+
const GenericPackedTensorAccessor<const input_scalar_t, 3, DefaultPtrTraits, index_t> input,
|
| 358 |
+
const GenericPackedTensorAccessor<const input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
|
| 359 |
+
GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input,
|
| 360 |
+
GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_weight,
|
| 361 |
+
GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_bias,
|
| 362 |
+
const GenericPackedTensorAccessor<const stat_scalar_t, 1, DefaultPtrTraits, index_t> weight,
|
| 363 |
+
const GenericPackedTensorAccessor<const stat_scalar_t, 1, DefaultPtrTraits, index_t> running_mean,
|
| 364 |
+
const GenericPackedTensorAccessor<const stat_scalar_t, 1, DefaultPtrTraits, index_t> running_var,
|
| 365 |
+
const GenericPackedTensorAccessor<const stat_accscalar_t, 1, DefaultPtrTraits, index_t> save_mean,
|
| 366 |
+
const GenericPackedTensorAccessor<const stat_accscalar_t, 1, DefaultPtrTraits, index_t> save_invstd,
|
| 367 |
+
bool train,
|
| 368 |
+
stat_accscalar_t epsilon) {
|
| 369 |
+
|
| 370 |
+
index_t plane = blockIdx.x;
|
| 371 |
+
index_t N = grad_output.size(0) * grad_output.size(2);
|
| 372 |
+
|
| 373 |
+
stat_accscalar_t mean, invstd;
|
| 374 |
+
if (train) {
|
| 375 |
+
mean = save_mean[plane];
|
| 376 |
+
invstd = save_invstd[plane];
|
| 377 |
+
} else {
|
| 378 |
+
mean = static_cast<stat_accscalar_t>(running_mean[plane]);
|
| 379 |
+
invstd = static_cast<stat_accscalar_t>(1) / device_sqrt(static_cast<stat_accscalar_t>(running_var[plane]) + epsilon);
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
stat_accscalar_t weight_val = weight.size(0) > 0 ? static_cast<stat_accscalar_t>(weight[plane]) : stat_accscalar_t(1);
|
| 383 |
+
stat_accscalar_t norm = stat_accscalar_t(1) / N;
|
| 384 |
+
|
| 385 |
+
// Compute two values across (batch, x/y/z) in one pass:
|
| 386 |
+
// 1. Sum(grad_output)
|
| 387 |
+
// 2. DotProduct(input - mean, grad_output)
|
| 388 |
+
GradOp<input_scalar_t, stat_accscalar_t, GenericPackedTensorAccessor<const input_scalar_t, 3, DefaultPtrTraits, index_t>> g(mean, input, grad_output);
|
| 389 |
+
auto res = reduce<Float2<input_scalar_t, stat_accscalar_t>>(g, grad_output, plane);
|
| 390 |
+
|
| 391 |
+
stat_accscalar_t grad_output_sum = res.v1;
|
| 392 |
+
stat_accscalar_t dot_p = res.v2;
|
| 393 |
+
|
| 394 |
+
stat_accscalar_t grad_mean = grad_output_sum * norm;
|
| 395 |
+
stat_accscalar_t proj_scale = dot_p * norm * invstd * invstd;
|
| 396 |
+
stat_accscalar_t grad_scale = invstd * weight_val;
|
| 397 |
+
|
| 398 |
+
if (grad_input.data() != NULL) {
|
| 399 |
+
for (int batch = threadIdx.y; batch < grad_output.size(0); batch += blockDim.y) {
|
| 400 |
+
for (int x = threadIdx.x; x < grad_output.size(2); x += blockDim.x) {
|
| 401 |
+
input_scalar_t go = grad_output[batch][plane][x];
|
| 402 |
+
if (train) {
|
| 403 |
+
stat_accscalar_t inp = input[batch][plane][x];
|
| 404 |
+
stat_accscalar_t proj = (inp - mean) * proj_scale;
|
| 405 |
+
grad_input[batch][plane][x] = static_cast<input_scalar_t>((go - proj - grad_mean) * grad_scale);
|
| 406 |
+
} else {
|
| 407 |
+
grad_input[batch][plane][x] = static_cast<input_scalar_t>(go * grad_scale);
|
| 408 |
+
}
|
| 409 |
+
}
|
| 410 |
+
}
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
if (grad_weight.size(0) > 0) {
|
| 414 |
+
if (threadIdx.x == 0) {
|
| 415 |
+
grad_weight[plane] = static_cast<stat_scalar_t>(dot_p * invstd);
|
| 416 |
+
}
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
if (grad_bias.size(0) > 0) {
|
| 420 |
+
if (threadIdx.x == 0) {
|
| 421 |
+
grad_bias[plane] = static_cast<stat_scalar_t>(grad_output_sum);
|
| 422 |
+
}
|
| 423 |
+
}
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
template <typename scalar_t, typename accscalar_t, typename index_t>
|
| 427 |
+
__global__ void batch_norm_reduce_statistics_kernel(
|
| 428 |
+
const GenericPackedTensorAccessor<accscalar_t, 2, RestrictPtrTraits, index_t> vec_mean,
|
| 429 |
+
const GenericPackedTensorAccessor<accscalar_t, 2, RestrictPtrTraits, index_t> vec_invstd,
|
| 430 |
+
GenericPackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> mean,
|
| 431 |
+
GenericPackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> invstd,
|
| 432 |
+
GenericPackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> running_mean,
|
| 433 |
+
GenericPackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> running_var,
|
| 434 |
+
const accscalar_t epsilon,
|
| 435 |
+
const accscalar_t momentum,
|
| 436 |
+
const GenericPackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> counts) {
|
| 437 |
+
|
| 438 |
+
int feature_size = vec_mean.size(1);
|
| 439 |
+
int world_size = vec_mean.size(0);
|
| 440 |
+
|
| 441 |
+
int bid = blockIdx.x;
|
| 442 |
+
int tid = threadIdx.x;
|
| 443 |
+
|
| 444 |
+
// first the reductions each thread does separately
|
| 445 |
+
for (int i = bid*blockDim.x+tid; i < feature_size; i += gridDim.x*blockDim.x) {
|
| 446 |
+
accscalar_t avg = 0;
|
| 447 |
+
accscalar_t var_n = 0;
|
| 448 |
+
index_t n = 0;
|
| 449 |
+
for (int j = 0; j < world_size; j++) {
|
| 450 |
+
scalar_t count = counts[j];
|
| 451 |
+
accscalar_t m = vec_mean[j][i];
|
| 452 |
+
accscalar_t v = accscalar_t(1.0) / (vec_invstd[j][i]);
|
| 453 |
+
v = (v * v - epsilon) * count;
|
| 454 |
+
accscalar_t factor = 1.0 / (n + count);
|
| 455 |
+
var_n += v + (avg - m) * (avg - m) * n * count * factor;
|
| 456 |
+
avg = n * factor * avg + count * factor * m;
|
| 457 |
+
n += count;
|
| 458 |
+
}
|
| 459 |
+
mean[i] = avg;
|
| 460 |
+
invstd[i] = static_cast<accscalar_t>(1) / device_sqrt(var_n / n + epsilon);
|
| 461 |
+
if (running_mean.data() != NULL) {
|
| 462 |
+
running_mean[i] = static_cast<scalar_t>((1 - momentum) * running_mean[i] + momentum * avg);
|
| 463 |
+
}
|
| 464 |
+
accscalar_t unbiasedVar = var_n / (n - 1);
|
| 465 |
+
if (running_var.data() != NULL) {
|
| 466 |
+
running_var[i] = static_cast<scalar_t>((1 - momentum) * running_var[i] + momentum * unbiasedVar);
|
| 467 |
+
}
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
|
| 473 |
+
__global__ void batch_norm_backward_reduce_kernel(
|
| 474 |
+
const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
|
| 475 |
+
const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
|
| 476 |
+
GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean,
|
| 477 |
+
GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
|
| 478 |
+
GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy,
|
| 479 |
+
GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu,
|
| 480 |
+
GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_weight,
|
| 481 |
+
GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_bias) {
|
| 482 |
+
|
| 483 |
+
index_t plane = blockIdx.x;
|
| 484 |
+
|
| 485 |
+
stat_accscalar_t r_mean = mean[plane];
|
| 486 |
+
stat_accscalar_t factor = invstd[plane];
|
| 487 |
+
|
| 488 |
+
GradOp<input_scalar_t, stat_accscalar_t, GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t>> g(r_mean, input, grad_output);
|
| 489 |
+
auto res = reduce<Float2<input_scalar_t, stat_accscalar_t>>(g, grad_output, plane);
|
| 490 |
+
|
| 491 |
+
if (threadIdx.x == 0) {
|
| 492 |
+
if (grad_weight.size(0) > 0) {
|
| 493 |
+
grad_weight[plane] = static_cast<stat_scalar_t>(res.v2 * factor);
|
| 494 |
+
}
|
| 495 |
+
if (grad_bias.size(0) > 0) {
|
| 496 |
+
grad_bias[plane] = static_cast<stat_scalar_t>(res.v1);
|
| 497 |
+
}
|
| 498 |
+
if (sum_dy.size(0) > 0) {
|
| 499 |
+
sum_dy[plane] = static_cast<stat_accscalar_t>(res.v1);
|
| 500 |
+
}
|
| 501 |
+
if (sum_dy_xmu.size(0) > 0) {
|
| 502 |
+
sum_dy_xmu[plane] = static_cast<stat_accscalar_t>(res.v2);
|
| 503 |
+
}
|
| 504 |
+
}
|
| 505 |
+
}
|
| 506 |
+
|
| 507 |
+
template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
|
| 508 |
+
__device__ __forceinline__ void batch_norm_backward_elemt_kernel_impl(
|
| 509 |
+
const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
|
| 510 |
+
const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
|
| 511 |
+
const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean,
|
| 512 |
+
const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
|
| 513 |
+
const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> weight,
|
| 514 |
+
const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy,
|
| 515 |
+
const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu,
|
| 516 |
+
GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input,
|
| 517 |
+
const stat_accscalar_t norm_fct) {
|
| 518 |
+
index_t plane = blockIdx.x;
|
| 519 |
+
|
| 520 |
+
if (plane >= input.size(1)) {
|
| 521 |
+
return;
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
stat_accscalar_t m_c = mean[plane];
|
| 525 |
+
stat_accscalar_t m_dy_c = sum_dy[plane] * norm_fct;
|
| 526 |
+
stat_accscalar_t factor_1_c = invstd[plane];
|
| 527 |
+
stat_accscalar_t factor_2_c = weight.size(0) > 0 ? static_cast<stat_accscalar_t>(weight[plane]) : stat_accscalar_t(1);
|
| 528 |
+
factor_2_c *= factor_1_c;
|
| 529 |
+
factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[plane] * norm_fct;
|
| 530 |
+
|
| 531 |
+
index_t bs = input.size(0);
|
| 532 |
+
index_t fs = input.size(2);
|
| 533 |
+
|
| 534 |
+
index_t bstep = blockDim.y * gridDim.y;
|
| 535 |
+
for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) {
|
| 536 |
+
auto g_i = grad_input[batch][plane];
|
| 537 |
+
auto g_o = grad_output[batch][plane];
|
| 538 |
+
auto i = input[batch][plane];
|
| 539 |
+
for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) {
|
| 540 |
+
g_i[feature] = static_cast<input_scalar_t>((g_o[feature] - m_dy_c - (i[feature] - m_c) * factor_1_c) * factor_2_c);
|
| 541 |
+
}
|
| 542 |
+
}
|
| 543 |
+
}
|
| 544 |
+
|
| 545 |
+
template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
|
| 546 |
+
__global__ void batch_norm_backward_elemt_kernel(
|
| 547 |
+
const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
|
| 548 |
+
const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
|
| 549 |
+
const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean,
|
| 550 |
+
const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
|
| 551 |
+
const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> weight,
|
| 552 |
+
const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy,
|
| 553 |
+
const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu,
|
| 554 |
+
GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input,
|
| 555 |
+
const int* __restrict__ numel, const int world_size) {
|
| 556 |
+
int64_t total_numel = 0;
|
| 557 |
+
for (int i = 0; i < world_size; i ++) {
|
| 558 |
+
total_numel += numel[i];
|
| 559 |
+
}
|
| 560 |
+
|
| 561 |
+
const stat_accscalar_t norm_fct =
|
| 562 |
+
static_cast<stat_accscalar_t>(1) / static_cast<stat_accscalar_t>(total_numel);
|
| 563 |
+
batch_norm_backward_elemt_kernel_impl(
|
| 564 |
+
input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct);
|
| 565 |
+
}
|
| 566 |
+
|
| 567 |
+
template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
|
| 568 |
+
__global__ void batch_norm_backward_elemt_kernel(
|
| 569 |
+
const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
|
| 570 |
+
const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
|
| 571 |
+
const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean,
|
| 572 |
+
const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
|
| 573 |
+
const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> weight,
|
| 574 |
+
const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy,
|
| 575 |
+
const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu,
|
| 576 |
+
GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input,
|
| 577 |
+
const stat_accscalar_t norm_fct) {
|
| 578 |
+
batch_norm_backward_elemt_kernel_impl(
|
| 579 |
+
input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct);
|
| 580 |
+
}
|
| 581 |
+
|
| 582 |
+
template <typename scalar_t, int64_t dim, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
| 583 |
+
static GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t> get_packed_accessor(
|
| 584 |
+
const Tensor& t, c10::string_view var_name) {
|
| 585 |
+
constexpr auto expect_type = c10::CppTypeToScalarType<typename std::remove_const<scalar_t>::type>::value;
|
| 586 |
+
const auto actual_type = t.scalar_type();
|
| 587 |
+
TORCH_CHECK(actual_type == expect_type, "Expected ", var_name,
|
| 588 |
+
" to have type ", expect_type, " but got ", actual_type);
|
| 589 |
+
return t.generic_packed_accessor<scalar_t, dim, PtrTraits, index_t>();
|
| 590 |
+
}
|
| 591 |
+
|
| 592 |
+
template <typename scalar_t, int64_t dim, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
| 593 |
+
static GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t> packed_accessor_or_dummy(
|
| 594 |
+
const Tensor& t, c10::string_view var_name) {
|
| 595 |
+
if (!t.defined()) {
|
| 596 |
+
const std::array<index_t, dim> zeros{{0}};
|
| 597 |
+
return GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t>(nullptr, zeros.data(), zeros.data());
|
| 598 |
+
}
|
| 599 |
+
return get_packed_accessor<scalar_t, dim, PtrTraits, index_t>(t, var_name);
|
| 600 |
+
}
|
| 601 |
+
|
| 602 |
+
template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
|
| 603 |
+
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda_template(const Tensor& grad_out_, const Tensor& input_, const Tensor& weight_,
|
| 604 |
+
const Tensor& running_mean_, const Tensor& running_var_, const Tensor& save_mean_, const Tensor& save_invstd_,
|
| 605 |
+
bool train, double epsilon, std::array<bool,3> grad_input_mask) {
|
| 606 |
+
|
| 607 |
+
using accscalar_t = at::acc_type<stat_scalar_t, true>;
|
| 608 |
+
Tensor grad_input_;
|
| 609 |
+
Tensor grad_input_reshaped;
|
| 610 |
+
Tensor grad_weight_;
|
| 611 |
+
Tensor grad_bias_;
|
| 612 |
+
auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1});
|
| 613 |
+
auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
|
| 614 |
+
|
| 615 |
+
if (grad_input_mask[0]) {
|
| 616 |
+
grad_input_ = at::empty_like(input_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
| 617 |
+
grad_input_reshaped = grad_input_.view(input_reshaped.sizes());
|
| 618 |
+
}
|
| 619 |
+
if (grad_input_mask[1]) {
|
| 620 |
+
grad_weight_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
| 621 |
+
}
|
| 622 |
+
if (grad_input_mask[2]) {
|
| 623 |
+
grad_bias_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
| 624 |
+
}
|
| 625 |
+
|
| 626 |
+
auto input = get_packed_accessor<
|
| 627 |
+
const input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
|
| 628 |
+
auto grad_output = get_packed_accessor<
|
| 629 |
+
const input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
|
| 630 |
+
auto grad_input = packed_accessor_or_dummy<
|
| 631 |
+
input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input");
|
| 632 |
+
auto weight = packed_accessor_or_dummy<
|
| 633 |
+
const stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight");
|
| 634 |
+
auto grad_weight = packed_accessor_or_dummy<
|
| 635 |
+
stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_, "grad_weight");
|
| 636 |
+
auto grad_bias = packed_accessor_or_dummy<
|
| 637 |
+
stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_, "grad_bias");
|
| 638 |
+
auto running_mean = packed_accessor_or_dummy<
|
| 639 |
+
const stat_scalar_t, 1, DefaultPtrTraits, index_t>(running_mean_, "running_mean");
|
| 640 |
+
auto running_var = packed_accessor_or_dummy<
|
| 641 |
+
const stat_scalar_t, 1, DefaultPtrTraits, index_t>(running_var_, "running_var");
|
| 642 |
+
auto save_mean = packed_accessor_or_dummy<
|
| 643 |
+
const accscalar_t, 1, DefaultPtrTraits, index_t>(save_mean_, "save_mean");
|
| 644 |
+
auto save_invstd = packed_accessor_or_dummy<
|
| 645 |
+
const accscalar_t, 1, DefaultPtrTraits, index_t>(save_invstd_, "save_invstd");
|
| 646 |
+
|
| 647 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
| 648 |
+
dim3 blocks(input.size(1));
|
| 649 |
+
int tf = getNumThreads(input.size(2));
|
| 650 |
+
dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
|
| 651 |
+
|
| 652 |
+
batch_norm_backward_kernel<input_scalar_t, stat_scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
|
| 653 |
+
(input, grad_output, grad_input, grad_weight, grad_bias, weight, running_mean, running_var,
|
| 654 |
+
save_mean, save_invstd, train, epsilon);
|
| 655 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 656 |
+
|
| 657 |
+
return std::make_tuple(grad_input_, grad_weight_, grad_bias_);
|
| 658 |
+
}
|
| 659 |
+
|
| 660 |
+
template<typename scalar_t, typename index_t, typename VarTransform>
|
| 661 |
+
void batch_norm_stats_cuda_template(
|
| 662 |
+
const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input_, double epsilon) {
|
| 663 |
+
|
| 664 |
+
using accscalar_t = at::acc_type<scalar_t, true>;
|
| 665 |
+
int64_t n_input = input_.size(1);
|
| 666 |
+
Tensor dummy_mean_;
|
| 667 |
+
Tensor dummy_var_;
|
| 668 |
+
auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
|
| 669 |
+
|
| 670 |
+
resize_output(out_mean, {n_input});
|
| 671 |
+
resize_output(out_invstd, {n_input});
|
| 672 |
+
auto input = get_packed_accessor<
|
| 673 |
+
const scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input");
|
| 674 |
+
TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() &&
|
| 675 |
+
out_invstd.sizes()[0]);
|
| 676 |
+
TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() &&
|
| 677 |
+
out_mean.sizes()[0]);
|
| 678 |
+
|
| 679 |
+
auto mean = packed_accessor_or_dummy<
|
| 680 |
+
accscalar_t, 1, RestrictPtrTraits, index_t>(out_mean, "out_mean");
|
| 681 |
+
auto invstd = packed_accessor_or_dummy<
|
| 682 |
+
accscalar_t, 1, RestrictPtrTraits, index_t>(out_invstd, "out_invstd");
|
| 683 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
| 684 |
+
|
| 685 |
+
dim3 blocks(input.size(1));
|
| 686 |
+
int tf = getNumThreads(input.size(2));
|
| 687 |
+
dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
|
| 688 |
+
batch_norm_collect_statistics_kernel<VarTransform, scalar_t, scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
|
| 689 |
+
(input, epsilon, 0.0, mean, invstd);
|
| 690 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 691 |
+
}
|
| 692 |
+
|
| 693 |
+
template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
|
| 694 |
+
void batch_norm_elemt_cuda_template(const Tensor& output_, const Tensor& input_, const Tensor& weight_,
|
| 695 |
+
const Tensor& bias_, const Tensor& mean_, const Tensor& invstd_) {
|
| 696 |
+
|
| 697 |
+
using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
|
| 698 |
+
int64_t n_input = input_.size(1);
|
| 699 |
+
auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
|
| 700 |
+
auto output_reshaped = output_.view({input_.size(0), input_.size(1), -1});
|
| 701 |
+
|
| 702 |
+
auto input = get_packed_accessor<
|
| 703 |
+
const input_scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input");
|
| 704 |
+
auto output = get_packed_accessor<
|
| 705 |
+
input_scalar_t, 3, RestrictPtrTraits, index_t>(output_reshaped, "output");
|
| 706 |
+
auto weight = packed_accessor_or_dummy<
|
| 707 |
+
const stat_scalar_t, 1, RestrictPtrTraits, index_t>(weight_, "weight");
|
| 708 |
+
auto bias = packed_accessor_or_dummy<
|
| 709 |
+
const stat_scalar_t, 1, RestrictPtrTraits, index_t>(bias_, "bias");
|
| 710 |
+
auto mean = packed_accessor_or_dummy<
|
| 711 |
+
stat_accscalar_t, 1, RestrictPtrTraits, index_t>(mean_, "mean");
|
| 712 |
+
auto invstd = packed_accessor_or_dummy<
|
| 713 |
+
stat_accscalar_t, 1, RestrictPtrTraits, index_t>(invstd_, "invstd");
|
| 714 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
| 715 |
+
|
| 716 |
+
// NOTE: We use transform_input_kernel in training mode, which ignores epsilon
|
| 717 |
+
const double dummy_epsilon = 1e-5;
|
| 718 |
+
|
| 719 |
+
// The input_transform kernel is pointwise, but we need to balance reading parameters (save_var/mean,
|
| 720 |
+
// weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
|
| 721 |
+
// and good occupancy. Quiet likely, we could go with even more blocks than 1024.
|
| 722 |
+
// The various planes are independent, so we use blocks for them.
|
| 723 |
+
int tf = std::max<int>(getNumThreads(input.size(2)/4),
|
| 724 |
+
std::min<int>(getNumThreads(input.size(2)), 64));
|
| 725 |
+
int tb = std::max<int>(64/tf, 1);
|
| 726 |
+
dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
|
| 727 |
+
(input.size(0)+tb-1)/tb)));
|
| 728 |
+
blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE);
|
| 729 |
+
dim3 threads_trans(tf, tb);
|
| 730 |
+
batch_norm_transform_input_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, true, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
|
| 731 |
+
(input, output, mean, invstd, weight, bias, dummy_epsilon);
|
| 732 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 733 |
+
}
|
| 734 |
+
|
| 735 |
+
template<typename scalar_t, typename accscalar_t, typename index_t>
|
| 736 |
+
std::tuple<Tensor, Tensor> batch_norm_gather_stats_cuda_template(const Tensor& mean_, const Tensor& invstd_,
|
| 737 |
+
const Tensor& running_mean_, const Tensor& running_var_,
|
| 738 |
+
double momentum, double epsilon, const Tensor& counts_) {
|
| 739 |
+
|
| 740 |
+
Tensor save_mean_;
|
| 741 |
+
Tensor save_invstd_;
|
| 742 |
+
|
| 743 |
+
auto features = mean_.size(1);
|
| 744 |
+
auto input_options = mean_.options();
|
| 745 |
+
if (mean_.scalar_type() == at::ScalarType::Half || mean_.scalar_type() == at::ScalarType::BFloat16) {
|
| 746 |
+
input_options = input_options.dtype(ScalarType::Float);
|
| 747 |
+
}
|
| 748 |
+
save_mean_ = at::empty({features}, input_options);
|
| 749 |
+
save_invstd_ = at::empty({features}, input_options);
|
| 750 |
+
|
| 751 |
+
auto mean = packed_accessor_or_dummy<
|
| 752 |
+
accscalar_t, 2, RestrictPtrTraits, index_t>(mean_, "mean");
|
| 753 |
+
auto invstd = packed_accessor_or_dummy<
|
| 754 |
+
accscalar_t, 2, RestrictPtrTraits, index_t>(invstd_, "invstd");
|
| 755 |
+
auto running_mean = packed_accessor_or_dummy<
|
| 756 |
+
scalar_t, 1, RestrictPtrTraits, index_t>(running_mean_, "running_mean");
|
| 757 |
+
auto running_var = packed_accessor_or_dummy<
|
| 758 |
+
scalar_t, 1, RestrictPtrTraits, index_t>(running_var_, "running_mean");
|
| 759 |
+
auto counts = packed_accessor_or_dummy<
|
| 760 |
+
scalar_t, 1, RestrictPtrTraits, index_t>(counts_, "counts");
|
| 761 |
+
|
| 762 |
+
auto save_mean = get_packed_accessor<
|
| 763 |
+
accscalar_t, 1, RestrictPtrTraits, index_t>(save_mean_, "save_mean");
|
| 764 |
+
auto save_invstd = get_packed_accessor<
|
| 765 |
+
accscalar_t, 1, RestrictPtrTraits, index_t>(save_invstd_, "save_invstd");
|
| 766 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
| 767 |
+
|
| 768 |
+
int block = getNumThreads(features);
|
| 769 |
+
int grid = std::max<int>(1, features/block);
|
| 770 |
+
batch_norm_reduce_statistics_kernel<scalar_t, accscalar_t, index_t> <<<grid, block, 0, stream>>>
|
| 771 |
+
(mean, invstd, save_mean, save_invstd, running_mean, running_var, epsilon, momentum, counts);
|
| 772 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 773 |
+
|
| 774 |
+
return std::make_tuple(save_mean_, save_invstd_);
|
| 775 |
+
}
|
| 776 |
+
|
| 777 |
+
template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
|
| 778 |
+
std::tuple<Tensor, Tensor, Tensor, Tensor> batch_norm_backward_reduce_cuda_template(const Tensor& grad_out_, const Tensor& input_,
|
| 779 |
+
const Tensor& mean_, const Tensor& invstd_, const Tensor& weight_,
|
| 780 |
+
const bool input_g, const bool weight_g, const bool bias_g) {
|
| 781 |
+
|
| 782 |
+
using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
|
| 783 |
+
int64_t n_input = input_.size(1);
|
| 784 |
+
Tensor sum_dy_;
|
| 785 |
+
Tensor sum_dy_xmu_;
|
| 786 |
+
Tensor grad_weight_;
|
| 787 |
+
Tensor grad_bias_;
|
| 788 |
+
auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
|
| 789 |
+
auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
|
| 790 |
+
|
| 791 |
+
if (input_g) {
|
| 792 |
+
sum_dy_ = at::empty_like(mean_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
| 793 |
+
sum_dy_xmu_ = at::empty_like(mean_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
| 794 |
+
}
|
| 795 |
+
if (weight_g) {
|
| 796 |
+
grad_weight_ = at::empty({n_input}, weight_.options());
|
| 797 |
+
}
|
| 798 |
+
if (bias_g) {
|
| 799 |
+
grad_bias_ = at::empty({n_input}, weight_.options());
|
| 800 |
+
}
|
| 801 |
+
|
| 802 |
+
auto input = get_packed_accessor<
|
| 803 |
+
input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
|
| 804 |
+
auto grad_output = get_packed_accessor<
|
| 805 |
+
input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
|
| 806 |
+
auto grad_weight = packed_accessor_or_dummy<
|
| 807 |
+
stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_, "grad_weight");
|
| 808 |
+
auto grad_bias = packed_accessor_or_dummy<
|
| 809 |
+
stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_, "grad_bias");
|
| 810 |
+
auto mean = packed_accessor_or_dummy<
|
| 811 |
+
stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean");
|
| 812 |
+
auto invstd = packed_accessor_or_dummy<
|
| 813 |
+
stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd");
|
| 814 |
+
auto sum_dy = packed_accessor_or_dummy<
|
| 815 |
+
stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy");
|
| 816 |
+
auto sum_dy_xmu = packed_accessor_or_dummy<
|
| 817 |
+
stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu");
|
| 818 |
+
|
| 819 |
+
auto batch_size = input_reshaped.size(0);
|
| 820 |
+
auto feature_size = input_reshaped.size(2);
|
| 821 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
| 822 |
+
|
| 823 |
+
int warp_size = at::cuda::warp_size();
|
| 824 |
+
int block_y = std::min<int>(lastPow2(batch_size), MAX_BLOCK_SIZE/warp_size);
|
| 825 |
+
// We want block_x to be at least a warp width
|
| 826 |
+
int block_x = std::min<int>(std::max<int>(getNumThreads(feature_size), warp_size), MAX_BLOCK_SIZE/block_y);
|
| 827 |
+
const dim3 block(block_x, block_y);
|
| 828 |
+
const dim3 grid(n_input);
|
| 829 |
+
|
| 830 |
+
batch_norm_backward_reduce_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, index_t> <<<grid, block, 0, stream>>>
|
| 831 |
+
(input, grad_output, mean, invstd, sum_dy, sum_dy_xmu, grad_weight, grad_bias);
|
| 832 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 833 |
+
|
| 834 |
+
return std::make_tuple(sum_dy_, sum_dy_xmu_, grad_weight_, grad_bias_);
|
| 835 |
+
}
|
| 836 |
+
|
| 837 |
+
template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
|
| 838 |
+
Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Tensor& input_,
|
| 839 |
+
const Tensor& mean_, const Tensor& invstd_,
|
| 840 |
+
const Tensor& weight_, const Tensor& sum_dy_, const Tensor& sum_dy_xmu_) {
|
| 841 |
+
|
| 842 |
+
using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
|
| 843 |
+
int64_t n_input = input_.size(1);
|
| 844 |
+
auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
|
| 845 |
+
auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
|
| 846 |
+
auto grad_input_reshaped = at::empty_like(input_reshaped, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
| 847 |
+
|
| 848 |
+
auto input = get_packed_accessor<
|
| 849 |
+
input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
|
| 850 |
+
auto grad_input = get_packed_accessor<
|
| 851 |
+
input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input");
|
| 852 |
+
auto grad_output = get_packed_accessor<
|
| 853 |
+
input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
|
| 854 |
+
auto mean = packed_accessor_or_dummy<
|
| 855 |
+
stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean");
|
| 856 |
+
auto invstd = packed_accessor_or_dummy<
|
| 857 |
+
stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd");
|
| 858 |
+
auto weight = packed_accessor_or_dummy<
|
| 859 |
+
stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight");
|
| 860 |
+
auto sum_dy = packed_accessor_or_dummy<
|
| 861 |
+
stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy");
|
| 862 |
+
auto sum_dy_xmu = packed_accessor_or_dummy<
|
| 863 |
+
stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu");
|
| 864 |
+
|
| 865 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
| 866 |
+
|
| 867 |
+
// The kernel is pointwise, but we need to balance reading parameters (save_var/mean,
|
| 868 |
+
// weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
|
| 869 |
+
// and good occupancy. Quiet likely, we could go with even more blocks than 1024.
|
| 870 |
+
// The various planes are independent, so we use blocks for them.
|
| 871 |
+
int tf = std::max<int>(getNumThreads(input.size(2)/4),
|
| 872 |
+
std::min<int>(getNumThreads(input.size(2)), 64));
|
| 873 |
+
int tb = std::max<int>(64/tf, 1);
|
| 874 |
+
dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
|
| 875 |
+
(input.size(0)+tb-1)/tb)));
|
| 876 |
+
blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE);
|
| 877 |
+
dim3 threads_trans(tf, tb);
|
| 878 |
+
auto reduction_size = input_.numel() / n_input;
|
| 879 |
+
auto norm_fct = static_cast<stat_accscalar_t>(1.0 / reduction_size);
|
| 880 |
+
batch_norm_backward_elemt_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, index_t>
|
| 881 |
+
<<<blocks_trans, threads_trans, 0, stream>>>
|
| 882 |
+
(input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct);
|
| 883 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 884 |
+
|
| 885 |
+
return grad_input_reshaped.view(input_.sizes());
|
| 886 |
+
}
|
| 887 |
+
|
| 888 |
+
template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
|
| 889 |
+
Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Tensor& input_,
|
| 890 |
+
const Tensor& mean_, const Tensor& invstd_,
|
| 891 |
+
const Tensor& weight_, const Tensor& sum_dy_, const Tensor& sum_dy_xmu_, const Tensor& count) {
|
| 892 |
+
|
| 893 |
+
using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
|
| 894 |
+
int64_t n_input = input_.size(1);
|
| 895 |
+
auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
|
| 896 |
+
auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
|
| 897 |
+
auto grad_input_reshaped = at::empty_like(input_reshaped, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
| 898 |
+
|
| 899 |
+
auto input = get_packed_accessor<
|
| 900 |
+
input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
|
| 901 |
+
auto grad_input = get_packed_accessor<
|
| 902 |
+
input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input");
|
| 903 |
+
auto grad_output = get_packed_accessor<
|
| 904 |
+
input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
|
| 905 |
+
auto mean = packed_accessor_or_dummy<
|
| 906 |
+
stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean");
|
| 907 |
+
auto invstd = packed_accessor_or_dummy<
|
| 908 |
+
stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd");
|
| 909 |
+
auto weight = packed_accessor_or_dummy<
|
| 910 |
+
stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight");
|
| 911 |
+
auto sum_dy = packed_accessor_or_dummy<
|
| 912 |
+
stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy");
|
| 913 |
+
auto sum_dy_xmu = packed_accessor_or_dummy<
|
| 914 |
+
stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu");
|
| 915 |
+
|
| 916 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
| 917 |
+
|
| 918 |
+
// The kernel is pointwise, but we need to balance reading parameters (save_var/mean,
|
| 919 |
+
// weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
|
| 920 |
+
// and good occupancy. Quiet likely, we could go with even more blocks than 1024.
|
| 921 |
+
// The various planes are independent, so we use blocks for them.
|
| 922 |
+
int tf = std::max<int>(getNumThreads(input.size(2)/4),
|
| 923 |
+
std::min<int>(getNumThreads(input.size(2)), 64));
|
| 924 |
+
int tb = std::max<int>(64/tf, 1);
|
| 925 |
+
dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
|
| 926 |
+
(input.size(0)+tb-1)/tb)));
|
| 927 |
+
blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE);
|
| 928 |
+
dim3 threads_trans(tf, tb);
|
| 929 |
+
batch_norm_backward_elemt_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
|
| 930 |
+
(input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, count.const_data_ptr<int>(), count.numel());
|
| 931 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 932 |
+
|
| 933 |
+
return grad_input_reshaped.view(input_.sizes());
|
| 934 |
+
}
|
| 935 |
+
|
| 936 |
+
// welford kernel for c last tensor calculating mean/biased_variance/unbiased_variance
|
| 937 |
+
// original apex name: welford_kernel_c_last
|
| 938 |
+
template
|
| 939 |
+
<typename VarTransform,
|
| 940 |
+
typename scalar_t,
|
| 941 |
+
typename accscalar_t,
|
| 942 |
+
int PARALLEL_LOADS>
|
| 943 |
+
__global__ void
|
| 944 |
+
batch_norm_collect_statistics_channels_last_kernel(
|
| 945 |
+
const scalar_t* __restrict__ input,
|
| 946 |
+
accscalar_t* __restrict__ out_mean,
|
| 947 |
+
accscalar_t* __restrict__ out_invstd,
|
| 948 |
+
volatile accscalar_t* staging_data,
|
| 949 |
+
int* semaphores,
|
| 950 |
+
const int reduction_size,
|
| 951 |
+
const int stride,
|
| 952 |
+
accscalar_t epsilon) {
|
| 953 |
+
// hide latency with concurrency
|
| 954 |
+
accscalar_t x_mean[PARALLEL_LOADS];
|
| 955 |
+
accscalar_t m_2_n[PARALLEL_LOADS];
|
| 956 |
+
int count[PARALLEL_LOADS];
|
| 957 |
+
|
| 958 |
+
#pragma unroll
|
| 959 |
+
for (int i = 0; i < PARALLEL_LOADS; i++) {
|
| 960 |
+
x_mean[i] = accscalar_t(0);
|
| 961 |
+
m_2_n[i] = accscalar_t(0);
|
| 962 |
+
count[i] = accscalar_t(0);
|
| 963 |
+
}
|
| 964 |
+
// tensor dimension (m,c)
|
| 965 |
+
|
| 966 |
+
// loop along m dimension
|
| 967 |
+
int inner_loop_stride = blockDim.y * gridDim.y;
|
| 968 |
+
|
| 969 |
+
// offset along m dimension
|
| 970 |
+
int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
|
| 971 |
+
int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
|
| 972 |
+
|
| 973 |
+
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
|
| 974 |
+
int address_base = m_offset * stride + c_offset;
|
| 975 |
+
int address_increment = inner_loop_stride * stride;
|
| 976 |
+
|
| 977 |
+
for (int i = 0; i < loop_count; i++) {
|
| 978 |
+
accscalar_t x_math[PARALLEL_LOADS];
|
| 979 |
+
accscalar_t x_count_inv[PARALLEL_LOADS];
|
| 980 |
+
accscalar_t is_valid[PARALLEL_LOADS];
|
| 981 |
+
|
| 982 |
+
// load multiple data in
|
| 983 |
+
#pragma unroll
|
| 984 |
+
for (int j = 0; j < PARALLEL_LOADS; j++) {
|
| 985 |
+
if (c_offset < stride && m_offset < reduction_size) {
|
| 986 |
+
x_math[j] = input[address_base];
|
| 987 |
+
count[j]++;
|
| 988 |
+
x_count_inv[j] = accscalar_t(1) / count[j];
|
| 989 |
+
is_valid[j] = accscalar_t(1);
|
| 990 |
+
} else {
|
| 991 |
+
x_math[j] = accscalar_t(0);
|
| 992 |
+
x_count_inv[j] = accscalar_t(0);
|
| 993 |
+
is_valid[j] = accscalar_t(0);
|
| 994 |
+
}
|
| 995 |
+
m_offset += inner_loop_stride;
|
| 996 |
+
address_base += address_increment;
|
| 997 |
+
}
|
| 998 |
+
|
| 999 |
+
// calculate mean/m2n with welford
|
| 1000 |
+
#pragma unroll
|
| 1001 |
+
for (int j = 0; j < PARALLEL_LOADS; j++) {
|
| 1002 |
+
accscalar_t delta0 = x_math[j] - x_mean[j];
|
| 1003 |
+
x_mean[j] += delta0 * x_count_inv[j];
|
| 1004 |
+
accscalar_t delta1 = x_math[j] - x_mean[j];
|
| 1005 |
+
m_2_n[j] += delta0 * delta1 * is_valid[j];
|
| 1006 |
+
}
|
| 1007 |
+
}
|
| 1008 |
+
|
| 1009 |
+
// thread reduction to accumulate mean/m_2_n/count between PARALLEL_LOADS
|
| 1010 |
+
#pragma unroll
|
| 1011 |
+
for (int j = 1; j < PARALLEL_LOADS; j++) {
|
| 1012 |
+
welford_merge_element(count[0], x_mean[0], m_2_n[0], count[j], x_mean[j], m_2_n[j]);
|
| 1013 |
+
}
|
| 1014 |
+
|
| 1015 |
+
// release x_mean / m_2_n
|
| 1016 |
+
auto mean_th = x_mean[0];
|
| 1017 |
+
auto m2_th = m_2_n[0];
|
| 1018 |
+
auto count_th = count[0];
|
| 1019 |
+
|
| 1020 |
+
// block-wise reduction with shared memory (since reduction cannot be done within a warp)
|
| 1021 |
+
static __shared__ accscalar_t shmem_mean[MAX_BLOCK_SIZE];
|
| 1022 |
+
static __shared__ accscalar_t shmem_m2n[MAX_BLOCK_SIZE];
|
| 1023 |
+
static __shared__ int shmem_count[MAX_BLOCK_SIZE];
|
| 1024 |
+
|
| 1025 |
+
welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);
|
| 1026 |
+
|
| 1027 |
+
if (gridDim.y > 1) {
|
| 1028 |
+
volatile accscalar_t* staging_mean = staging_data;
|
| 1029 |
+
volatile accscalar_t* staging_m2n = &staging_data[stride*gridDim.y];
|
| 1030 |
+
volatile int* staging_count = reinterpret_cast<volatile int*>(&staging_m2n[stride*gridDim.y]);
|
| 1031 |
+
|
| 1032 |
+
address_base = c_offset + blockIdx.y * stride;
|
| 1033 |
+
// write data to staging_data;
|
| 1034 |
+
if (threadIdx.y == 0 && c_offset < stride) {
|
| 1035 |
+
staging_mean[address_base] = mean_th;
|
| 1036 |
+
staging_m2n[address_base] = m2_th;
|
| 1037 |
+
staging_count[address_base] = count_th;
|
| 1038 |
+
}
|
| 1039 |
+
|
| 1040 |
+
__threadfence();
|
| 1041 |
+
__syncthreads(); // ensuring writes to staging_ is visible to all blocks
|
| 1042 |
+
|
| 1043 |
+
__shared__ bool is_last_block_done;
|
| 1044 |
+
// mark block done
|
| 1045 |
+
if (threadIdx.x == 0 && threadIdx.y == 0) {
|
| 1046 |
+
int old = atomicAdd(&semaphores[blockIdx.x], 1);
|
| 1047 |
+
is_last_block_done = (old == (gridDim.y-1));
|
| 1048 |
+
}
|
| 1049 |
+
|
| 1050 |
+
__syncthreads();
|
| 1051 |
+
|
| 1052 |
+
// check that all data is now available in global memory
|
| 1053 |
+
if (is_last_block_done) {
|
| 1054 |
+
count_th = 0;
|
| 1055 |
+
mean_th = accscalar_t(0.0);
|
| 1056 |
+
m2_th = accscalar_t(0.0);
|
| 1057 |
+
|
| 1058 |
+
for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
|
| 1059 |
+
address_base = c_offset + y * stride;
|
| 1060 |
+
int count_new = c_offset < stride ? staging_count[address_base] : 0;
|
| 1061 |
+
accscalar_t mean_new = c_offset < stride ? staging_mean[address_base] : accscalar_t(0.0);
|
| 1062 |
+
accscalar_t m2n_new = c_offset < stride ? staging_m2n[address_base] : accscalar_t(0.0);
|
| 1063 |
+
|
| 1064 |
+
welford_merge_element(count_th, mean_th, m2_th, count_new, mean_new, m2n_new);
|
| 1065 |
+
}
|
| 1066 |
+
|
| 1067 |
+
welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);
|
| 1068 |
+
if (threadIdx.y == 0 && c_offset < stride) {
|
| 1069 |
+
out_mean[c_offset] = static_cast<accscalar_t>(mean_th);
|
| 1070 |
+
out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon);
|
| 1071 |
+
}
|
| 1072 |
+
}
|
| 1073 |
+
} else {
|
| 1074 |
+
if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {
|
| 1075 |
+
out_mean[c_offset] = static_cast<accscalar_t>(mean_th);
|
| 1076 |
+
out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon);
|
| 1077 |
+
}
|
| 1078 |
+
}
|
| 1079 |
+
}
|
| 1080 |
+
|
| 1081 |
+
// elementwise BN kernel
|
| 1082 |
+
// original apex name: batchnorm_forward_c_last_kernel
|
| 1083 |
+
template <
|
| 1084 |
+
typename scalar_t,
|
| 1085 |
+
typename accscalar_t,
|
| 1086 |
+
typename layerscalar_t,
|
| 1087 |
+
int PARALLEL_LOADS>
|
| 1088 |
+
__global__ void batch_norm_transform_input_channels_last_kernel(
|
| 1089 |
+
const scalar_t* __restrict__ input,
|
| 1090 |
+
const scalar_t* __restrict__ z,
|
| 1091 |
+
const accscalar_t* __restrict__ mean,
|
| 1092 |
+
const accscalar_t* __restrict__ inv_std,
|
| 1093 |
+
const layerscalar_t* __restrict__ weight,
|
| 1094 |
+
const layerscalar_t* __restrict__ shift,
|
| 1095 |
+
scalar_t* __restrict__ out,
|
| 1096 |
+
const int reduction_size,
|
| 1097 |
+
const int stride,
|
| 1098 |
+
const bool fuse_relu) {
|
| 1099 |
+
// tensor dimension (m,c)
|
| 1100 |
+
// loop along m dimension
|
| 1101 |
+
int inner_loop_stride = blockDim.y * gridDim.y;
|
| 1102 |
+
|
| 1103 |
+
// offset along m dimension
|
| 1104 |
+
int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
|
| 1105 |
+
int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
|
| 1106 |
+
|
| 1107 |
+
if (c_offset >= stride || m_offset >= reduction_size) {
|
| 1108 |
+
return;
|
| 1109 |
+
}
|
| 1110 |
+
|
| 1111 |
+
auto m_c = mean[c_offset];
|
| 1112 |
+
auto inv_std_c = static_cast<accscalar_t>(inv_std[c_offset]);
|
| 1113 |
+
auto w_c = weight == nullptr ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset]);
|
| 1114 |
+
auto s_c = shift == nullptr ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[c_offset]);
|
| 1115 |
+
|
| 1116 |
+
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
|
| 1117 |
+
int address_base = m_offset * stride + c_offset;
|
| 1118 |
+
int address_increment = inner_loop_stride * stride;
|
| 1119 |
+
|
| 1120 |
+
for (int i = 0; i < loop_count; i++) {
|
| 1121 |
+
#pragma unroll
|
| 1122 |
+
for (int j = 0; j < PARALLEL_LOADS; j++) {
|
| 1123 |
+
if (c_offset < stride && m_offset < reduction_size) {
|
| 1124 |
+
auto tmp = w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c;
|
| 1125 |
+
if (z != nullptr) {
|
| 1126 |
+
tmp += z[address_base];
|
| 1127 |
+
}
|
| 1128 |
+
out[address_base] = (fuse_relu && tmp <= accscalar_t(0.0) ? scalar_t(0.0) : static_cast<scalar_t>(tmp));
|
| 1129 |
+
}
|
| 1130 |
+
m_offset += inner_loop_stride;
|
| 1131 |
+
address_base += address_increment;
|
| 1132 |
+
}
|
| 1133 |
+
}
|
| 1134 |
+
}
|
| 1135 |
+
|
| 1136 |
+
template<typename T>
|
| 1137 |
+
__device__ __forceinline__ void merge_block_vertical_backward(T& sum_dy,
|
| 1138 |
+
T& sum_dy_xmu,
|
| 1139 |
+
T* shmem_sum_dy,
|
| 1140 |
+
T* shmem_sum_dy_xmu) {
|
| 1141 |
+
// write to shared memory
|
| 1142 |
+
auto address_base = threadIdx.x + threadIdx.y * blockDim.x;
|
| 1143 |
+
|
| 1144 |
+
#pragma unroll
|
| 1145 |
+
for (int offset = blockDim.y/2; offset > 0; offset >>= 1) {
|
| 1146 |
+
if (threadIdx.y < offset*2) {
|
| 1147 |
+
shmem_sum_dy[address_base] = sum_dy;
|
| 1148 |
+
shmem_sum_dy_xmu[address_base] = sum_dy_xmu;
|
| 1149 |
+
}
|
| 1150 |
+
__syncthreads();
|
| 1151 |
+
if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
|
| 1152 |
+
auto address = address_base + offset * blockDim.x;
|
| 1153 |
+
|
| 1154 |
+
sum_dy += shmem_sum_dy[address];
|
| 1155 |
+
sum_dy_xmu += shmem_sum_dy_xmu[address];
|
| 1156 |
+
}
|
| 1157 |
+
}
|
| 1158 |
+
}
|
| 1159 |
+
|
| 1160 |
+
// batchnorm backward kernel for c last tensor
|
| 1161 |
+
// original apex name: reduce_bn_c_last_kernel
|
| 1162 |
+
template <
|
| 1163 |
+
int PARALLEL_LOADS,
|
| 1164 |
+
typename scalar_t,
|
| 1165 |
+
typename accscalar_t,
|
| 1166 |
+
typename layerscalar_t>
|
| 1167 |
+
__global__ void batch_norm_backward_reduce_channels_last_kernel(
|
| 1168 |
+
const scalar_t* __restrict__ input,
|
| 1169 |
+
const scalar_t* __restrict__ grad_output,
|
| 1170 |
+
const accscalar_t* __restrict__ mean,
|
| 1171 |
+
const accscalar_t* __restrict__ inv_std,
|
| 1172 |
+
accscalar_t* __restrict__ sum_dy_o,
|
| 1173 |
+
accscalar_t* __restrict__ sum_dy_xmu_o,
|
| 1174 |
+
layerscalar_t* __restrict__ grad_weight,
|
| 1175 |
+
layerscalar_t* __restrict__ grad_bias,
|
| 1176 |
+
volatile accscalar_t* staging_data,
|
| 1177 |
+
int* semaphores,
|
| 1178 |
+
const int reduction_size,
|
| 1179 |
+
const int stride) {
|
| 1180 |
+
|
| 1181 |
+
// hide latency with concurrency
|
| 1182 |
+
accscalar_t sum_dy[PARALLEL_LOADS];
|
| 1183 |
+
accscalar_t sum_dy_xmu[PARALLEL_LOADS];
|
| 1184 |
+
|
| 1185 |
+
#pragma unroll
|
| 1186 |
+
for (int i = 0; i < PARALLEL_LOADS; i++) {
|
| 1187 |
+
sum_dy[i] = accscalar_t(0);
|
| 1188 |
+
sum_dy_xmu[i] = accscalar_t(0);
|
| 1189 |
+
}
|
| 1190 |
+
// tensor dimension (m,c)
|
| 1191 |
+
|
| 1192 |
+
// loop along m dimension
|
| 1193 |
+
int inner_loop_stride = blockDim.y * gridDim.y;
|
| 1194 |
+
|
| 1195 |
+
// offset along m dimension
|
| 1196 |
+
int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
|
| 1197 |
+
int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
|
| 1198 |
+
|
| 1199 |
+
if (c_offset >= stride || m_offset >= reduction_size) {
|
| 1200 |
+
return;
|
| 1201 |
+
}
|
| 1202 |
+
|
| 1203 |
+
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
|
| 1204 |
+
int address_base = m_offset * stride + c_offset;
|
| 1205 |
+
int address_increment = inner_loop_stride * stride;
|
| 1206 |
+
|
| 1207 |
+
auto r_mean = mean[c_offset];
|
| 1208 |
+
auto factor = inv_std[c_offset];
|
| 1209 |
+
|
| 1210 |
+
for (int i = 0; i < loop_count; i++) {
|
| 1211 |
+
accscalar_t x_input[PARALLEL_LOADS];
|
| 1212 |
+
accscalar_t x_grad_output[PARALLEL_LOADS];
|
| 1213 |
+
|
| 1214 |
+
// load multiple data in
|
| 1215 |
+
#pragma unroll
|
| 1216 |
+
for (int j = 0; j < PARALLEL_LOADS; j++) {
|
| 1217 |
+
if (c_offset < stride && m_offset < reduction_size) {
|
| 1218 |
+
x_input[j] = input[address_base];
|
| 1219 |
+
x_grad_output[j] = grad_output[address_base];
|
| 1220 |
+
} else {
|
| 1221 |
+
x_input[j] = accscalar_t(0);
|
| 1222 |
+
x_grad_output[j] = accscalar_t(0);
|
| 1223 |
+
}
|
| 1224 |
+
m_offset += inner_loop_stride;
|
| 1225 |
+
address_base += address_increment;
|
| 1226 |
+
}
|
| 1227 |
+
|
| 1228 |
+
// calculate sum_dy / sum_dy_xmu
|
| 1229 |
+
#pragma unroll
|
| 1230 |
+
for (int j = 0; j < PARALLEL_LOADS; j++) {
|
| 1231 |
+
sum_dy[j] += x_grad_output[j];
|
| 1232 |
+
sum_dy_xmu[j] += x_grad_output[j] * (x_input[j] - r_mean);
|
| 1233 |
+
}
|
| 1234 |
+
}
|
| 1235 |
+
|
| 1236 |
+
// thread reduction to accumulate sum_dy / sum_dy_xmu between PARALLEL_LOADS
|
| 1237 |
+
#pragma unroll
|
| 1238 |
+
for (int j = 1; j < PARALLEL_LOADS; j++) {
|
| 1239 |
+
sum_dy[0] += sum_dy[j];
|
| 1240 |
+
sum_dy_xmu[0] += sum_dy_xmu[j];
|
| 1241 |
+
}
|
| 1242 |
+
|
| 1243 |
+
// release array of registers
|
| 1244 |
+
auto sum_dy_th = sum_dy[0];
|
| 1245 |
+
auto sum_dy_xmu_th = sum_dy_xmu[0];
|
| 1246 |
+
|
| 1247 |
+
// block-wise reduction with shared memory (since reduction cannot be done within a warp)
|
| 1248 |
+
static __shared__ accscalar_t shmem_sum_dy[MAX_BLOCK_SIZE];
|
| 1249 |
+
static __shared__ accscalar_t shmem_sum_dy_xmu[MAX_BLOCK_SIZE];
|
| 1250 |
+
|
| 1251 |
+
merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);
|
| 1252 |
+
|
| 1253 |
+
if (gridDim.y > 1) {
|
| 1254 |
+
volatile accscalar_t* staging_sum_dy = staging_data;
|
| 1255 |
+
volatile accscalar_t* staging_sum_dy_xmu = &staging_data[stride*gridDim.y];
|
| 1256 |
+
|
| 1257 |
+
address_base = c_offset + blockIdx.y * stride;
|
| 1258 |
+
// write data to staging_data;
|
| 1259 |
+
if (threadIdx.y == 0 && c_offset < stride) {
|
| 1260 |
+
staging_sum_dy[address_base] = sum_dy_th;
|
| 1261 |
+
staging_sum_dy_xmu[address_base] = sum_dy_xmu_th;
|
| 1262 |
+
}
|
| 1263 |
+
|
| 1264 |
+
__threadfence();
|
| 1265 |
+
__syncthreads(); // ensuring writes to staging_ is visible to all blocks
|
| 1266 |
+
|
| 1267 |
+
__shared__ bool is_last_block_done;
|
| 1268 |
+
// mark block done
|
| 1269 |
+
if (threadIdx.x == 0 && threadIdx.y == 0) {
|
| 1270 |
+
int old = atomicAdd(&semaphores[blockIdx.x], 1);
|
| 1271 |
+
is_last_block_done = (old == (gridDim.y-1));
|
| 1272 |
+
}
|
| 1273 |
+
|
| 1274 |
+
__syncthreads();
|
| 1275 |
+
|
| 1276 |
+
// check that all data is now available in global memory
|
| 1277 |
+
if (is_last_block_done) {
|
| 1278 |
+
sum_dy_th = accscalar_t(0.0);
|
| 1279 |
+
sum_dy_xmu_th = accscalar_t(0.0);
|
| 1280 |
+
|
| 1281 |
+
for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
|
| 1282 |
+
address_base = c_offset + y * stride;
|
| 1283 |
+
sum_dy_th += (c_offset < stride ? staging_sum_dy[address_base] : accscalar_t(0.0));
|
| 1284 |
+
sum_dy_xmu_th += (c_offset < stride ? staging_sum_dy_xmu[address_base] : accscalar_t(0.0));
|
| 1285 |
+
}
|
| 1286 |
+
|
| 1287 |
+
merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);
|
| 1288 |
+
if (threadIdx.y == 0 && c_offset < stride) {
|
| 1289 |
+
if (grad_bias != nullptr) {
|
| 1290 |
+
grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);
|
| 1291 |
+
}
|
| 1292 |
+
if (grad_weight != nullptr) {
|
| 1293 |
+
grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
|
| 1294 |
+
}
|
| 1295 |
+
//mean_dy[c_offset] = sum_dy_th / reduction_size;
|
| 1296 |
+
//mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
|
| 1297 |
+
sum_dy_o[c_offset] = sum_dy_th;
|
| 1298 |
+
sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;
|
| 1299 |
+
}
|
| 1300 |
+
}
|
| 1301 |
+
} else {
|
| 1302 |
+
if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {
|
| 1303 |
+
if (grad_bias != nullptr) {
|
| 1304 |
+
grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);
|
| 1305 |
+
}
|
| 1306 |
+
if (grad_weight != nullptr) {
|
| 1307 |
+
grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
|
| 1308 |
+
}
|
| 1309 |
+
//mean_dy[c_offset] = sum_dy_th / reduction_size;
|
| 1310 |
+
//mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
|
| 1311 |
+
sum_dy_o[c_offset] = sum_dy_th;
|
| 1312 |
+
sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;
|
| 1313 |
+
}
|
| 1314 |
+
}
|
| 1315 |
+
}
|
| 1316 |
+
|
| 1317 |
+
// elementwise BN kernel
|
| 1318 |
+
// original apex name: batchnorm_backward_c_last_kernel
|
| 1319 |
+
template <
|
| 1320 |
+
int PARALLEL_LOADS,
|
| 1321 |
+
typename scalar_t,
|
| 1322 |
+
typename accscalar_t,
|
| 1323 |
+
typename layerscalar_t>
|
| 1324 |
+
__device__ __forceinline__ void batch_norm_backward_elemt_channels_last_kernel_impl(
|
| 1325 |
+
const scalar_t* __restrict__ grad_output,
|
| 1326 |
+
const scalar_t* __restrict__ input,
|
| 1327 |
+
const accscalar_t* __restrict__ mean,
|
| 1328 |
+
const accscalar_t* __restrict__ inv_std,
|
| 1329 |
+
const layerscalar_t* __restrict__ weight,
|
| 1330 |
+
const accscalar_t* __restrict__ sum_dy,
|
| 1331 |
+
const accscalar_t* __restrict__ sum_dy_xmu,
|
| 1332 |
+
scalar_t* __restrict__ grad_input,
|
| 1333 |
+
const accscalar_t norm_fct,
|
| 1334 |
+
const int reduction_size,
|
| 1335 |
+
const int stride) {
|
| 1336 |
+
// tensor dimension (m,c)
|
| 1337 |
+
// loop along m dimension
|
| 1338 |
+
int inner_loop_stride = blockDim.y * gridDim.y;
|
| 1339 |
+
|
| 1340 |
+
// offset along m dimension
|
| 1341 |
+
int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
|
| 1342 |
+
int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
|
| 1343 |
+
|
| 1344 |
+
if (c_offset >= stride || m_offset >= reduction_size) {
|
| 1345 |
+
return;
|
| 1346 |
+
}
|
| 1347 |
+
|
| 1348 |
+
auto m_c = mean[c_offset];
|
| 1349 |
+
auto m_dy_c = sum_dy[c_offset] * norm_fct;
|
| 1350 |
+
auto factor_1_c = inv_std[c_offset];
|
| 1351 |
+
auto factor_2_c = (weight == nullptr? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset])) * factor_1_c;
|
| 1352 |
+
factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[c_offset] * norm_fct;
|
| 1353 |
+
|
| 1354 |
+
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
|
| 1355 |
+
int address_base = m_offset * stride + c_offset;
|
| 1356 |
+
int address_increment = inner_loop_stride * stride;
|
| 1357 |
+
|
| 1358 |
+
for (int i = 0; i < loop_count; i++) {
|
| 1359 |
+
#pragma unroll
|
| 1360 |
+
for (int j = 0; j < PARALLEL_LOADS; j++) {
|
| 1361 |
+
if (c_offset < stride && m_offset < reduction_size) {
|
| 1362 |
+
grad_input[address_base] = static_cast<scalar_t>(
|
| 1363 |
+
(static_cast<accscalar_t>(grad_output[address_base]) - m_dy_c -
|
| 1364 |
+
(static_cast<accscalar_t>(input[address_base]) - m_c) * factor_1_c)
|
| 1365 |
+
* factor_2_c);
|
| 1366 |
+
}
|
| 1367 |
+
m_offset += inner_loop_stride;
|
| 1368 |
+
address_base += address_increment;
|
| 1369 |
+
}
|
| 1370 |
+
}
|
| 1371 |
+
}
|
| 1372 |
+
|
| 1373 |
+
template <
|
| 1374 |
+
int PARALLEL_LOADS,
|
| 1375 |
+
typename scalar_t,
|
| 1376 |
+
typename accscalar_t,
|
| 1377 |
+
typename layerscalar_t>
|
| 1378 |
+
__global__ void batch_norm_backward_elemt_channels_last_kernel(
|
| 1379 |
+
const scalar_t* __restrict__ grad_output,
|
| 1380 |
+
const scalar_t* __restrict__ input,
|
| 1381 |
+
const accscalar_t* __restrict__ mean,
|
| 1382 |
+
const accscalar_t* __restrict__ inv_std,
|
| 1383 |
+
const layerscalar_t* __restrict__ weight,
|
| 1384 |
+
const accscalar_t* __restrict__ sum_dy,
|
| 1385 |
+
const accscalar_t* __restrict__ sum_dy_xmu,
|
| 1386 |
+
const int* __restrict__ numel,
|
| 1387 |
+
scalar_t* __restrict__ grad_input,
|
| 1388 |
+
const int64_t world_size,
|
| 1389 |
+
const int reduction_size,
|
| 1390 |
+
const int stride) {
|
| 1391 |
+
|
| 1392 |
+
int64_t total_numel = 0;
|
| 1393 |
+
for (int i = 0; i < world_size; i++) {
|
| 1394 |
+
total_numel += numel[i];
|
| 1395 |
+
}
|
| 1396 |
+
|
| 1397 |
+
auto norm_fct = static_cast<accscalar_t>(1) / static_cast<accscalar_t>(total_numel);
|
| 1398 |
+
batch_norm_backward_elemt_channels_last_kernel_impl<PARALLEL_LOADS>(
|
| 1399 |
+
grad_output, input, mean, inv_std, weight, sum_dy, sum_dy_xmu,
|
| 1400 |
+
grad_input, norm_fct, reduction_size, stride);
|
| 1401 |
+
}
|
| 1402 |
+
|
| 1403 |
+
template <
|
| 1404 |
+
int PARALLEL_LOADS,
|
| 1405 |
+
typename scalar_t,
|
| 1406 |
+
typename accscalar_t,
|
| 1407 |
+
typename layerscalar_t>
|
| 1408 |
+
__global__ void batch_norm_backward_elemt_channels_last_kernel(
|
| 1409 |
+
const scalar_t* __restrict__ grad_output,
|
| 1410 |
+
const scalar_t* __restrict__ input,
|
| 1411 |
+
const accscalar_t* __restrict__ mean,
|
| 1412 |
+
const accscalar_t* __restrict__ inv_std,
|
| 1413 |
+
const layerscalar_t* __restrict__ weight,
|
| 1414 |
+
const accscalar_t* __restrict__ sum_dy,
|
| 1415 |
+
const accscalar_t* __restrict__ sum_dy_xmu,
|
| 1416 |
+
scalar_t* __restrict__ grad_input,
|
| 1417 |
+
const accscalar_t norm_fct,
|
| 1418 |
+
const int reduction_size,
|
| 1419 |
+
const int stride) {
|
| 1420 |
+
batch_norm_backward_elemt_channels_last_kernel_impl<PARALLEL_LOADS>(
|
| 1421 |
+
grad_output, input, mean, inv_std, weight, sum_dy, sum_dy_xmu,
|
| 1422 |
+
grad_input, norm_fct, reduction_size, stride);
|
| 1423 |
+
}
|
| 1424 |
+
|
| 1425 |
+
template<typename scalar_t, typename VarTransform>
|
| 1426 |
+
void batch_norm_stats_channels_last_cuda_template(
|
| 1427 |
+
const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input, double epsilon) {
|
| 1428 |
+
using accscalar_t = at::acc_type<scalar_t, true>;
|
| 1429 |
+
|
| 1430 |
+
const auto stride = input.sizes()[1];
|
| 1431 |
+
const auto reduction_size = input.numel() / stride;
|
| 1432 |
+
|
| 1433 |
+
resize_output(out_mean, {stride});
|
| 1434 |
+
resize_output(out_invstd, {stride});
|
| 1435 |
+
TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() &&
|
| 1436 |
+
out_invstd.sizes()[0]);
|
| 1437 |
+
TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() &&
|
| 1438 |
+
out_mean.sizes()[0]);
|
| 1439 |
+
|
| 1440 |
+
dim3 block;
|
| 1441 |
+
dim3 grid;
|
| 1442 |
+
flexible_launch_configs(reduction_size, stride, block, grid, true);
|
| 1443 |
+
|
| 1444 |
+
at::Tensor staging_data;
|
| 1445 |
+
at::Tensor semaphores;
|
| 1446 |
+
if (grid.y > 1) {
|
| 1447 |
+
staging_data = at::empty({4*stride*grid.y}, out_mean.options());
|
| 1448 |
+
semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
|
| 1449 |
+
}
|
| 1450 |
+
|
| 1451 |
+
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.mutable_data_ptr<accscalar_t>() : nullptr;
|
| 1452 |
+
int* semaphores_ptr = grid.y > 1 ? semaphores.mutable_data_ptr<int>() : nullptr;
|
| 1453 |
+
batch_norm_collect_statistics_channels_last_kernel<VarTransform, scalar_t, accscalar_t, ELEMENTS_PER_ITER>
|
| 1454 |
+
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
|
| 1455 |
+
input.const_data_ptr<scalar_t>(),
|
| 1456 |
+
out_mean.mutable_data_ptr<accscalar_t>(),
|
| 1457 |
+
out_invstd.mutable_data_ptr<accscalar_t>(),
|
| 1458 |
+
staging_data_ptr,
|
| 1459 |
+
semaphores_ptr,
|
| 1460 |
+
reduction_size,
|
| 1461 |
+
stride,
|
| 1462 |
+
epsilon);
|
| 1463 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 1464 |
+
}
|
| 1465 |
+
|
| 1466 |
+
void batch_norm_elemt_channels_last_cuda_template(
|
| 1467 |
+
const at::Tensor& output,
|
| 1468 |
+
const at::Tensor& input,
|
| 1469 |
+
const at::Tensor& weight,
|
| 1470 |
+
const at::Tensor& shift, // bias of BN
|
| 1471 |
+
const at::Tensor& mean,
|
| 1472 |
+
const at::Tensor& inv_std,
|
| 1473 |
+
const std::optional<at::Tensor>& z = std::nullopt, // bias after BN
|
| 1474 |
+
const bool fuse_relu = false) {
|
| 1475 |
+
const auto stride = input.sizes()[1];
|
| 1476 |
+
const auto reduction_size = input.numel() / stride;
|
| 1477 |
+
|
| 1478 |
+
dim3 block;
|
| 1479 |
+
dim3 grid;
|
| 1480 |
+
flexible_launch_configs(reduction_size, stride, block, grid);
|
| 1481 |
+
|
| 1482 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
| 1483 |
+
const auto second_dtype = weight.defined() ? weight.scalar_type() :
|
| 1484 |
+
(shift.defined() ? shift.scalar_type() : input.scalar_type());
|
| 1485 |
+
|
| 1486 |
+
if (input.scalar_type() != second_dtype) {
|
| 1487 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward", [&] {
|
| 1488 |
+
using accscalar_t = at::acc_type<scalar_t, true>;
|
| 1489 |
+
batch_norm_transform_input_channels_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
|
| 1490 |
+
<<<grid, block, 0, stream>>>(
|
| 1491 |
+
input.const_data_ptr<scalar_t>(),
|
| 1492 |
+
z.has_value() ? z.value().const_data_ptr<scalar_t>() : nullptr,
|
| 1493 |
+
mean.const_data_ptr<accscalar_t>(),
|
| 1494 |
+
inv_std.const_data_ptr<accscalar_t>(),
|
| 1495 |
+
weight.defined() ? weight.const_data_ptr<accscalar_t>() : nullptr,
|
| 1496 |
+
shift.defined() ? shift.const_data_ptr<accscalar_t>() : nullptr,
|
| 1497 |
+
output.mutable_data_ptr<scalar_t>(),
|
| 1498 |
+
reduction_size,
|
| 1499 |
+
stride,
|
| 1500 |
+
fuse_relu);
|
| 1501 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 1502 |
+
});
|
| 1503 |
+
} else {
|
| 1504 |
+
if (weight.defined()){
|
| 1505 |
+
TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_forward: input.scalar_type() ", input.scalar_type(),
|
| 1506 |
+
" is not supported with weight.scalar_type() ", weight.scalar_type());
|
| 1507 |
+
}
|
| 1508 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward", [&] {
|
| 1509 |
+
using accscalar_t = at::acc_type<scalar_t, true>;
|
| 1510 |
+
batch_norm_transform_input_channels_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER>
|
| 1511 |
+
<<<grid, block, 0, stream>>>(
|
| 1512 |
+
input.const_data_ptr<scalar_t>(),
|
| 1513 |
+
z.has_value() ? z.value().const_data_ptr<scalar_t>() : nullptr,
|
| 1514 |
+
mean.const_data_ptr<accscalar_t>(),
|
| 1515 |
+
inv_std.const_data_ptr<accscalar_t>(),
|
| 1516 |
+
weight.defined() ? weight.const_data_ptr<scalar_t>() : nullptr,
|
| 1517 |
+
shift.defined() ? shift.const_data_ptr<scalar_t>(): nullptr,
|
| 1518 |
+
output.mutable_data_ptr<scalar_t>(),
|
| 1519 |
+
reduction_size,
|
| 1520 |
+
stride,
|
| 1521 |
+
fuse_relu);
|
| 1522 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 1523 |
+
});
|
| 1524 |
+
}
|
| 1525 |
+
}
|
| 1526 |
+
|
| 1527 |
+
std::tuple<Tensor, Tensor, Tensor, Tensor>
|
| 1528 |
+
batch_norm_backward_reduce_cuda_channels_last_template(const at::Tensor& grad_output,
|
| 1529 |
+
const at::Tensor& input,
|
| 1530 |
+
const at::Tensor& mean,
|
| 1531 |
+
const at::Tensor& inv_std,
|
| 1532 |
+
const at::Tensor& weight,
|
| 1533 |
+
const bool input_g, const bool weight_g, const bool bias_g) {
|
| 1534 |
+
const auto stride = input.sizes()[1];
|
| 1535 |
+
const auto reduction_size = input.numel() / stride;
|
| 1536 |
+
|
| 1537 |
+
at::Tensor sumn_dy = at::empty({stride}, mean.options());
|
| 1538 |
+
at::Tensor sum_dy_xmu = at::empty({stride}, mean.options());
|
| 1539 |
+
|
| 1540 |
+
at::Tensor grad_weight;
|
| 1541 |
+
at::Tensor grad_bias;
|
| 1542 |
+
if (weight.defined()) {
|
| 1543 |
+
grad_weight = at::empty({stride}, weight.options());
|
| 1544 |
+
grad_bias = at::empty({stride}, weight.options());
|
| 1545 |
+
} else {
|
| 1546 |
+
// because I cannot return an uninitialized at::Tensor
|
| 1547 |
+
grad_weight = at::empty({0}, mean.options());
|
| 1548 |
+
grad_bias = at::empty({0}, mean.options());
|
| 1549 |
+
}
|
| 1550 |
+
|
| 1551 |
+
dim3 block;
|
| 1552 |
+
dim3 grid;
|
| 1553 |
+
flexible_launch_configs(reduction_size, stride, block, grid, true);
|
| 1554 |
+
|
| 1555 |
+
at::Tensor staging_data;
|
| 1556 |
+
at::Tensor semaphores;
|
| 1557 |
+
if (grid.y > 1) {
|
| 1558 |
+
staging_data = at::empty({2*stride*grid.y}, mean.options());
|
| 1559 |
+
semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
|
| 1560 |
+
}
|
| 1561 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
| 1562 |
+
|
| 1563 |
+
if (weight.defined() && input.scalar_type() != weight.scalar_type()) {
|
| 1564 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_reduce", [&] {
|
| 1565 |
+
using accscalar_t = at::acc_type<scalar_t, true>;
|
| 1566 |
+
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.mutable_data_ptr<accscalar_t>() : nullptr;
|
| 1567 |
+
int* semaphores_ptr = grid.y > 1 ? semaphores.mutable_data_ptr<int>() : nullptr;
|
| 1568 |
+
batch_norm_backward_reduce_channels_last_kernel<ELEMENTS_PER_ITER>
|
| 1569 |
+
<<<grid, block, 0, stream>>>(
|
| 1570 |
+
input.const_data_ptr<scalar_t>(),
|
| 1571 |
+
grad_output.const_data_ptr<scalar_t>(),
|
| 1572 |
+
mean.const_data_ptr<accscalar_t>(),
|
| 1573 |
+
inv_std.const_data_ptr<accscalar_t>(),
|
| 1574 |
+
sumn_dy.mutable_data_ptr<accscalar_t>(),
|
| 1575 |
+
sum_dy_xmu.mutable_data_ptr<accscalar_t>(),
|
| 1576 |
+
grad_weight.mutable_data_ptr<accscalar_t>(),
|
| 1577 |
+
grad_bias.mutable_data_ptr<accscalar_t>(),
|
| 1578 |
+
staging_data_ptr,
|
| 1579 |
+
semaphores_ptr,
|
| 1580 |
+
reduction_size,
|
| 1581 |
+
stride);
|
| 1582 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 1583 |
+
});
|
| 1584 |
+
} else {
|
| 1585 |
+
if (weight.defined()) {
|
| 1586 |
+
TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_backward_reduce: input.scalar_type() ", input.scalar_type(),
|
| 1587 |
+
" is not supported with weight.scalar_type() ", weight.scalar_type());
|
| 1588 |
+
}
|
| 1589 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_reduce", [&] {
|
| 1590 |
+
using accscalar_t = at::acc_type<scalar_t, true>;
|
| 1591 |
+
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.mutable_data_ptr<accscalar_t>() : nullptr;
|
| 1592 |
+
int* semaphores_ptr = grid.y > 1 ? semaphores.mutable_data_ptr<int>() : nullptr;
|
| 1593 |
+
batch_norm_backward_reduce_channels_last_kernel<ELEMENTS_PER_ITER>
|
| 1594 |
+
<<<grid, block, 0, stream>>>(
|
| 1595 |
+
input.const_data_ptr<scalar_t>(),
|
| 1596 |
+
grad_output.const_data_ptr<scalar_t>(),
|
| 1597 |
+
mean.const_data_ptr<accscalar_t>(),
|
| 1598 |
+
inv_std.const_data_ptr<accscalar_t>(),
|
| 1599 |
+
sumn_dy.mutable_data_ptr<accscalar_t>(),
|
| 1600 |
+
sum_dy_xmu.mutable_data_ptr<accscalar_t>(),
|
| 1601 |
+
weight.defined() ? grad_weight.mutable_data_ptr<scalar_t>() : nullptr,
|
| 1602 |
+
weight.defined() ? grad_bias.mutable_data_ptr<scalar_t>() : nullptr,
|
| 1603 |
+
staging_data_ptr,
|
| 1604 |
+
semaphores_ptr,
|
| 1605 |
+
reduction_size,
|
| 1606 |
+
stride);
|
| 1607 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 1608 |
+
});
|
| 1609 |
+
}
|
| 1610 |
+
|
| 1611 |
+
return std::make_tuple(sumn_dy, sum_dy_xmu, grad_weight, grad_bias);
|
| 1612 |
+
}
|
| 1613 |
+
|
| 1614 |
+
at::Tensor batch_norm_backward_elemt_channels_last_cuda_template(
|
| 1615 |
+
const at::Tensor& grad_output,
|
| 1616 |
+
const at::Tensor& input,
|
| 1617 |
+
const at::Tensor& mean,
|
| 1618 |
+
const at::Tensor& inv_std,
|
| 1619 |
+
const at::Tensor& weight,
|
| 1620 |
+
const at::Tensor& sum_dy,
|
| 1621 |
+
const at::Tensor& sum_dy_xmu,
|
| 1622 |
+
const at::Tensor& count) {
|
| 1623 |
+
const auto stride = input.sizes()[1];
|
| 1624 |
+
const auto reduction_size = input.numel() / stride;
|
| 1625 |
+
|
| 1626 |
+
// Input is guarunteed to be channels-last compatible
|
| 1627 |
+
at::Tensor grad_input = at::empty_like(input);
|
| 1628 |
+
|
| 1629 |
+
dim3 block;
|
| 1630 |
+
dim3 grid;
|
| 1631 |
+
flexible_launch_configs(reduction_size, stride, block, grid);
|
| 1632 |
+
|
| 1633 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
| 1634 |
+
|
| 1635 |
+
if (weight.defined() && weight.scalar_type() != input.scalar_type()) {
|
| 1636 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_element", [&] {
|
| 1637 |
+
using accscalar_t = at::acc_type<scalar_t, true>;
|
| 1638 |
+
batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER>
|
| 1639 |
+
<<<grid, block, 0, stream>>>(
|
| 1640 |
+
grad_output.const_data_ptr<scalar_t>(),
|
| 1641 |
+
input.const_data_ptr<scalar_t>(),
|
| 1642 |
+
mean.const_data_ptr<accscalar_t>(),
|
| 1643 |
+
inv_std.const_data_ptr<accscalar_t>(),
|
| 1644 |
+
weight.const_data_ptr<accscalar_t>(),
|
| 1645 |
+
sum_dy.const_data_ptr<accscalar_t>(),
|
| 1646 |
+
sum_dy_xmu.const_data_ptr<accscalar_t>(),
|
| 1647 |
+
count.const_data_ptr<int>(),
|
| 1648 |
+
grad_input.mutable_data_ptr<scalar_t>(),
|
| 1649 |
+
count.numel(),
|
| 1650 |
+
reduction_size,
|
| 1651 |
+
stride);
|
| 1652 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 1653 |
+
});
|
| 1654 |
+
} else {
|
| 1655 |
+
if (weight.defined()) {
|
| 1656 |
+
TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_backward_element: input.scalar_type() ", input.scalar_type(),
|
| 1657 |
+
" is not supported with weight.scalar_type() ", weight.scalar_type());
|
| 1658 |
+
}
|
| 1659 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "batchnorm_backward_element", [&] {
|
| 1660 |
+
using accscalar_t = at::acc_type<scalar_t, true>;
|
| 1661 |
+
batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER>
|
| 1662 |
+
<<<grid, block, 0, stream>>>(
|
| 1663 |
+
grad_output.const_data_ptr<scalar_t>(),
|
| 1664 |
+
input.const_data_ptr<scalar_t>(),
|
| 1665 |
+
mean.const_data_ptr<accscalar_t>(),
|
| 1666 |
+
inv_std.const_data_ptr<accscalar_t>(),
|
| 1667 |
+
weight.defined() ? weight.const_data_ptr<scalar_t>() : nullptr,
|
| 1668 |
+
sum_dy.const_data_ptr<accscalar_t>(),
|
| 1669 |
+
sum_dy_xmu.const_data_ptr<accscalar_t>(),
|
| 1670 |
+
count.const_data_ptr<int>(),
|
| 1671 |
+
grad_input.mutable_data_ptr<scalar_t>(),
|
| 1672 |
+
count.numel(),
|
| 1673 |
+
reduction_size,
|
| 1674 |
+
stride);
|
| 1675 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 1676 |
+
});
|
| 1677 |
+
}
|
| 1678 |
+
|
| 1679 |
+
return grad_input;
|
| 1680 |
+
}
|
| 1681 |
+
|
| 1682 |
+
at::Tensor batch_norm_backward_elemt_channels_last_cuda_template(
|
| 1683 |
+
const at::Tensor& grad_output,
|
| 1684 |
+
const at::Tensor& input,
|
| 1685 |
+
const at::Tensor& mean,
|
| 1686 |
+
const at::Tensor& inv_std,
|
| 1687 |
+
const at::Tensor& weight,
|
| 1688 |
+
const at::Tensor& sum_dy,
|
| 1689 |
+
const at::Tensor& sum_dy_xmu) {
|
| 1690 |
+
const auto stride = input.sizes()[1];
|
| 1691 |
+
const auto reduction_size = input.numel() / stride;
|
| 1692 |
+
auto norm_fct = 1.0 / reduction_size;
|
| 1693 |
+
|
| 1694 |
+
// Input is guarunteed to be channels-last compatible
|
| 1695 |
+
at::Tensor grad_input = at::empty_like(input);
|
| 1696 |
+
|
| 1697 |
+
dim3 block;
|
| 1698 |
+
dim3 grid;
|
| 1699 |
+
flexible_launch_configs(reduction_size, stride, block, grid);
|
| 1700 |
+
|
| 1701 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
| 1702 |
+
|
| 1703 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_element", [&] {
|
| 1704 |
+
using accscalar_t = at::acc_type<scalar_t, true>;
|
| 1705 |
+
|
| 1706 |
+
if (weight.defined() && weight.scalar_type() != input.scalar_type()) {
|
| 1707 |
+
batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER>
|
| 1708 |
+
<<<grid, block, 0, stream>>>(
|
| 1709 |
+
grad_output.const_data_ptr<scalar_t>(),
|
| 1710 |
+
input.const_data_ptr<scalar_t>(),
|
| 1711 |
+
mean.const_data_ptr<accscalar_t>(),
|
| 1712 |
+
inv_std.const_data_ptr<accscalar_t>(),
|
| 1713 |
+
weight.const_data_ptr<accscalar_t>(),
|
| 1714 |
+
sum_dy.const_data_ptr<accscalar_t>(),
|
| 1715 |
+
sum_dy_xmu.const_data_ptr<accscalar_t>(),
|
| 1716 |
+
grad_input.mutable_data_ptr<scalar_t>(),
|
| 1717 |
+
static_cast<accscalar_t>(norm_fct),
|
| 1718 |
+
reduction_size,
|
| 1719 |
+
stride);
|
| 1720 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 1721 |
+
} else {
|
| 1722 |
+
batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER>
|
| 1723 |
+
<<<grid, block, 0, stream>>>(
|
| 1724 |
+
grad_output.const_data_ptr<scalar_t>(),
|
| 1725 |
+
input.const_data_ptr<scalar_t>(),
|
| 1726 |
+
mean.const_data_ptr<accscalar_t>(),
|
| 1727 |
+
inv_std.const_data_ptr<accscalar_t>(),
|
| 1728 |
+
weight.defined() ? weight.const_data_ptr<scalar_t>() : nullptr,
|
| 1729 |
+
sum_dy.const_data_ptr<accscalar_t>(),
|
| 1730 |
+
sum_dy_xmu.const_data_ptr<accscalar_t>(),
|
| 1731 |
+
grad_input.mutable_data_ptr<scalar_t>(),
|
| 1732 |
+
static_cast<accscalar_t>(norm_fct),
|
| 1733 |
+
reduction_size,
|
| 1734 |
+
stride);
|
| 1735 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 1736 |
+
}
|
| 1737 |
+
});
|
| 1738 |
+
|
| 1739 |
+
return grad_input;
|
| 1740 |
+
}
|
| 1741 |
+
|
| 1742 |
+
} } // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Pow.cuh
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/native/Pow.h>
|
| 3 |
+
#include <c10/core/Scalar.h>
|
| 4 |
+
|
| 5 |
+
namespace at { namespace native {
|
| 6 |
+
|
| 7 |
+
namespace {
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
// SFINAE doesn't work well with NVCC under Windows for math functions like pow and sqrt.
|
| 11 |
+
// So we need to define the functions with the explicit function signatures.
|
| 12 |
+
// As for pow, the following signatures are defined as the device function:
|
| 13 |
+
// pow(float, int)
|
| 14 |
+
// pow(double, int)
|
| 15 |
+
// pow(float, float)
|
| 16 |
+
// pow(double, double)
|
| 17 |
+
#ifdef _MSC_VER
|
| 18 |
+
// Functions for pow
|
| 19 |
+
// pow for at::Half
|
| 20 |
+
static inline __host__ __device__ at::Half pow_(at::Half base, at::Half exp) {
|
| 21 |
+
return static_cast<at::Half>(std::pow(static_cast<float>(base), static_cast<float>(exp)));
|
| 22 |
+
}
|
| 23 |
+
// pow for at::BFloat16
|
| 24 |
+
static inline __host__ __device__ at::BFloat16 pow_(at::BFloat16 base, at::BFloat16 exp) {
|
| 25 |
+
return static_cast<at::BFloat16>(std::pow(static_cast<float>(base), static_cast<float>(exp)));
|
| 26 |
+
}
|
| 27 |
+
// pow (floating, floating/int)
|
| 28 |
+
template <typename Base_type, typename Exp_type>
|
| 29 |
+
static inline __host__ __device__ typename std::enable_if<std::is_floating_point<Base_type>::value && (std::is_same<Base_type, Exp_type>::value || std::is_same<Exp_type, int>::value), Base_type>::type
|
| 30 |
+
pow_(Base_type base, Exp_type exp) {
|
| 31 |
+
return std::pow(base, exp);
|
| 32 |
+
}
|
| 33 |
+
// pow (Otherwise)
|
| 34 |
+
template <typename Base_type, typename Exp_type>
|
| 35 |
+
static inline __host__ __device__ typename std::enable_if<!std::is_same<Base_type, Exp_type>::value && !std::is_same<Exp_type, int>::value, Base_type>::type
|
| 36 |
+
pow_(Base_type base, Exp_type exp) {
|
| 37 |
+
return static_cast<Base_type>(std::pow(static_cast<double>(base), static_cast<double>(exp)));
|
| 38 |
+
}
|
| 39 |
+
#else
|
| 40 |
+
template <typename Base_type, typename Exp_type>
|
| 41 |
+
static inline __host__ __device__ Base_type pow_(Base_type base, Exp_type exp) {
|
| 42 |
+
return ::pow(base, exp);
|
| 43 |
+
}
|
| 44 |
+
#endif
|
| 45 |
+
|
| 46 |
+
template <typename T>
|
| 47 |
+
static inline __host__ __device__ std::enable_if_t<std::is_integral<T>::value, T> pow_(
|
| 48 |
+
T base, T exp) {
|
| 49 |
+
return at::native::powi(base, exp);
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
template <typename T>
|
| 53 |
+
static inline __host__ __device__ c10::complex<T> pow_(c10::complex<T> base, c10::complex<T> exp) {
|
| 54 |
+
return c10_complex_math::pow(base, exp);
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
} // namespace
|
| 58 |
+
}} // namespace at::native
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Randperm.cuh
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
| 2 |
+
#include <ATen/cuda/CUDAGraphsUtils.cuh>
|
| 3 |
+
#include <ATen/Utils.h>
|
| 4 |
+
|
| 5 |
+
#include <curand.h>
|
| 6 |
+
#include <curand_kernel.h>
|
| 7 |
+
#include <curand_philox4x32_x.h>
|
| 8 |
+
|
| 9 |
+
namespace {
|
| 10 |
+
|
| 11 |
+
// See note [Algorithm of randperm]
|
| 12 |
+
template<typename T, typename scalar_t>
|
| 13 |
+
__global__ void randperm_handle_duplicate_keys_kernel(T *keys, scalar_t *data, T mask, int n, at::PhiloxCudaState philox_args) {
|
| 14 |
+
int tid = threadIdx.x + blockDim.x * blockIdx.x;
|
| 15 |
+
|
| 16 |
+
// find the beginning of islands
|
| 17 |
+
if (tid >= n - 1) return; // out of range
|
| 18 |
+
if ((keys[tid] & mask) != (keys[tid + 1] & mask)) return; // not in an island
|
| 19 |
+
if (tid != 0 && (keys[tid] & mask) == (keys[tid - 1] & mask)) return; // not the beginning of an island
|
| 20 |
+
|
| 21 |
+
// find the size of islands
|
| 22 |
+
int island_size = 0;
|
| 23 |
+
do { island_size++; }
|
| 24 |
+
while ((tid + island_size < n) && (keys[tid + island_size] & mask) == (keys[tid] & mask));
|
| 25 |
+
|
| 26 |
+
// do random permutation inside each island.
|
| 27 |
+
data += tid;
|
| 28 |
+
auto seeds = at::cuda::philox::unpack(philox_args);
|
| 29 |
+
curandStatePhilox4_32_10_t state;
|
| 30 |
+
curand_init(std::get<0>(seeds), tid, std::get<1>(seeds), &state);
|
| 31 |
+
for (int i = island_size - 1; i > 0; i--) {
|
| 32 |
+
unsigned int r = curand(&state) % (i + 1);
|
| 33 |
+
if (i != r) {
|
| 34 |
+
scalar_t tmp = data[i];
|
| 35 |
+
data[i] = data[r];
|
| 36 |
+
data[r] = tmp;
|
| 37 |
+
}
|
| 38 |
+
}
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
// See note [Algorithm of randperm]
|
| 42 |
+
template<typename T, typename scalar_t>
|
| 43 |
+
void randperm_handle_duplicate_keys(T *keys, scalar_t *data, int bits, int64_t n, std::optional<at::Generator> &gen_) {
|
| 44 |
+
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
| 45 |
+
int64_t counter_offset = n;
|
| 46 |
+
at::PhiloxCudaState rng_engine_inputs;
|
| 47 |
+
{
|
| 48 |
+
// See Note [Acquire lock when using random generators]
|
| 49 |
+
std::lock_guard<std::mutex> lock(gen->mutex_);
|
| 50 |
+
rng_engine_inputs = gen->philox_cuda_state(counter_offset);
|
| 51 |
+
}
|
| 52 |
+
T mask = static_cast<T>((1UL << bits) - 1);
|
| 53 |
+
randperm_handle_duplicate_keys_kernel<<<(n + 511) / 512, 512, 0, at::cuda::getCurrentCUDAStream()>>>(
|
| 54 |
+
keys, data, mask, n, rng_engine_inputs);
|
| 55 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
}
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Resize.h
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/EmptyTensor.h>
|
| 4 |
+
#include <ATen/native/ResizeCommon.h>
|
| 5 |
+
|
| 6 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 7 |
+
|
| 8 |
+
namespace at { namespace native {
|
| 9 |
+
|
| 10 |
+
TORCH_CUDA_CPP_API void resize_bytes_cuda(StorageImpl* storage, size_t size_bytes);
|
| 11 |
+
|
| 12 |
+
static inline void maybe_resize_storage_cuda(TensorImpl* self, size_t new_size_bytes) {
|
| 13 |
+
// It does not make sense to try to resize a storage
|
| 14 |
+
// to hold 0 elements, and this can break
|
| 15 |
+
// if storage_offset is positive but
|
| 16 |
+
// new_size is 0, so just bail in that case
|
| 17 |
+
// (same comment is in Resize.h)
|
| 18 |
+
if (self->numel() == 0) {
|
| 19 |
+
return;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
const Storage &storage = self->unsafe_storage();
|
| 23 |
+
TORCH_CHECK(storage, "Tensor: invalid null storage");
|
| 24 |
+
if (new_size_bytes > storage.nbytes()) {
|
| 25 |
+
resize_bytes_cuda(storage.unsafeGetStorageImpl(), new_size_bytes);
|
| 26 |
+
}
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
inline TensorImpl* resize_impl_cuda_(
|
| 30 |
+
TensorImpl* self,
|
| 31 |
+
IntArrayRef size,
|
| 32 |
+
at::OptionalIntArrayRef stride) {
|
| 33 |
+
if (self->sizes() == size && (!stride || self->strides() == stride)) {
|
| 34 |
+
return self;
|
| 35 |
+
}
|
| 36 |
+
const auto itemsize = self->dtype().itemsize();
|
| 37 |
+
const auto storage_offset = self->storage_offset();
|
| 38 |
+
size_t storage_size = 1;
|
| 39 |
+
if (stride) {
|
| 40 |
+
self->set_sizes_and_strides(size, *stride);
|
| 41 |
+
storage_size = at::detail::computeStorageNbytes(
|
| 42 |
+
size, *stride, itemsize, storage_offset);
|
| 43 |
+
} else {
|
| 44 |
+
self->set_sizes_contiguous(size);
|
| 45 |
+
storage_size = at::detail::computeStorageNbytesContiguous(
|
| 46 |
+
size, itemsize, storage_offset);
|
| 47 |
+
}
|
| 48 |
+
maybe_resize_storage_cuda(self, storage_size);
|
| 49 |
+
|
| 50 |
+
return self;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
}}
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/RowwiseScaledMM.h
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/TensorBase.h>
|
| 3 |
+
#include <optional>
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
namespace at::cuda::detail {
|
| 7 |
+
TORCH_API void f8f8bf16_rowwise(
|
| 8 |
+
at::Tensor XQ, // FP8
|
| 9 |
+
at::Tensor WQ, // FP8
|
| 10 |
+
at::Tensor x_scale, // FP32
|
| 11 |
+
at::Tensor w_scale, // FP32
|
| 12 |
+
std::optional<at::Tensor> bias, // BF16
|
| 13 |
+
bool use_fast_accum,
|
| 14 |
+
at::Tensor& out);
|
| 15 |
+
} // at::cuda::detail
|
.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/ScanUtils.cuh
ADDED
|
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/NumericUtils.h>
|
| 3 |
+
#include <ATen/core/TensorBase.h>
|
| 4 |
+
#include <ATen/cuda/cub.cuh>
|
| 5 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 6 |
+
|
| 7 |
+
#include <c10/util/Load.h>
|
| 8 |
+
#include <limits>
|
| 9 |
+
#include <cmath>
|
| 10 |
+
|
| 11 |
+
namespace at {
|
| 12 |
+
namespace native {
|
| 13 |
+
|
| 14 |
+
template <typename integer>
|
| 15 |
+
constexpr inline integer ceil_div(integer n, integer m) {
|
| 16 |
+
return (n + m - 1) / m;
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
template <typename integer>
|
| 20 |
+
constexpr inline integer get_log_num_threads_x_inner_scan(integer num_rows, integer row_size) {
|
| 21 |
+
integer log_num_threads_x = 0;
|
| 22 |
+
integer log_num_threads_y = 0;
|
| 23 |
+
while (((integer)1 << log_num_threads_x) < row_size) {
|
| 24 |
+
++log_num_threads_x;
|
| 25 |
+
}
|
| 26 |
+
while (((integer)1 << log_num_threads_y) < num_rows) {
|
| 27 |
+
++log_num_threads_y;
|
| 28 |
+
}
|
| 29 |
+
// we want to keep the ratio between the x-threads and y-threads about the same as
|
| 30 |
+
// the ratio between the row_size and num_rows, but the total number of threads in
|
| 31 |
+
// a block should be about 512
|
| 32 |
+
integer diff = log_num_threads_x - log_num_threads_y;
|
| 33 |
+
// 9 is from log2(512)
|
| 34 |
+
log_num_threads_x = ((integer)9 + diff) / (integer)2;
|
| 35 |
+
// I found that in having larger log_num_threads_x can give significant speed up in some cases,
|
| 36 |
+
// but detrimental in another case, so just keep the lower bound to be log2(16) == 4 to make it
|
| 37 |
+
// similar to the previous implementation
|
| 38 |
+
// Keeping the upper bound to be log2(512) == 9 as the maximum number of threads in a block.
|
| 39 |
+
log_num_threads_x = std::min(std::max((integer)4, log_num_threads_x), (integer)9);
|
| 40 |
+
return log_num_threads_x;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
template<typename scalar_t, typename idx_t, typename BinaryOperation>
|
| 44 |
+
__device__ void binary_op_update(const scalar_t lhs, scalar_t& rhs, const idx_t lhs_idx, idx_t& rhs_idx, BinaryOperation binary_op) {
|
| 45 |
+
if(!at::_isnan(rhs) && (at::_isnan(lhs) || !binary_op(rhs, lhs))) {
|
| 46 |
+
rhs = lhs;
|
| 47 |
+
rhs_idx = lhs_idx;
|
| 48 |
+
}
|
| 49 |
+
}
|
| 50 |
+
/* Perform an inclusive scan along the innermost dimension of a tensor.
|
| 51 |
+
*
|
| 52 |
+
* - num_rows is the size of the flattened outer dimensions;
|
| 53 |
+
* - row_size is the size of the innermost dimension;
|
| 54 |
+
*
|
| 55 |
+
* The outer dimensions of the tensor are considered as a single dimension, i.e. the tensor is
|
| 56 |
+
* considered as having 'num_rows' rows of size 'row_size'.
|
| 57 |
+
* Each thread block processes one or more sets of contiguous rows (processing multiple rows
|
| 58 |
+
* per thread block is quicker than processing a single row, especially for short rows).
|
| 59 |
+
*/
|
| 60 |
+
template<typename scalar_t, class BinaryFunction>
|
| 61 |
+
__global__ void tensor_kernel_scan_innermost_dim_with_indices(const scalar_t *self_, scalar_t *values_, int64_t *indices_,
|
| 62 |
+
int num_rows, int row_size,
|
| 63 |
+
const uint32_t num_threads, const uint32_t log_num_threads_x,
|
| 64 |
+
scalar_t init, BinaryFunction binary_op) {
|
| 65 |
+
// dynamic memory allocation for vbuf and ibuf
|
| 66 |
+
alignas(sizeof(double)) extern __shared__ char buf[];
|
| 67 |
+
scalar_t* vbuf = reinterpret_cast<scalar_t*>(buf); // the size is num_threads * 2
|
| 68 |
+
int64_t* ibuf = reinterpret_cast<int64_t*>(vbuf + num_threads * 2);
|
| 69 |
+
const uint32_t num_threads_x = 1 << log_num_threads_x;
|
| 70 |
+
scalar_t* row_buf = vbuf + 2 * num_threads_x * threadIdx.y;
|
| 71 |
+
int64_t* row_idx_buf = ibuf + 2 * num_threads_x * threadIdx.y;
|
| 72 |
+
|
| 73 |
+
for (int block_row = blockIdx.x * blockDim.y;
|
| 74 |
+
block_row < num_rows;
|
| 75 |
+
block_row += blockDim.y * gridDim.x) {
|
| 76 |
+
int row = block_row + threadIdx.y;
|
| 77 |
+
const scalar_t *row_self = self_ + row * row_size;
|
| 78 |
+
scalar_t *row_values = values_ + row * row_size;
|
| 79 |
+
int64_t *row_indices = indices_ + row * row_size;
|
| 80 |
+
scalar_t block_total = init;
|
| 81 |
+
int64_t block_idx_final = 0;
|
| 82 |
+
const bool row_exists = row < num_rows;
|
| 83 |
+
// Perform scan on one block at a time, keeping track of the total value of
|
| 84 |
+
// all blocks processed so far.
|
| 85 |
+
for (int block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) {
|
| 86 |
+
// Load data into shared memory (two values per thread).
|
| 87 |
+
int col1 = block_col + threadIdx.x;
|
| 88 |
+
int col2 = block_col + num_threads_x + threadIdx.x;
|
| 89 |
+
if (row_exists) {
|
| 90 |
+
if (col1 < row_size) {
|
| 91 |
+
row_buf[threadIdx.x] = c10::load(&row_self[col1]);
|
| 92 |
+
row_idx_buf[threadIdx.x] = col1;
|
| 93 |
+
} else {
|
| 94 |
+
row_buf[threadIdx.x] = init;
|
| 95 |
+
// No need to set the index here as the value in init will never be selected
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
if (col2 < row_size) {
|
| 99 |
+
row_buf[num_threads_x + threadIdx.x] = c10::load(&row_self[col2]);
|
| 100 |
+
row_idx_buf[num_threads_x + threadIdx.x] = col2;
|
| 101 |
+
} else {
|
| 102 |
+
row_buf[num_threads_x + threadIdx.x] = init;
|
| 103 |
+
// No need to set the index here as the value in init will never be selected
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
// Add the total value of all previous blocks to the first value of this block.
|
| 107 |
+
if (threadIdx.x == 0) {
|
| 108 |
+
binary_op_update(block_total, row_buf[0], block_idx_final, row_idx_buf[0], binary_op);
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
__syncthreads();
|
| 112 |
+
|
| 113 |
+
// Parallel reduction with Sklansky method. The diagram can be seen on this paper:
|
| 114 |
+
// https://research.nvidia.com/publication/single-pass-parallel-prefix-scan-decoupled-look-back
|
| 115 |
+
for (uint32_t s = 1; s <= num_threads_x; s <<= 1) {
|
| 116 |
+
if (row_exists) {
|
| 117 |
+
uint32_t a = (threadIdx.x / s) * (2 * s) + s;
|
| 118 |
+
uint32_t ti = a + (threadIdx.x % s);
|
| 119 |
+
uint32_t si = a - 1;
|
| 120 |
+
binary_op_update(row_buf[si], row_buf[ti], row_idx_buf[si], row_idx_buf[ti], binary_op);
|
| 121 |
+
}
|
| 122 |
+
__syncthreads();
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
// Write back to output.
|
| 126 |
+
if (row_exists) {
|
| 127 |
+
if (col1 < row_size){
|
| 128 |
+
row_values[col1] = row_buf[threadIdx.x];
|
| 129 |
+
row_indices[col1] = row_idx_buf[threadIdx.x];
|
| 130 |
+
}
|
| 131 |
+
if (col2 < row_size) {
|
| 132 |
+
row_values[col2] = row_buf[num_threads_x + threadIdx.x];
|
| 133 |
+
row_indices[col2] = row_idx_buf[num_threads_x + threadIdx.x];
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
block_total = row_buf[2 * num_threads_x - 1];
|
| 137 |
+
block_idx_final = row_idx_buf[2 * num_threads_x - 1];
|
| 138 |
+
__syncthreads();
|
| 139 |
+
}
|
| 140 |
+
}
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
/* Perform an inclusive scan along an outer dimension of a tensor.
|
| 144 |
+
*
|
| 145 |
+
* - num_orows is the size of the flattened outer dimensions;
|
| 146 |
+
* - num_irows is the size of the flattened inner dimensions;
|
| 147 |
+
* - row_size is the size of the dimension along which to compute the variance;
|
| 148 |
+
*
|
| 149 |
+
* The dimensions to the outside and inside of the specified dimension are considered as flattened.
|
| 150 |
+
* Thread blocks with the same blockIdx.y process an "outer row" (i.e. an element of the flattened
|
| 151 |
+
* outer dimensions, which contains several "inner rows").
|
| 152 |
+
* Each thread processes a single inner row at a time.
|
| 153 |
+
*/
|
| 154 |
+
template<typename scalar_t, class BinaryFunction>
|
| 155 |
+
__global__ void tensor_kernel_scan_outer_dim_with_indices(const scalar_t *self_, scalar_t *values_, int64_t *indices_,
|
| 156 |
+
const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size, scalar_t init, BinaryFunction binary_op) {
|
| 157 |
+
for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
|
| 158 |
+
for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
|
| 159 |
+
const scalar_t *self = self_ + orow * row_size * num_irows + irow;
|
| 160 |
+
scalar_t *values = values_ + orow * row_size * num_irows + irow;
|
| 161 |
+
int64_t *indices = indices_ + orow * row_size * num_irows + irow;
|
| 162 |
+
scalar_t out = init;
|
| 163 |
+
int64_t out_idx = 0;
|
| 164 |
+
|
| 165 |
+
for (auto col = decltype(row_size){0}; col < row_size; ++col) {
|
| 166 |
+
const auto val = c10::load(self);
|
| 167 |
+
if(at::_isnan(val) || (!at::_isnan(out) && binary_op(val, out))) {
|
| 168 |
+
out = val;
|
| 169 |
+
out_idx = col;
|
| 170 |
+
}
|
| 171 |
+
*values = out;
|
| 172 |
+
*indices = out_idx;
|
| 173 |
+
self += num_irows;
|
| 174 |
+
values += num_irows;
|
| 175 |
+
indices += num_irows;
|
| 176 |
+
}
|
| 177 |
+
}
|
| 178 |
+
}
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
inline void check_fits_in_unsigned(int64_t val, const char* name) {
|
| 182 |
+
constexpr auto umax = std::numeric_limits<uint32_t>::max();
|
| 183 |
+
TORCH_CHECK(
|
| 184 |
+
val >= 0 && val <= umax, name, " must fit in a 32-bit uint32_t value");
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
template<typename scalar_t, class BinaryFunction>
|
| 189 |
+
__host__ void scan_outer_dim_with_indices(
|
| 190 |
+
const TensorBase& self, const TensorBase& values, const TensorBase& indices,
|
| 191 |
+
int dim, scalar_t init, BinaryFunction binary_op) {
|
| 192 |
+
int64_t row_size = self.size(dim);
|
| 193 |
+
auto sizes = self.sizes();
|
| 194 |
+
|
| 195 |
+
// Treat all outer dimensions (i.e. dim_ < dim) as one.
|
| 196 |
+
const int64_t num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + dim);
|
| 197 |
+
|
| 198 |
+
// Treat all inner dimensions (i.e. dim > dimension) as one.
|
| 199 |
+
const int64_t num_irows = c10::multiply_integers(sizes.begin() + dim + 1, sizes.end());
|
| 200 |
+
//for performance reasons, cuda kernels use uint32_t for loops over irows, orows and row,
|
| 201 |
+
//make sure that input is not bigger than supported by uint32_t
|
| 202 |
+
check_fits_in_unsigned(num_irows, "num_irows");
|
| 203 |
+
check_fits_in_unsigned(num_orows, "num_orows");
|
| 204 |
+
check_fits_in_unsigned(row_size, "row_size");
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
dim3 threads(std::min(512, int(num_irows)));
|
| 208 |
+
int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
|
| 209 |
+
dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x})));
|
| 210 |
+
tensor_kernel_scan_outer_dim_with_indices<scalar_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
| 211 |
+
self.const_data_ptr<scalar_t>(), values.mutable_data_ptr<scalar_t>(), indices.mutable_data_ptr<int64_t>(),
|
| 212 |
+
num_orows, num_irows, row_size, init, binary_op);
|
| 213 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
template <typename scalar_t, class BinaryFunction>
|
| 217 |
+
__host__ void scan_innermost_dim_with_indices(
|
| 218 |
+
const TensorBase& self, const TensorBase& values, const TensorBase& indices,
|
| 219 |
+
scalar_t init, BinaryFunction binary_op) {
|
| 220 |
+
int ndim = self.dim();
|
| 221 |
+
// Treat all outer dimensions as a single dimension.
|
| 222 |
+
int row_size = self.size(ndim - 1);
|
| 223 |
+
int num_rows = self.numel() / row_size;
|
| 224 |
+
|
| 225 |
+
// assuming max_num_threads per block is 512
|
| 226 |
+
const uint32_t num_threads = 512;
|
| 227 |
+
const uint32_t log_num_threads_x = get_log_num_threads_x_inner_scan<uint32_t>(num_rows, row_size);
|
| 228 |
+
const uint32_t num_threads_x = (1 << log_num_threads_x);
|
| 229 |
+
const uint32_t num_threads_y = num_threads / num_threads_x;
|
| 230 |
+
dim3 threads(num_threads_x, num_threads_y);
|
| 231 |
+
dim3 grid(std::min(at::cuda::getCurrentDeviceProperties()->maxGridSize[0], ceil_div(num_rows, int(threads.y))));
|
| 232 |
+
|
| 233 |
+
const uint32_t mem_size = 2 * num_threads * (sizeof(scalar_t) + sizeof(int64_t));
|
| 234 |
+
tensor_kernel_scan_innermost_dim_with_indices<scalar_t><<<grid, threads, mem_size,
|
| 235 |
+
at::cuda::getCurrentCUDAStream()>>>(
|
| 236 |
+
self.const_data_ptr<scalar_t>(), values.mutable_data_ptr<scalar_t>(), indices.mutable_data_ptr<int64_t>(),
|
| 237 |
+
num_rows, row_size, num_threads, log_num_threads_x, init, binary_op);
|
| 238 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
template<typename scalar_t, typename BinaryFunction>
|
| 242 |
+
void scan_dim_with_indices(const TensorBase& self, const TensorBase& values, const TensorBase& indices, //int64_t dim) {
|
| 243 |
+
int64_t dim, scalar_t init, BinaryFunction binary_op) {
|
| 244 |
+
int ndim = self.dim();
|
| 245 |
+
auto self_ = self.expect_contiguous();
|
| 246 |
+
TORCH_INTERNAL_ASSERT(values.is_contiguous() && indices.is_contiguous());
|
| 247 |
+
if (dim == ndim - 1) {
|
| 248 |
+
scan_innermost_dim_with_indices<scalar_t>(*self_, values, indices, init, binary_op);
|
| 249 |
+
} else {
|
| 250 |
+
scan_outer_dim_with_indices<scalar_t>(*self_, values, indices, dim, init, binary_op);
|
| 251 |
+
}
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
// TODO: The implementation of `tensor_kernel_scan_outer_dim` and
|
| 255 |
+
// `tensor_kernel_scan_innermost_dim` is similar to
|
| 256 |
+
// `tensor_kernel_scan_outer_dim_with_indices`
|
| 257 |
+
// `tensor_kernel_scan_outer_dim_with_indices` and should be refactored to
|
| 258 |
+
// remove the duplication.
|
| 259 |
+
|
| 260 |
+
/* Perform an inclusive scan along an outer dimension of a tensor.
|
| 261 |
+
*
|
| 262 |
+
* - num_orows is the size of the flattened outer dimensions;
|
| 263 |
+
* - num_irows is the size of the flattened inner dimensions;
|
| 264 |
+
* - row_size is the size of the dimension along which to scan;
|
| 265 |
+
*
|
| 266 |
+
* The dimensions to the outside and inside of the specified dimension are considered as flattened.
|
| 267 |
+
* Thread blocks with the same blockIdx.y process an "outer row" (i.e. an element of the flattened
|
| 268 |
+
* outer dimensions, which contains several "inner rows").
|
| 269 |
+
* Each thread processes a single inner row at a time.
|
| 270 |
+
*/
|
| 271 |
+
template<typename scalar_t, class BinaryOp>
|
| 272 |
+
__global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, const scalar_t *src_,
|
| 273 |
+
const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size,
|
| 274 |
+
const scalar_t init, BinaryOp binary_op)
|
| 275 |
+
{
|
| 276 |
+
for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
|
| 277 |
+
for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
|
| 278 |
+
const scalar_t *src = src_ + orow * row_size * num_irows + irow;
|
| 279 |
+
scalar_t *tgt = tgt_ + orow * row_size * num_irows + irow;
|
| 280 |
+
scalar_t acc = init;
|
| 281 |
+
|
| 282 |
+
for (uint32_t col = 0; col < row_size; ++col) {
|
| 283 |
+
acc = binary_op(acc, c10::load(src));
|
| 284 |
+
*tgt = acc;
|
| 285 |
+
|
| 286 |
+
src += num_irows;
|
| 287 |
+
tgt += num_irows;
|
| 288 |
+
}
|
| 289 |
+
}
|
| 290 |
+
}
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
/* Perform an inclusive scan along the innermost dimension of a tensor.
|
| 294 |
+
*
|
| 295 |
+
* - num_rows is the size of the flattened outer dimensions;
|
| 296 |
+
* - row_size is the size of the innermost dimension;
|
| 297 |
+
*
|
| 298 |
+
* The outer dimensions of the tensor are considered as a single dimension, i.e. the tensor is
|
| 299 |
+
* considered as having 'num_rows' rows of size 'row_size'.
|
| 300 |
+
* Each thread block processes one or more sets of contiguous rows (processing multiple rows
|
| 301 |
+
* per thread block is quicker than processing a single row, especially for short rows).
|
| 302 |
+
*/
|
| 303 |
+
template<typename T, class BinaryFunction>
|
| 304 |
+
__device__ void tensor_kernel_scan_innermost_dim_impl(T* row_buf, T *tgt_, const T *src_,
|
| 305 |
+
const uint32_t num_rows, const uint32_t row_size,
|
| 306 |
+
const uint32_t log_num_threads_x,
|
| 307 |
+
T init, BinaryFunction binary_op){
|
| 308 |
+
const uint32_t num_threads_x = 1 << log_num_threads_x;
|
| 309 |
+
for (uint32_t block_row = blockIdx.x * blockDim.y;
|
| 310 |
+
block_row < num_rows;
|
| 311 |
+
block_row += blockDim.y * gridDim.x) {
|
| 312 |
+
uint32_t row = block_row + threadIdx.y;
|
| 313 |
+
T block_total = init;
|
| 314 |
+
|
| 315 |
+
const T *row_src = src_ + row * row_size;
|
| 316 |
+
T *row_tgt = tgt_ + row * row_size;
|
| 317 |
+
const bool row_exists = row < num_rows;
|
| 318 |
+
|
| 319 |
+
// Perform scan on one block at a time, keeping track of the total value of
|
| 320 |
+
// all blocks processed so far.
|
| 321 |
+
for (uint32_t block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) {
|
| 322 |
+
// Load data into shared memory (two values per thread).
|
| 323 |
+
uint32_t col1 = block_col + threadIdx.x;
|
| 324 |
+
uint32_t col2 = block_col + num_threads_x + threadIdx.x;
|
| 325 |
+
if (row_exists) {
|
| 326 |
+
if (col1 < row_size) {
|
| 327 |
+
row_buf[threadIdx.x] = row_src[col1];
|
| 328 |
+
} else {
|
| 329 |
+
row_buf[threadIdx.x] = init;
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
if (col2 < row_size) {
|
| 333 |
+
row_buf[num_threads_x + threadIdx.x] = row_src[col2];
|
| 334 |
+
} else {
|
| 335 |
+
row_buf[num_threads_x + threadIdx.x] = init;
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
// Add the total value of all previous blocks to the first value of this block.
|
| 339 |
+
if (threadIdx.x == 0) {
|
| 340 |
+
row_buf[0] = binary_op(row_buf[0], block_total);
|
| 341 |
+
}
|
| 342 |
+
}
|
| 343 |
+
__syncthreads();
|
| 344 |
+
|
| 345 |
+
// Parallel reduction with Sklansky method. The diagram can be seen on this paper:
|
| 346 |
+
// https://research.nvidia.com/publication/single-pass-parallel-prefix-scan-decoupled-look-back
|
| 347 |
+
for (uint32_t m = 0; m <= log_num_threads_x; ++m) {
|
| 348 |
+
if (row_exists) {
|
| 349 |
+
uint32_t s = 1 << m; // s = 2 ^ m
|
| 350 |
+
uint32_t a = ((threadIdx.x >> m) << (m + 1)) | s; // a = (threadIdx.x / s) * (2 * s) + s
|
| 351 |
+
uint32_t ti = a + (threadIdx.x % s);
|
| 352 |
+
uint32_t si = a - 1;
|
| 353 |
+
row_buf[ti] = binary_op(row_buf[ti], row_buf[si]);
|
| 354 |
+
}
|
| 355 |
+
__syncthreads();
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
// Write back to output.
|
| 359 |
+
if (row_exists) {
|
| 360 |
+
if (col1 < row_size) row_tgt[col1] = row_buf[threadIdx.x];
|
| 361 |
+
if (col2 < row_size) row_tgt[col2] = row_buf[num_threads_x + threadIdx.x];
|
| 362 |
+
}
|
| 363 |
+
block_total = row_buf[2 * num_threads_x - 1];
|
| 364 |
+
__syncthreads();
|
| 365 |
+
}
|
| 366 |
+
}
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
template <
|
| 370 |
+
typename T,
|
| 371 |
+
class BinaryFunction>
|
| 372 |
+
__global__ void tensor_kernel_scan_innermost_dim(
|
| 373 |
+
T* tgt_,
|
| 374 |
+
const T* src_,
|
| 375 |
+
const uint32_t num_rows,
|
| 376 |
+
const uint32_t row_size,
|
| 377 |
+
const uint32_t log_num_threads_x,
|
| 378 |
+
T init,
|
| 379 |
+
BinaryFunction binary_op) {
|
| 380 |
+
alignas(sizeof(double)) extern __shared__ char sbuf[];
|
| 381 |
+
T* sbuf2 = reinterpret_cast<T*>(sbuf);
|
| 382 |
+
const uint32_t num_threads_x = 1 << log_num_threads_x;
|
| 383 |
+
T* row_buf = reinterpret_cast<T*>(sbuf2 + num_threads_x * 2 * threadIdx.y);
|
| 384 |
+
|
| 385 |
+
tensor_kernel_scan_innermost_dim_impl<T>(
|
| 386 |
+
row_buf, tgt_, src_, num_rows, row_size, log_num_threads_x, init, binary_op);
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
template<typename scalar_t, class BinaryFunction>
|
| 391 |
+
__host__ void scan_outer_dim(const TensorBase& self, const TensorBase& result,
|
| 392 |
+
int dim, scalar_t init, BinaryFunction binary_op) {
|
| 393 |
+
const int64_t row_size = self.size(dim);
|
| 394 |
+
auto sizes = self.sizes();
|
| 395 |
+
|
| 396 |
+
// Treat all outer dimensions (i.e. dim_ < dim) as one.
|
| 397 |
+
const int64_t num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + dim);
|
| 398 |
+
|
| 399 |
+
// Treat all inner dimensions (i.e. dim > dimension) as one.
|
| 400 |
+
const int64_t num_irows = c10::multiply_integers(sizes.begin() + dim + 1, sizes.end());
|
| 401 |
+
|
| 402 |
+
dim3 threads(std::min(512, int(num_irows)));
|
| 403 |
+
int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
|
| 404 |
+
dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x})));
|
| 405 |
+
|
| 406 |
+
check_fits_in_unsigned(num_irows, "num_irows");
|
| 407 |
+
check_fits_in_unsigned(num_orows, "num_orows");
|
| 408 |
+
check_fits_in_unsigned(row_size, "row_size");
|
| 409 |
+
|
| 410 |
+
tensor_kernel_scan_outer_dim<scalar_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
| 411 |
+
result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),
|
| 412 |
+
num_orows, num_irows, row_size, init, binary_op);
|
| 413 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
template <typename scalar_t, class BinaryFunction>
|
| 417 |
+
void scan_innermost_dim(const TensorBase& self, const TensorBase& result,
|
| 418 |
+
scalar_t init, BinaryFunction binary_op) {
|
| 419 |
+
int64_t ndim = self.dim();
|
| 420 |
+
// Treat all outer dimensions as a single dimension.
|
| 421 |
+
int64_t row_size = self.size(ndim - 1);
|
| 422 |
+
int64_t num_rows = self.numel() / row_size;
|
| 423 |
+
|
| 424 |
+
// assuming max_num_threads per block is 512
|
| 425 |
+
const uint32_t num_threads = 512;
|
| 426 |
+
const uint32_t log_num_threads_x = get_log_num_threads_x_inner_scan<uint32_t>(num_rows, row_size);
|
| 427 |
+
const uint32_t num_threads_x = (1 << log_num_threads_x);
|
| 428 |
+
const uint32_t num_threads_y = num_threads / num_threads_x;
|
| 429 |
+
dim3 threads(num_threads_x, num_threads_y);
|
| 430 |
+
int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
|
| 431 |
+
dim3 grid(std::min(maxGridDim, ceil_div(num_rows, int64_t{threads.y})));
|
| 432 |
+
|
| 433 |
+
check_fits_in_unsigned(num_rows, "Number of rows (self.numel()/self.size(self.dim()-1))");
|
| 434 |
+
check_fits_in_unsigned(row_size, "row_size");
|
| 435 |
+
|
| 436 |
+
tensor_kernel_scan_innermost_dim<scalar_t><<<grid, threads, num_threads * 2 * sizeof(scalar_t),
|
| 437 |
+
at::cuda::getCurrentCUDAStream()>>>(
|
| 438 |
+
result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),
|
| 439 |
+
num_rows, row_size, log_num_threads_x, init, binary_op);
|
| 440 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 441 |
+
}
|
| 442 |
+
|
| 443 |
+
template<typename scalar_t, typename BinaryFunction>
|
| 444 |
+
void scan_dim(const TensorBase& self, const TensorBase& result,
|
| 445 |
+
int64_t dim, scalar_t init, BinaryFunction binary_op) {
|
| 446 |
+
int ndim = self.dim();
|
| 447 |
+
auto self_ = self.expect_contiguous();
|
| 448 |
+
TORCH_INTERNAL_ASSERT(result.is_contiguous());
|
| 449 |
+
|
| 450 |
+
if (self.numel() == self.size(dim)) {
|
| 451 |
+
cuda::cub::inclusive_scan(self_->const_data_ptr<scalar_t>(), result.mutable_data_ptr<scalar_t>(), binary_op, self.numel());
|
| 452 |
+
} else if (dim == ndim - 1) {
|
| 453 |
+
scan_innermost_dim<scalar_t>(*self_, result, init, binary_op);
|
| 454 |
+
} else {
|
| 455 |
+
scan_outer_dim<scalar_t>(*self_, result, dim, init, binary_op);
|
| 456 |
+
}
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
}} // namespace at::native
|