diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/AtomicAddFloat.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/AtomicAddFloat.h new file mode 100644 index 0000000000000000000000000000000000000000..5b24ee4821c45baab25f37a3bfa3399eff8a1716 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/AtomicAddFloat.h @@ -0,0 +1,37 @@ +#ifndef ATOMIC_ADD_FLOAT +#define ATOMIC_ADD_FLOAT + +#if (defined(__x86_64__) || defined(__i386__) || defined(__aarch64__)) +#include +#else +#define _mm_pause() +#endif + +#include + +static inline void cpu_atomic_add_float(float* dst, float fvalue) +{ + typedef union { + unsigned intV; + float floatV; + } uf32_t; + + uf32_t new_value, old_value; + std::atomic* dst_intV = (std::atomic*)(dst); + + old_value.floatV = *dst; + new_value.floatV = old_value.floatV + fvalue; + + unsigned* old_intV = (unsigned*)(&old_value.intV); + while (!std::atomic_compare_exchange_strong(dst_intV, old_intV, new_value.intV)) { +#ifdef __aarch64__ + __asm__ __volatile__("yield;" : : : "memory"); +#else + _mm_pause(); +#endif + old_value.floatV = *dst; + new_value.floatV = old_value.floatV + fvalue; + } +} + +#endif diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/ChannelShuffleKernel.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/ChannelShuffleKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..387c301c25f030f73d73eb7e052950ce0eee0ae8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/ChannelShuffleKernel.h @@ -0,0 +1,14 @@ +#pragma once +#include +#include + +namespace at { +class TensorBase; +} + +namespace at::native { + +using channel_shuffle_fn = void(*)(TensorBase&, const TensorBase&, int64_t); +DECLARE_DISPATCH(channel_shuffle_fn, channel_shuffle_kernel); + +} // at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/CopyKernel.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/CopyKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..3378e16f93d23e6c317b98f4469e660086b0082a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/CopyKernel.h @@ -0,0 +1,14 @@ +#pragma once + +#include + +namespace at { +struct TensorIteratorBase; + +namespace native { +inline namespace CPU_CAPABILITY { + +void direct_copy_kernel(TensorIteratorBase &iter); +void copy_kernel(TensorIterator& iter, bool /*non_blocking*/); + +}}} // namespace at::native::CPU_CAPABILITY diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/DepthwiseConvKernel.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/DepthwiseConvKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..80970074b8e6c99d079f26aa6f576e67228a04f7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/DepthwiseConvKernel.h @@ -0,0 +1,21 @@ +#pragma once + +#include +#include + +/* + Depthwise 3x3 Winograd convolution operator +*/ + +namespace at { +class Tensor; + +namespace native { + +using convolution_depthwise3x3_winograd_fn = + Tensor (*)(const Tensor &, const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, int64_t); + +DECLARE_DISPATCH(convolution_depthwise3x3_winograd_fn, convolution_depthwise3x3_winograd_stub); + +} // namespace native +} // namespace at diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/DistributionTemplates.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/DistributionTemplates.h new file mode 100644 index 0000000000000000000000000000000000000000..8171ae8e79ad2a1311f4a8600decd202c66649d5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/DistributionTemplates.h @@ -0,0 +1,425 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef CPU_CAPABILITY_AVX2 +#include +#include +#endif + + + + +namespace at::native::templates::cpu { +namespace { + +// ==================================================== Random ======================================================== + +template +void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG generator) { + AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cpu", AT_WRAP([&] { + std::lock_guard lock(generator->mutex_); + cpu_serial_kernel(iter, [range, base, generator]() -> scalar_t { + uniform_int_from_to_distribution random(range, base); + return random(generator); + }); + }), kBool, kHalf, kBFloat16, AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); +} + +// This is the special kernel to handle single specific case: +// from(inclusive) = std::numeric_limits::lowest() +// to(exclusive) = None (= std::numeric_limits::max() + 1) +template +void random_full_64_bits_range_kernel(TensorIteratorBase& iter, RNG generator) { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, iter.dtype(), "random_full_64_bits_range_kernel_cpu", [&] { + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) { + std::lock_guard lock(generator->mutex_); + cpu_serial_kernel(iter, [generator]() -> scalar_t { + uniform_int_full_range_distribution random; + return random(generator); + }); + } else { + TORCH_CHECK(false, "random_full_64_bits_range_kernel_cpu handles only int64, double, float and bfloat16"); + } + }); +} + +template +struct RandomFromToKernel { + void operator()(TensorIteratorBase& iter, uint64_t range, int64_t base, std::optional gen) { + random_from_to_kernel(iter, range, base, check_generator(gen)); + } + void operator()(TensorIteratorBase& iter, std::optional gen) { + random_full_64_bits_range_kernel(iter, check_generator(gen)); + } +}; + +template +void random_kernel(TensorIteratorBase& iter, RNG generator) { + std::lock_guard lock(generator->mutex_); + AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cpu", [&] { + cpu_serial_kernel(iter, [generator]() -> scalar_t { + uniform_int_distribution random; + return random(generator); + }); + }); +} + +template +struct RandomKernel { + void operator()(TensorIteratorBase& iter, std::optional gen) { + random_kernel(iter, check_generator(gen)); + } +}; + +// ==================================================== Normal ======================================================== + +#ifdef CPU_CAPABILITY_AVX2 +static void normal_fill_16_AVX2(float *data, + const __m256* two_pi, + const __m256* one, + const __m256* minus_two, + const __m256* mean, + const __m256* std_v) { + const __m256 u1 = _mm256_sub_ps(*one, _mm256_loadu_ps(data)); + const __m256 u2 = _mm256_loadu_ps(data + 8); + // sincos256_ps and log256_ps are from avx_mathfun.h + const __m256 radius = _mm256_sqrt_ps(_mm256_mul_ps(*minus_two, log256_ps(u1))); + const __m256 theta = _mm256_mul_ps(*two_pi, u2); + __m256 sintheta, costheta; + sincos256_ps(theta, &sintheta, &costheta); + const __m256 n1 = _mm256_mul_ps(radius, costheta); + const __m256 n2 = _mm256_mul_ps(radius, sintheta); + _mm256_storeu_ps(data, _mm256_fmadd_ps(n1, *std_v, *mean)); + _mm256_storeu_ps(data + 8, _mm256_fmadd_ps(n2, *std_v, *mean)); +} + +template +void normal_fill_AVX2(const TensorBase &self, const float mean, const float std, RNG generator) { + float *data = self.data_ptr(); + auto size = self.numel(); + std::lock_guard lock(generator->mutex_); + for (const auto i : c10::irange(size)) { + at::uniform_real_distribution uniform(0, 1); + data[i] = uniform(generator); + } + const __m256 two_pi = _mm256_set1_ps(2.0f * c10::pi); + const __m256 one = _mm256_set1_ps(1.0f); + const __m256 minus_two = _mm256_set1_ps(-2.0f); + const __m256 mean_v = _mm256_set1_ps(mean); + const __m256 std_v = _mm256_set1_ps(std); + + for (int64_t i = 0; i < size - 15; i += 16) { + normal_fill_16_AVX2(data + i, &two_pi, &one, &minus_two, &mean_v, &std_v); + } + + if (size % 16 != 0) { + // Recompute the last 16 values. + data = data + size - 16; + for (const auto i : c10::irange(16)) { + at::uniform_real_distribution uniform(0, 1); + data[i] = uniform(generator); + } + normal_fill_16_AVX2(data, &two_pi, &one, &minus_two, &mean_v, &std_v); + } +} +#endif + +template +static void normal_fill_16(scalar_t *data, const scalar_t mean, const scalar_t std) { + for (const auto j : c10::irange(8)) { + const scalar_t u1 = 1 - data[j]; // [0, 1) -> (0, 1] for log. + const scalar_t u2 = data[j + 8]; + const scalar_t radius = std::sqrt(-2 * std::log(u1)); + const scalar_t theta = 2.0f * c10::pi * u2; + data[j] = radius * std::cos(theta) * std + mean; + data[j + 8] = radius * std::sin(theta) * std + mean; + } +} + +#if defined(__VSX__) || defined(CPU_CAPABILITY_VSX) +static void normal_fill_16_VSX(float *data,const Vectorized &two_pi,const Vectorized &one,const Vectorized &minus_two,const Vectorized &mean,const Vectorized &std) { + using Vec = Vectorized; + Vec u1=one-Vec::loadu(data); + Vec u2=Vec::loadu(data+8); + Vec radius=(minus_two * u1.log()); + radius=radius.sqrt(); + Vec theta=two_pi * u2; + Vec output_vec=radius * theta.cos() * std + mean; + Vec output_vec2=radius * theta.sin() * std + mean; + output_vec.store(data); + output_vec2.store(data+8); +} + +template +void normal_fill_VSX(const TensorBase &self, const scalar_t mean, const scalar_t std, RNG generator) { + float *data = self.data_ptr(); + auto size = self.numel(); + std::lock_guard lock(generator->mutex_); + for (const auto i : c10::irange(size)) { + at::uniform_real_distribution uniform(0, 1); + data[i] = uniform(generator); + } + + using Vec = Vectorized; + const Vec two_pi = Vec(2.0f * c10::pi); + const Vec one = Vec(1.0f); + const Vec minus_two = Vec(-2.0f); + const Vec var_vec = Vec(std); + const Vec mean_vec = Vec(mean); + + for (int64_t i = 0; i < size - 15; i += 16) { + if(Vec::size()==8) { + normal_fill_16_VSX(data + i, two_pi, one, minus_two, mean_vec, var_vec); + } + else{ + normal_fill_16(data + i, mean, std); + } + } + if (size % 16 != 0) { + // Recompute the last 16 values. + data = data + size - 16; + for (const auto i : c10::irange(16)) { + at::uniform_real_distribution uniform(0, 1); + data[i] = uniform(generator); + } + if(Vec::size()==8){ + normal_fill_16_VSX(data, two_pi, one, minus_two, mean_vec, var_vec); + } + else{ + normal_fill_16(data, mean, std); + } + } +} +#endif //VSX + +template +void normal_fill(const TensorBase &self, const scalar_t mean, const scalar_t std, RNG generator) { + scalar_t *data = self.data_ptr(); + auto size = self.numel(); + std::lock_guard lock(generator->mutex_); + for (const auto i : c10::irange(size)) { + at::uniform_real_distribution uniform(0, 1); + data[i] = uniform(generator); + } + + for (int64_t i = 0; i < size - 15; i += 16) { + normal_fill_16(data + i, mean, std); + } + if (size % 16 != 0) { + // Recompute the last 16 values. + data = data + size - 16; + for (const auto i : c10::irange(16)) { + at::uniform_real_distribution uniform(0, 1); + data[i] = uniform(generator); + } + normal_fill_16(data, mean, std); + } +} + +template +void normal_kernel(const TensorBase &self, double mean, double std, RNG generator) { + auto size = self.numel(); + if (self.scalar_type() == ScalarType::Float && size >= 16 && self.is_contiguous()) { +#ifdef CPU_CAPABILITY_AVX2 + normal_fill_AVX2(self, static_cast(mean), static_cast(std), generator); +#elif defined(__VSX__) || defined(CPU_CAPABILITY_VSX) + normal_fill_VSX(self, static_cast(mean), static_cast(std), generator); +#else + normal_fill(self, static_cast(mean), static_cast(std), generator); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, self.scalar_type(), "normal_kernel_cpu", [&] { + if (size >= 16 && self.is_contiguous()) { + normal_fill(self, static_cast(mean), static_cast(std), generator); + } else { + auto iter = TensorIterator::borrowing_nullary_op(self); + std::lock_guard lock(generator->mutex_); + cpu_serial_kernel(iter, [mean, std, generator]() -> scalar_t { + at::normal_distribution normal(mean, std); + return static_cast(normal(generator)); + }); + } + }); + } +} + +template +struct NormalKernel { + void operator()(Tensor& self, double mean, double std, std::optional gen) { + normal_kernel(self, mean, std, check_generator(gen)); + } +}; + +// ==================================================== Uniform ======================================================= + +template +void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG generator) { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "uniform_kernel_cpu", [&]() { + std::lock_guard lock(generator->mutex_); + auto from = static_cast(from_); + auto to = static_cast(to_); + at::uniform_real_distribution uniform(from, to); + cpu_serial_kernel(iter, [&uniform, generator]() -> scalar_t { + return static_cast(uniform(generator)); + }); + }); +} + +template +struct UniformKernel { + void operator()(TensorIteratorBase& iter, double from, double to, std::optional gen) { + uniform_kernel(iter, from, to, check_generator(gen)); + } +}; + +// ==================================================== Cauchy ======================================================== + +template +void cauchy_kernel(TensorIteratorBase& iter, double median, double sigma, RNG generator) { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "cauchy_cpu", [&]() { + std::lock_guard lock(generator->mutex_); + at::cauchy_distribution cauchy(median, sigma); + cpu_serial_kernel(iter, [&cauchy, generator]() -> scalar_t { + return static_cast(cauchy(generator)); + }); + }); +} + +template +struct CauchyKernel { + void operator()(TensorIteratorBase& iter, double median, double sigma, std::optional gen) { + cauchy_kernel(iter, median, sigma, check_generator(gen)); + } +}; + +// ================================================== LogNormal ======================================================= + +template +void log_normal_kernel(TensorIteratorBase& iter, double mean, double std, RNG generator) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "log_normal_cpu", [&]() { + std::lock_guard lock(generator->mutex_); + at::lognormal_distribution logNormal(mean, std); + cpu_serial_kernel(iter, [&logNormal, generator]() -> scalar_t { + return static_cast(logNormal(generator)); + }); + }); +} + +template +struct LogNormalKernel { + void operator()(TensorIteratorBase& iter, double mean, double std, std::optional gen) { + log_normal_kernel(iter, mean, std, check_generator(gen)); + } +}; + +// =================================================== Geometric ====================================================== + +template +void geometric_kernel(TensorIteratorBase& iter, double p, RNG generator) { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_cpu", [&]() { + std::lock_guard lock(generator->mutex_); + at::geometric_distribution geometric(p); + cpu_serial_kernel(iter, [&geometric, generator]() -> scalar_t { + return static_cast(geometric(generator)); + }); + }); +} + +template +struct GeometricKernel { + void operator()(TensorIteratorBase& iter, double p, std::optional gen) { + geometric_kernel(iter, p, check_generator(gen)); + } +}; + +// ================================================== Exponential ===================================================== + +template +void exponential_kernel(TensorIteratorBase& iter, double lambda, RNG generator) { + TORCH_CHECK(isFloatingType(iter.dtype()), "Exponential distribution is a continuous probability distribution. dtype must be a floating point but you specified ", iter.dtype()); + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_cpu", [&]() { + std::lock_guard lock(generator->mutex_); + at::exponential_distribution exponential(lambda); + cpu_serial_kernel(iter, [&exponential, generator]() -> scalar_t { + return static_cast(exponential(generator)); + }); + }); +} + +template +struct ExponentialKernel { + void operator()(TensorIteratorBase& iter, double lambda, std::optional gen) { + exponential_kernel(iter, lambda, check_generator(gen)); + } +}; + +// ================================================== Bernoulli ======================================================= + +template +void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG generator) { + AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, + self.scalar_type(), "bernoulli_tensor_cpu_self_", [&] { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(generator->mutex_); + using self_t = scalar_t; + auto p_cpu = p_.to(kCPU); + auto p = expand_inplace(self, p_cpu); + auto iter = TensorIteratorConfig() + .add_output(self) + .add_const_input(*p) + .check_all_same_dtype(false) + .build(); + if (p->scalar_type() == kDouble) { + cpu_serial_kernel(iter, [&](const double p_val) -> self_t { + at::bernoulli_distribution bernoulli(p_val); + return static_cast(bernoulli(generator)); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half, + p->scalar_type(), "bernoulli_tensor_cpu_p_", [&] { + using p_t = scalar_t; + cpu_serial_kernel(iter, [&](const p_t p_val) -> self_t { + at::bernoulli_distribution bernoulli(p_val); + return static_cast(bernoulli(generator)); + }); + }); + } + }); +} + +template +void bernoulli_kernel(const TensorBase &self, double p, RNG generator) { + AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, + self.scalar_type(), "bernoulli_scalar_cpu_", [&] { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(generator->mutex_); + auto iter = TensorIterator::borrowing_nullary_op(self); + cpu_serial_kernel(iter, [p, generator]() -> scalar_t { + at::bernoulli_distribution bernoulli(p); + return static_cast(bernoulli(generator)); + }); + }); +} + +template +struct BernoulliKernel { + void operator()(const TensorBase &self, double p, std::optional gen) { + bernoulli_kernel(self, p, check_generator(gen)); + } + void operator()(const TensorBase &self, const TensorBase &p_, std::optional gen) { + bernoulli_kernel(self, p_, check_generator(gen)); + } +}; + +}} diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/GridSamplerKernel.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/GridSamplerKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..3d332f88fc7cbd69cb266535e6ed1db07f95f3e3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/GridSamplerKernel.h @@ -0,0 +1,34 @@ +#pragma once + +#include + +#include +#include + +namespace at { +class TensorBase; +} + +namespace at::native { + +using forward_2d_fn = void (*) ( + const TensorBase &output, + const TensorBase &input, + const TensorBase &grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners); +using backward_2d_fn = void (*) ( + const TensorBase &grad_input, + const TensorBase &grad_grid, + const TensorBase &grad_output, + const TensorBase &input, + const TensorBase &grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners, + std::array output_mask); +DECLARE_DISPATCH(forward_2d_fn, grid_sampler_2d_cpu_kernel); +DECLARE_DISPATCH(backward_2d_fn, grid_sampler_2d_backward_cpu_kernel); + +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/IndexKernelUtils.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/IndexKernelUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..876f759e130f20c46be1b3176caf6829af046019 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/IndexKernelUtils.h @@ -0,0 +1,87 @@ +#pragma once +#include +#include + +namespace at::native { + +namespace { +static bool is_constant_index(int ntensor, const int64_t* strides) { + AT_ASSERT(ntensor >= 3); + for (const auto arg : c10::irange(2, ntensor)) { + if (strides[arg] != 0) { + return false; + } + } + return true; +} + + +struct Indexer { + Indexer(int64_t num_indexers, char** indexers, const int64_t* indexer_strides, + IntArrayRef original_sizes, IntArrayRef original_strides) + : num_indexers(num_indexers) + , indexers(indexers) + , indexer_strides(indexer_strides) + , original_strides(original_strides.data()) + , original_sizes(original_sizes.data()) { + AT_ASSERT(static_cast(original_strides.size()) == num_indexers); + AT_ASSERT(static_cast(original_sizes.size()) == num_indexers); + } + + int64_t num_indexers; + char** indexers; + const int64_t* indexer_strides; + const int64_t* original_strides; + const int64_t* original_sizes; + + int64_t get(int64_t idx) { + int64_t offset = 0; + for (const auto j : c10::irange(num_indexers)) { + int64_t value = *(int64_t*)&indexers[j][idx * indexer_strides[j]]; + int64_t size = original_sizes[j]; + TORCH_CHECK_INDEX(value >= -size && value < size, + "index ", value, " is out of bounds for dimension ", j, " with size ", size); + if (value < 0) { + value += size; + } + offset += value * original_strides[j]; + } + return offset; + } +}; +} // anonymous namespace + +template +void cpu_index_kernel(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride, + const func_t& f, bool serial_execution=false) +{ + int ntensor = iter.ntensors(); + // When launch the index parallel version, set a relative small grain size less than the INTERNAL::GRAIN_SIZE + // to make the whole available thread numbers get more balanced work load and a better cache location. + // The grain size here is chosen by the op benchmark to overcome the thread launch overhead + const int index_parallel_grain_size = 3000; + auto loop = [&](char** data, const int64_t* strides, int64_t n) { + auto indexer = Indexer(ntensor - 2, &data[2], &strides[2], index_size, index_stride); + char* dst = data[0]; + char* src = data[1]; + if (is_constant_index(ntensor, strides)) { + // specialization for when every element uses the same index + int64_t offset = indexer.get(0); + for (const auto i : c10::irange(n)) { + f(dst + strides[0] * i, src + strides[1] * i, offset); + } + } else { + for (const auto i : c10::irange(n)) { + int64_t offset = indexer.get(i); + f(dst + strides[0] * i, src + strides[1] * i, offset); + } + } + }; + if (serial_execution) { + iter.serial_for_each(loop, {0, iter.numel()}); + } else { + iter.for_each(loop, index_parallel_grain_size); + } +} +} // at +// native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Intrinsics.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Intrinsics.h new file mode 100644 index 0000000000000000000000000000000000000000..f3b35328f1882729a9158eaed7eb2abf77097484 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Intrinsics.h @@ -0,0 +1,33 @@ +#pragma once + +#if defined(__clang__) && (defined(__x86_64__) || defined(__i386__)) +/* Clang-compatible compiler, targeting x86/x86-64 */ +#include +#elif defined(_MSC_VER) +/* Microsoft C/C++-compatible compiler */ +#include +#if _MSC_VER <= 1900 +#define _mm256_extract_epi64(X, Y) (((uint64_t*)&X)[Y]) +#endif +#elif defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__)) +/* GCC-compatible compiler, targeting x86/x86-64 */ +#include +#elif defined(__GNUC__) && defined(__ARM_NEON__) +/* GCC-compatible compiler, targeting ARM with NEON */ +#include +#elif defined(__GNUC__) && defined(__IWMMXT__) +/* GCC-compatible compiler, targeting ARM with WMMX */ +#include +#elif (defined(__GNUC__) || defined(__xlC__)) && \ + (defined(__VEC__) || defined(__ALTIVEC__)) +/* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */ +#include +/* We need to undef those tokens defined by to avoid conflicts + with the C++ types. => Can still use __bool/__vector */ +#undef bool +#undef vector +#undef pixel +#elif defined(__GNUC__) && defined(__SPE__) +/* GCC-compatible compiler, targeting PowerPC with SPE */ +#include +#endif diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/IsContiguous.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/IsContiguous.h new file mode 100644 index 0000000000000000000000000000000000000000..ddbbb6fb8f5afca7b2fa822cfa9e7281cec9909a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/IsContiguous.h @@ -0,0 +1,62 @@ +#pragma once + +namespace at::native { inline namespace CPU_CAPABILITY { + +// n: number of function arguments (arity) +// traits: function_traits (see FunctionTraits.h) +// s: index of scalar argument or -1 +template +struct IsContiguous { + static bool eval(const int64_t* strides) { + using type = typename traits::template arg::type; + return strides[stride_index] == (s == n ? 0 : sizeof(type)) && + IsContiguous::eval(strides); + } +}; + +// will be called when there is an output exists +template +struct IsContiguous<0, 0, traits, s> { + static bool eval(const int64_t* strides) { + return strides[0] == sizeof(typename traits::result_type); + } +}; + +// will be called when there is no output +template +struct IsContiguous<0, -1, traits, s> { + static bool eval(const int64_t* /*strides*/) { + return true; + } +}; + +// output and all inputs are contiguous +template ::value>::type* = nullptr> +static inline bool is_contiguous(const int64_t* strides) { + return IsContiguous::eval(strides); +} + +template ::value>::type* = nullptr> +static inline bool is_contiguous(const int64_t* strides) { + return IsContiguous::eval(strides); +} + +// input at `s` is scalar (stride 0); output and other inputs are contiguous +// NB: output is typically at strides[0] so first input corresponds to s=1 +template ::value>::type* = nullptr> +static inline bool is_contiguous_scalar(const int64_t* strides) { + static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds"); + return IsContiguous::eval(strides); +} + +template ::value>::type* = nullptr> +static inline bool is_contiguous_scalar(const int64_t* strides) { + static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds"); + return IsContiguous::eval(strides); +} + +}} diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/LogAddExp.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/LogAddExp.h new file mode 100644 index 0000000000000000000000000000000000000000..e2b80a648df6b11a99ceadff5488dd597af4f9ac --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/LogAddExp.h @@ -0,0 +1,61 @@ +#pragma once + +#include +#include + +namespace at::native { +inline namespace CPU_CAPABILITY { + +// custom min and max to be used in logcumsumexp for complex arguments +template +std::pair, c10::complex> _logcumsumexp_minmax(c10::complex x, c10::complex y) { + if (at::_isnan(y)) { // either real is nan or imag is nan + return std::make_pair(y, y); + } else if (at::_isnan(x)) { // either real is nan or imag is nan + return std::make_pair(x, x); + } else { + return (x.real() < y.real()) ? std::make_pair(x, y) : std::make_pair(y, x); + } +} + +template +scalar_t _log_add_exp_helper(scalar_t x, scalar_t y) { + // Reference : https://www.tensorflow.org/api_docs/python/tf/math/cumulative_logsumexp + scalar_t min = at::_isnan(y) ? y : std::min(x, y); // std::min returns first arg if one of the args is nan + scalar_t max = at::_isnan(y) ? y : std::max(x, y); // std::max returns first arg if one of the args is nan + if (min != max || std::isfinite(min)) { + // nan will be propagated here + return std::log1p(std::exp(min - max)) + max; + } else { + // special case to correctly handle infinite cases + return x; + } +} + +template +c10::complex _log_add_exp_helper(const c10::complex& x, const c10::complex& y) { + auto [min, max] = _logcumsumexp_minmax(x, y); + auto min_real = std::real(min); + auto max_real = std::real(max); + + if (at::_isnan(min)) { // either real is nan or imag is nan + // handling the "infectious" NaNs + return {std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN()}; + } else if (!std::isfinite(min_real) && (min_real == max_real)) { + if (min_real < 0) { + // handle the -inf case, the imaginary part here does not really matter as the exp(value) + // will be around 0.0 and the angle (i.e. the imaginary part) cannot be determined. + // It does not matter if we're taking the exp of this value + return min; + } else { + // handle the +inf case, we don't need the special precision for log1p for small values + // and to avoid producing nan in case of real(max) == real(min) == +inf + return std::log(std::exp(min) + std::exp(max)); + } + } else { + return std::log1p(std::exp(min - max)) + max; + } +} + +} // end namespace +} //end at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/MaxUnpoolKernel.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/MaxUnpoolKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..1c6507909ca4aa7e49fbaa420e407b211023b1b7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/MaxUnpoolKernel.h @@ -0,0 +1,14 @@ +#pragma once +#include + +namespace at { +class Tensor; + +namespace native { + +using max_unpooling_fn = void(*)(Tensor&, const Tensor&, const Tensor&); + +DECLARE_DISPATCH(max_unpooling_fn, max_unpool2d_kernel); +DECLARE_DISPATCH(max_unpooling_fn, max_unpool3d_kernel); + +}} // at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/PixelShuffleKernel.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/PixelShuffleKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..d5eee58c1ab1510957f2d1dad1163bc2e1724226 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/PixelShuffleKernel.h @@ -0,0 +1,14 @@ +#pragma once +#include + +namespace at { +class TensorBase; +} + +namespace at::native { + +using pixel_shuffle_fn = void(*)(TensorBase&, const TensorBase&, int64_t); +DECLARE_DISPATCH(pixel_shuffle_fn, pixel_shuffle_kernel); +DECLARE_DISPATCH(pixel_shuffle_fn, pixel_unshuffle_kernel); + +} // at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Reduce.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..37bd32d1c4c13a7a00747058a741684e1e6e391a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/Reduce.h @@ -0,0 +1,314 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +namespace at { namespace native { inline namespace CPU_CAPABILITY { + +using namespace vec; + +#define VEC_LOOP_HEADER(func_t, data) \ + using scalar_t = typename function_traits::result_type; \ + using Vec = Vectorized; \ + char* out_ptr = data[0]; \ + (void) out_ptr; + +// reduction that is contiguous over the input in dim 0 +template +inline bool is_contiguous_reduction(const int64_t* strides) { + return strides[0] == 0 && + strides[1] == sizeof(typename traits::arg2_t); +} + +// reduction that is contiguous over the input in dim 1 +template +inline bool is_outer_reduction(const int64_t* strides) { + return strides[0] == 0 && + strides[2] == sizeof(typename traits::result_type) && + strides[3] == sizeof(typename traits::arg2_t); +} + +template +inline void vectorized_reduction(char** data, int64_t n, int64_t stride, + func_t op, vec_func_t vop, bool reduce) { + VEC_LOOP_HEADER(func_t, data) + const char* in1_ptr = data[1]; + Vec acc[4]; + for (const auto j : c10::irange(4)) { + acc[j] = Vec::loadu(in1_ptr + j * Vec::size() * sizeof(scalar_t)); + } + for (const auto i : c10::irange(1, n)) { + const char* ptr = in1_ptr + stride * i; + acc[0] = vop(acc[0], Vec::loadu(ptr + (0 * Vec::size() * sizeof(scalar_t)))); + acc[1] = vop(acc[1], Vec::loadu(ptr + (1 * Vec::size() * sizeof(scalar_t)))); + acc[2] = vop(acc[2], Vec::loadu(ptr + (2 * Vec::size() * sizeof(scalar_t)))); + acc[3] = vop(acc[3], Vec::loadu(ptr + (3 * Vec::size() * sizeof(scalar_t)))); + } + if (reduce) { + scalar_t buffer[Vec::size()]; + acc[0] = vop(vop(acc[0], acc[1]), vop(acc[2], acc[3])); + acc[0].store(buffer); + for (const auto j : c10::irange(1, Vec::size())) { + buffer[0] = op(buffer[0], buffer[j]); + } + auto dst = (scalar_t*)out_ptr; + *dst = op(*dst, buffer[0]); + } else { + for (const auto j : c10::irange(4)) { + auto dst = out_ptr + j * Vec::size() * sizeof(scalar_t); + acc[j] = vop(acc[j], Vec::loadu(dst)); + acc[j].store(dst); + } + } +} + +template +inline void UNARY_OUTER_LOOP(char* data[2], const int64_t strides[2], int64_t n, F f) { + for (const auto j C10_UNUSED : c10::irange(n)) { + f(); + data[0] += strides[0]; + data[1] += strides[1]; + } +} + +// computes the reduction out = op(out, in) +template +inline void vectorized_inner_reduction(char** data, int64_t n, func_t op, vec_func_t vop) { + VEC_LOOP_HEADER(func_t, data) + int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t); + int64_t count = n / (4 * Vec::size()); + if (count > 0) { + vectorized_reduction(data, count, vector_stride, op, vop, /*reduce=*/true); + } + char* ptrs[3] = { data[0], data[0], data[1] }; + int64_t strides[] = { 0, 0, sizeof(scalar_t) }; + basic_loop(ptrs, strides, count * 4 * Vec::size(), n, op); +} + +// computes the reduction out = op(out, in) +template +inline void vectorized_outer_reduction(char** data, int64_t inner_stride, int64_t size0, int64_t size1, func_t op, vec_func_t vop) { + VEC_LOOP_HEADER(func_t, data) + + // reduce down each column of 4 * Vec::size() elements (128 or 256 bytes) +#if defined(CPU_CAPABILITY_AVX512) + int64_t outer_stride[2] = { 256, 256 }; +#else + int64_t outer_stride[2] = { 128, 128 }; +#endif + UNARY_OUTER_LOOP(data, outer_stride, size1 / (4 * Vec::size()), [&] { + vectorized_reduction(data, size0, inner_stride, op, vop, /*reduce=*/false); + }); + + // reduce down the remaining columns + int64_t step[] = { sizeof(scalar_t), sizeof(scalar_t) }; + int64_t remaining = size1 % (4 * Vec::size()); + UNARY_OUTER_LOOP(data, step, remaining, [&] { + char* ptrs[3] = { data[0], data[0], data[1] }; + int64_t strides[] = { 0, 0, inner_stride }; + basic_loop(ptrs, strides, 0, size0, op); + }); +} + +template +static void set_result(const int index, const res_t result, const TensorIteratorBase &iter, const int num_outputs) { + // static_assert(std::is_same::value, "data types must match"); + if (index < num_outputs) { + char *out = (char *) iter.data_ptr(index); + *(res_t *) out = result; + } +} + +template +static void set_results(const res_t result, const TensorIteratorBase &iter, const int num_outputs) { + AT_ASSERT(num_outputs == 1); + set_result(0, result, iter, num_outputs); +} + +template +inline typename std::enable_if::type +for_each_in_tuple(const std::tuple& /*t*/, const TensorIteratorBase& /*iter*/, const int /*num_outputs*/) { + return i; +} + +template +inline typename std::enable_if::type +for_each_in_tuple(const std::tuple& t, const TensorIteratorBase &iter, const int num_outputs) { + if (i < (size_t)num_outputs) { + set_result(i, std::get(t), iter, num_outputs); + return for_each_in_tuple(t, iter, num_outputs); + } + return i; +} + +template +static void set_results(const std::tuple& result, const TensorIteratorBase &iter, const int num_outputs) { + AT_ASSERT(num_outputs >= 1); + std::size_t result_size = for_each_in_tuple(result, iter, num_outputs); + AT_ASSERT((size_t)num_outputs == result_size); +} + +template +struct all_same : std::conjunction< + std::is_same... +> {}; + +// data_t is the input/output data type. +// acc_t is a type that contains all the necessary data +// to continue reducing. +// index_t is a one-dimensional index +// +// ops_t is such that &ops_t::reduce, &ops_t::combine, and &ops_t::project exist and satisfy +// the following. +// reduce: (acc_t, data_t, index_t) -> acc_t adds one data point to the accumulated value. +// combine: (acc_t, acc_t) -> acc_t combines two accumulated values into one. +// project: acc_t -> out_t finishes the reduction, getting the required output. +// +// Additionally, acc_t must be default-constructible: +// acc_t {} is an identity for combine, +// and project(acc_t {}) is the value of the operation on zero elements. +// +// The point of `combine` is to support parallelization - +// the idea is to one sequence of `reduce` calls per thread of execution, +// and then to combine them at the end with `combine`. +// +// If there is more than one output element, +// our parallelization strategy is to use one thread for each of them, +// which means that `combine` will never be called. +// +// If, on the other hand, there is only one, then we split the input into +// into several pieces, reduce each separately, and then combine them. + +template +void binary_kernel_reduce(TensorIteratorBase& iter, ops_t ops, init_t init) { + using rf_t = decltype(&ops_t::reduce); + using cf_t = decltype(&ops_t::combine); + using pf_t = decltype(&ops_t::project); + using r_traits = binary_function_traits; + using c_traits = binary_function_traits; + using p_traits = unary_function_traits; + using acc_t = typename p_traits::arg1_t; + using data_t = typename r_traits::arg2_t; + static_assert( + all_same< + acc_t, + init_t, + typename r_traits::arg1_t, + typename r_traits::result_type, + typename c_traits::arg1_t, + typename c_traits::arg2_t, + typename c_traits::result_type>::value, + "all accumulate types must match"); + static_assert( + std::is_default_constructible::value, + "the accumulate type must be default-constructible" + ); + const int num_outputs = iter.noutputs(); + iter.foreach_reduced_elt([&ops, &init, num_outputs](TensorIteratorBase &sub_iter) { + auto reduction_body = [&ops, &sub_iter, num_outputs](acc_t acc, int64_t begin, int64_t end) -> acc_t { + int ntensors = sub_iter.ntensors(); + sub_iter.serial_for_each([&acc, &ops, num_outputs, ntensors, begin](char** data, const int64_t* strides, int64_t size) { + AT_ASSERT(ntensors - num_outputs == 1); + char *in = data[ntensors - 1]; + int64_t stride = strides[ntensors - 1]; + for (const auto i : c10::irange(size)) { + acc = ops.reduce(acc, c10::load(in), begin + i); + in += stride; + } + }, {begin, end}); + return ops.translate_idx(acc, sub_iter.view_offsets()[0]); + }; + acc_t total_acc = init; + auto numel = sub_iter.numel(); + if (numel < at::internal::GRAIN_SIZE || at::get_num_threads() == 1 || + at::in_parallel_region()) { + total_acc = reduction_body(total_acc, 0, numel); + } else { + int max_threads = at::get_num_threads(); + AT_ASSERT(max_threads > 0); + static_assert( + !std::is_same::value, + "Concurrently modifying different references into std::vector is UB." + ); + std::vector buffer((unsigned)max_threads, init); + at::parallel_for(0, numel, internal::GRAIN_SIZE, + [&](int64_t begin, int64_t end) { + auto& acc = buffer[at::get_thread_num()]; + acc = reduction_body(acc, begin, end); + } + ); + for (const auto i : c10::irange(max_threads)) { + total_acc = ops.combine(total_acc, buffer[i]); + } + } + set_results(ops.project(total_acc), sub_iter, num_outputs); + }); +} + +template +void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, double ident = 0) { + using traits = binary_function_traits; + static_assert( + all_same< + typename traits::result_type, + typename traits::arg1_t, + typename traits::arg2_t>::value, + "all types must match"); + + iter.output_base().fill_(ident); + iter.parallel_reduce([&](char** data, const int64_t* strides, int64_t size0, int64_t size1) { + int64_t outer_strides[] = { strides[2], strides[3] }; + if (is_contiguous_reduction(strides)) { + // input is contiguous in dim 0, output is reduced in dim 0 + UNARY_OUTER_LOOP(data, outer_strides, size1, [&] { + vectorized_inner_reduction(data, size0, op, vop); + }); + } else if (is_outer_reduction(strides)) { + // input and output are contiguous in dim 1 + int64_t inner_stride = strides[1]; // stride of input in dim 0 + vectorized_outer_reduction(data, inner_stride, size0, size1, op, vop); + } else { + UNARY_OUTER_LOOP(data, outer_strides, size1, [&] { + char* ptrs[3] = { data[0], data[0], data[1] }; + int64_t inner_strides[3] = { strides[0], strides[0], strides[1] }; + basic_loop(ptrs, inner_strides, 0, size0, op); + }); + } + }); +} + +// when reduction is on most inner dimension (dim 0 in TensorIterator) +// and input has contiguous most inner dimension, `binary_kernel_reduce_lastdim` +// can be used. +inline bool is_reduce_lastdim(TensorIteratorBase& iter) { + return iter.num_reduce_dims() == 1 && iter.is_dim_reduced(0) + && iter.ninputs() == 1 && iter.strides(1)[0] == iter.element_size(1); +} + +template +void binary_kernel_reduce_lastdim(TensorIteratorBase& iter, reduce_func_t reduce_op) { + auto shape = iter.shape(); + int64_t dim_size = shape[0]; + int64_t grain_size = std::max((int64_t) 1, at::internal::GRAIN_SIZE / dim_size); + TensorIterator sub_iter(iter); + // create sub iterator to parallel on all non-reduce-dims + sub_iter.narrow(0, 0, 1); + auto loop = [&](char** data, const int64_t* strides, int64_t size) { + char* out = data[0]; + char* in = data[1]; + for (int64_t i = 0; i < size; ++i) { + reduce_op(out, in, dim_size); + out += strides[0]; + in += strides[1]; + } + }; + sub_iter.for_each(loop, grain_size); +} + +}}} // namespace at::native:: diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SampledAddmmKernel.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SampledAddmmKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..e1d75b17698c2e2f0d23bc831ad5cf466f691d59 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SampledAddmmKernel.h @@ -0,0 +1,12 @@ +#pragma once + +#include +#include + +namespace at::native { + +using sampled_addmm_sparse_csr_fn = void(*)(const Tensor&, const Tensor&, const Scalar&, const Scalar&, const Tensor&); + +DECLARE_DISPATCH(sampled_addmm_sparse_csr_fn, sampled_addmm_sparse_csr_stub); + +} // at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SerialStackImpl.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SerialStackImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..88ba1c91b6c8cb30cae8a55718e400153a9699a7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SerialStackImpl.h @@ -0,0 +1,146 @@ +// Copyright 2004-present Facebook. All Rights Reserved. +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace at::native::detail { + +struct InputMeta { + void* data_ptr; + int64_t inner_size; + + InputMeta(const Tensor& t, int64_t dim, int64_t inner) + : data_ptr(t.data_ptr()), inner_size(t.sizes()[dim] * inner) {} +}; + +// This kernel is used by two TensorList types: +// 1. stack_serial_kernel uses at::ArrayRef +// 2. Static runtime calls this kernel directly (csrc/jit/runtime/static/ops.cpp) with +// ProcessedNodeInputWrapper. +// When making changes, make sure that they are compatible with both types! +template +void stack_serial_kernel_impl(Tensor& result, TensorListType tensors, int64_t dim) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + dim >= 0 && dim <= result.dim(), + "dim out of range in stack_serial_kernel_impl"); + int64_t outer = + result.numel() / (result.sizes()[dim] * result.strides()[dim]); + scalar_t* result_data = result.data_ptr(); + int64_t ninputs = tensors.size(); + std::vector inputs; + inputs.reserve(ninputs); + for (const auto& tensor : tensors) { + inputs.emplace_back(tensor, dim, tensor.strides()[dim]); + } + + using Vec = vec::Vectorized; + scalar_t* result_ptr = result_data; + for (const auto i : c10::irange(outer)) { + for (const auto j : c10::irange(ninputs)) { + int64_t local_inner = inputs[j].inner_size; + scalar_t* input_ptr = (scalar_t*)(inputs[j].data_ptr) + i * local_inner; + + if (local_inner < Vec::size()) { + for (const auto k : c10::irange(local_inner)) { + result_ptr[k] = input_ptr[k]; + } + } else { + vec::map( + [](Vec x) { return x; }, result_ptr, input_ptr, local_inner); + } + result_ptr += local_inner; + } + } +} + +// Checks to see whether native stack can be invoked under these conditions: +// - result and input tensors are contiguous +// - only one thread is used +// - no type promotion has to occur +// - tensors dtype is Double or Float +template +bool can_use_native_serial_stack_impl(Tensor& result, TensorListType tensors, int64_t dim) { + TORCH_CHECK(tensors.size() > 0, "expected a non-empty list of Tensors"); + const Tensor& first_tensor = tensors[0]; + // stack dimension should be in range [0,firstTensor.dim()) + // dim == firstTensor.dim() is a valid input, but it is handled by default code path + // that uses unsqueeze + if (dim >= first_tensor.dim()) return false; + // Native stack doesn't apply any tensor is skipped. + if (first_tensor.numel() == 0 && first_tensor.dim() == 1) return false; + // there should be no type promotion + if (result.dtype() != first_tensor.dtype()) return false; + + auto first_tensor_mem_format = first_tensor.suggest_memory_format(); + ScalarType dtype = first_tensor.scalar_type(); + + if (!result.is_contiguous(first_tensor_mem_format)) { + return false; + } + + // fast path only works for Double and Float + if (dtype != ScalarType::Double && dtype != ScalarType::Float) { + return false; + } + + // check remainder of inputs +#ifndef STRIP_ERROR_MESSAGES + auto const &first_tensor_shape = first_tensor.sizes(); +#endif + for (const auto i : c10::irange(1, tensors.size())) { + auto const &tensor = tensors[i]; + TORCH_CHECK(tensors[i].sizes() == first_tensor.sizes(), + "stack expects each tensor to be equal size, but got ", first_tensor_shape, + " at entry 0 and ", tensor.sizes(), " at entry ", i); + + // every tensor must be contiguous + // tensor sizes and strides must be the same + // there should be no type promotion + if (!tensor.is_contiguous(first_tensor_mem_format) || + tensor.strides() != first_tensor.strides() || + tensor.dtype() != dtype) { + return false; + } + } + + // fast native stack should only be used when it is not worth using multiple threads + // or there is only one thread. Note that we aren't checking result.numel() here because + // it may not have been resized and we want to defer that cost till later. + int64_t numel_in_stack = first_tensor.numel() * tensors.size(); + return numel_in_stack < at::internal::GRAIN_SIZE || at::get_num_threads() == 1; +} + +template +struct CanUseNativeSerialStack; + +template +struct CanUseNativeSerialStack { + static bool call(Tensor& result, TensorListType tensors, int64_t dim) { + // Inputs cannot alias the output tensor + for (const auto i : c10::irange(tensors.size())) { + auto lap = at::get_overlap_status(result, tensors[i]); + TORCH_CHECK(lap != at::MemOverlapStatus::Partial && + lap != at::MemOverlapStatus::Full, 0, + "unsupported operation: the input tensors cannot refer to any of the " + "output memory locations. Found overlap in input tensor ", i); + } + + return can_use_native_serial_stack_impl(result, tensors, dim); + } +}; + +template +struct CanUseNativeSerialStack { + static bool call(Tensor& result, TensorListType tensors, int64_t dim) { + return can_use_native_serial_stack_impl(result, tensors, dim); + } +}; + +} // namespace at::native::detail diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SoftmaxKernel.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SoftmaxKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..ee9fac647ad6241c97e28a7af6f091d5d613bc3a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SoftmaxKernel.h @@ -0,0 +1,28 @@ +#pragma once + +#include +#include + +namespace at { +class Tensor; + +namespace native { + +using forward_fn = void (*)(const Tensor&, const Tensor&); +using backward_fn = void(*)(const Tensor &, const Tensor &, const Tensor&); + +DECLARE_DISPATCH(forward_fn, softmax_lastdim_kernel); +DECLARE_DISPATCH(forward_fn, log_softmax_lastdim_kernel); +DECLARE_DISPATCH(backward_fn, softmax_backward_lastdim_kernel); +DECLARE_DISPATCH(backward_fn, log_softmax_backward_lastdim_kernel); + +using forward_fn_with_dim = void(*)(const Tensor &, const Tensor &, const int64_t); +using backward_fn_with_dim = + void (*)(const Tensor&, const Tensor&, const Tensor&, const int64_t); + +DECLARE_DISPATCH(forward_fn_with_dim, softmax_kernel); +DECLARE_DISPATCH(forward_fn_with_dim, log_softmax_kernel); +DECLARE_DISPATCH(backward_fn_with_dim, softmax_backward_kernel); +DECLARE_DISPATCH(backward_fn_with_dim, log_softmax_backward_kernel); +} +} diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SpmmReduceKernel.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SpmmReduceKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..cbcbf3c63d9984ab4d8727f06e50dede5d840fb8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SpmmReduceKernel.h @@ -0,0 +1,22 @@ +#pragma once + +#include +#include +#include + +namespace at::native { + +using spmm_reduce_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op); +using spmm_reduce_arg_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op); +using spmm_reduce_backward_input_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op); +using spmm_reduce_backward_input_arg_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op); +using spmm_reduce_backward_other_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op); + +DECLARE_DISPATCH(spmm_reduce_fn, spmm_reduce_stub); +DECLARE_DISPATCH(spmm_reduce_arg_fn, spmm_reduce_arg_stub); +DECLARE_DISPATCH(spmm_reduce_backward_input_fn, spmm_reduce_backward_input_stub); +DECLARE_DISPATCH(spmm_reduce_backward_input_arg_fn, spmm_reduce_backward_input_arg_stub); +DECLARE_DISPATCH(spmm_reduce_backward_other_fn, spmm_reduce_backward_other_stub); +DECLARE_DISPATCH(spmm_reduce_backward_input_arg_fn, spmm_reduce_backward_other_arg_stub); + +} // at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/StackKernel.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/StackKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..6c96d83b9eaa03fda7f37fc1ccde1fa0d035959e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/StackKernel.h @@ -0,0 +1,12 @@ +// Copyright 2004-present Facebook. All Rights Reserved. +#pragma once + +#include +#include + +namespace at::native { + +using stack_serial_fn = void(*)(Tensor &, TensorList, int64_t); +DECLARE_DISPATCH(stack_serial_fn, stack_serial_stub); + +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/UpSampleKernelAVXAntialias.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/UpSampleKernelAVXAntialias.h new file mode 100644 index 0000000000000000000000000000000000000000..726a83c20963d065a41b6239a456a0af0837174a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/UpSampleKernelAVXAntialias.h @@ -0,0 +1,1376 @@ +/* +The Python Imaging Library (PIL) is + + Copyright © 1997-2011 by Secret Labs AB + Copyright © 1995-2011 by Fredrik Lundh + +Pillow is the friendly PIL fork. It is + + Copyright © 2010-2022 by Alex Clark and contributors + +Like PIL, Pillow is licensed under the open source HPND License +*/ + +// This code is heavily inspired from PILLOW-SIMD's implementation: +// https://github.com/uploadcare/pillow-simd/blob/simd/master/src/libImaging/Resample.c + +#pragma once +#ifdef CPU_CAPABILITY_AVX2 +// TODO: This file only supports AVX2. We could split the AVX kernels into +// smaller logical blocks in order to port them into the Vec.h logic. This would +// allow to support other vectorization architectures and perhaps also support +// the non-vectorized fallback (we'd need to make sure it's not slower than the +// current fallback). + +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + + +namespace { + +static inline __m128i mm_cvtsi32_si128(const uint8_t* C10_RESTRICT ptr, bool i32_aligned) { + int32_t v; + if (i32_aligned) { + v = *(const int32_t*)ptr; + } else { + std::memcpy(&v, ptr, 4); + } + return _mm_cvtsi32_si128(v); +} + +static inline __m128i mm_cvtepu8_epi32(const uint8_t* C10_RESTRICT ptr, bool i32_aligned) { + return _mm_cvtepu8_epi32(mm_cvtsi32_si128(ptr, i32_aligned)); +} + +static inline void _write_endline_rgb_as_uint32( + uint8_t* C10_RESTRICT output, + uint32_t data +) { + // data is (R G B X), output is (X1 X2 X3 | R1 B1 G1 R2 ...) + // Here we explicitly set X as R1 + uint8_t* data_ptr = reinterpret_cast(&data); + data_ptr[3] = output[3]; + std::memcpy(output, data_ptr, 4); +} + +at::Tensor unpack_rgb(const at::Tensor& packed_tensor) { + // Convert a "packed" tensor (typically RGBRGBRGB if channels_last) into + // RGBARGBARGBA format where A is hard-coded to 0. Each pixel is encoded + // into as 32 bits. This generalizes to num_channels <= 4 and also works for + // non-channels_last tensors. + + const uint8_t* packed = (const uint8_t*)packed_tensor.const_data_ptr(); + auto num_pixels = packed_tensor.size(1) * packed_tensor.size(2); + auto num_channels = packed_tensor.size(0); + + constexpr int rgba_size = 4; + auto unpacked_tensor = at::empty({rgba_size, packed_tensor.size(1), packed_tensor.size(2)}, at::CPU(at::kByte)); + uint8_t* unpacked = (uint8_t*) unpacked_tensor.data_ptr(); + + auto stride_i = packed_tensor.stride(2); + auto stride_j = packed_tensor.stride(0); + + for (const auto i : c10::irange(num_pixels)) { + for (const auto j : c10::irange(rgba_size)) { + unpacked[rgba_size * i + j] = (j < num_channels) ? packed[stride_i * i + stride_j * j] : 0; + } + } + return unpacked_tensor; +} + +void pack_rgb( + const at::Tensor& unpacked_tensor, // IN + const at::Tensor& packed_tensor // OUT +) { + // Convert from unpacked channels last 3-channels or 4-channels tensor into original data layout. + + uint8_t* unpacked = (uint8_t*)unpacked_tensor.data_ptr(); + uint8_t* packed = (uint8_t*)packed_tensor.data_ptr(); + auto num_pixels = packed_tensor.size(1) * packed_tensor.size(2); + auto num_channels = packed_tensor.size(0); + + auto unpacked_increment = unpacked_tensor.size(0); + auto packed_increment = packed_tensor.stride(2); + auto packed_stride = packed_tensor.stride(0); + + TORCH_INTERNAL_ASSERT(unpacked_increment == 3 || unpacked_increment == 4); + + for (const auto i C10_UNUSED : c10::irange(num_pixels)) { + for (const auto j : c10::irange(num_channels)) { + packed[j * packed_stride] = unpacked[j]; + } + unpacked += unpacked_increment; + packed += packed_increment; + } +} + +void ImagingResampleHorizontalConvolution8u4x( + uint8_t* C10_RESTRICT lineOut0, + uint8_t* C10_RESTRICT lineOut1, + uint8_t* C10_RESTRICT lineOut2, + uint8_t* C10_RESTRICT lineOut3, + int64_t out_xsize, + const uint8_t* C10_RESTRICT lineIn0, + const uint8_t* C10_RESTRICT lineIn1, + const uint8_t* C10_RESTRICT lineIn2, + const uint8_t* C10_RESTRICT lineIn3, + int64_t in_xsize, + const int64_t* idx_ptr_xmin, + const int64_t* idx_ptr_size, + const int16_t* kk, + int kmax, + unsigned int coefs_precision, + int64_t num_channels, + bool is_last_line); + +void ImagingResampleHorizontalConvolution8u( + uint8_t* C10_RESTRICT lineOut, + int64_t out_xsize, + const uint8_t* C10_RESTRICT lineIn, + int64_t in_xsize, + const int64_t* idx_ptr_xmin, + const int64_t* idx_ptr_size, + const int16_t* kk, + int kmax, + unsigned int coefs_precision, + int64_t num_channels, + bool is_last_line); + +void ImagingResampleVerticalConvolution8u( + uint8_t* C10_RESTRICT lineOut, + const uint8_t* C10_RESTRICT lineIn, + int64_t xsize, + int64_t ids_min, + int64_t ids_size, + const int16_t* k, + unsigned int coefs_precision, + int64_t num_channels); + +template +void ImagingResampleHorizontal( + const at::Tensor & unpacked_output, + const at::Tensor & unpacked_input, + int ksize, + const std::vector& horiz_indices_weights, + unsigned int horiz_weights_precision) { + + // Interpolation horizontal pass: we compute x-axis (image width) interpolation outputs. + + // Input data is stored as + // input = [r[0], g[0], b[0], a[0], r[1], g[1], b[1], a[1], r[2], g[2], b[2], a[2], ...] + // Weights are float values computed for each output pixel and rescaled to uint16: + // weights[i] = [w[i, 0], w[i, 1], ..., w[i, K-1]] + // We want to compute the output as following: + // output = [oR[0], oG[0], oB[0], oA[0], oR[1], oG[1], oB[1], oA[1], ...] + // where + // oR[yoffset + i] = r[yoffset + xmin[i]] * w[i, 0] + ... + r[yoffset + xmin[i] + K-1] * w[i, K-1] + // oG[yoffset + i] = g[yoffset + xmin[i]] * w[i, 0] + ... + g[yoffset + xmin[i] + K-1] * w[i, K-1] + // oB[yoffset + i] = b[yoffset + xmin[i]] * w[i, 0] + ... + b[yoffset + xmin[i] + K-1] * w[i, K-1] + // + + // TODO: we may want to merge that into the fallback code (currently called + // basic_loop_aa_horizontal) + // Although this may not be needed if / when we port all this code to use + // Vec.h since this would potentially give us another fall-back implem + + const int16_t* kk = (int16_t*)(horiz_indices_weights[3].const_data_ptr()); + + auto xout = unpacked_output.size(2); + auto yout = unpacked_output.size(1); + auto xin = unpacked_input.size(2); + TORCH_INTERNAL_ASSERT(num_channels == unpacked_input.size(0)); + + const int64_t* idx_ptr_xmin = horiz_indices_weights[0].const_data_ptr(); + const int64_t* idx_ptr_size = horiz_indices_weights[1].const_data_ptr(); + + uint8_t* unpacked_output_p = unpacked_output.data_ptr(); + const uint8_t* unpacked_input_p = unpacked_input.const_data_ptr(); + + int64_t yy = 0; + auto xout_stride = xout * num_channels; + auto xin_stride = xin * num_channels; + for (; yy < yout - 3; yy += 4) { + ImagingResampleHorizontalConvolution8u4x( + unpacked_output_p + yy * xout_stride, + unpacked_output_p + (yy + 1) * xout_stride, + unpacked_output_p + (yy + 2) * xout_stride, + unpacked_output_p + (yy + 3) * xout_stride, + xout, + unpacked_input_p + yy * xin_stride, + unpacked_input_p + (yy + 1) * xin_stride, + unpacked_input_p + (yy + 2) * xin_stride, + unpacked_input_p + (yy + 3) * xin_stride, + xin, + idx_ptr_xmin, + idx_ptr_size, + kk, + ksize, + horiz_weights_precision, + num_channels, + yy + 3 == yout - 1); + } + for (; yy < yout; yy++) { + ImagingResampleHorizontalConvolution8u( + unpacked_output_p + yy * xout_stride, + xout, + unpacked_input_p + yy * xin_stride, + xin, + idx_ptr_xmin, + idx_ptr_size, + kk, + ksize, + horiz_weights_precision, + num_channels, + yy == yout - 1); + } +} + +void ImagingResampleVertical( + const at::Tensor & unpacked_output, + const at::Tensor & unpacked_input, + int ksize, + const std::vector& vert_indices_weights, + unsigned int vert_weights_precision) { + + // Interpolation vertical pass: we compute y-axis interpolation outputs. + // Input data is stored as + // input = [r[0], g[0], b[0], a[0], r[1], g[1], b[1], a[1], r[2], g[2], b[2], a[2], ...] + // Weights are float values computed for each output pixel and rescaled to uint16: + // weights[i] = [w[i, 0], w[i, 1], ..., w[i, K-1]] + // We want to compute the output as following: + // output = [oR[0], oG[0], oB[0], oA[0], oR[1], oG[1], oB[1], oA[1], ...] + // where + // oR[xoffset + i] = r[xoffset + ymin[i]] * w[i, 0] + ... + r[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1] + // oG[xoffset + i] = g[xoffset + ymin[i]] * w[i, 0] + ... + g[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1] + // oB[xoffset + i] = b[xoffset + ymin[i]] * w[i, 0] + ... + b[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1] + + // TODO: we may want to merge that into the fallback code (currently called + // basic_loop_aa_vertical) + // Although this may not be needed if / when we port all this code to use + // Vec.h since this would potentially give us another fall-back implem + const int16_t* kk = (int16_t*)(vert_indices_weights[3].const_data_ptr()); + + const int64_t* idx_ptr_xmin = vert_indices_weights[0].const_data_ptr(); + const int64_t* idx_ptr_size = vert_indices_weights[1].const_data_ptr(); + + uint8_t* unpacked_output_p = unpacked_output.data_ptr(); + const uint8_t* unpacked_input_p = unpacked_input.const_data_ptr(); + + auto xout = unpacked_output.size(2); + auto yout = unpacked_output.size(1); + const auto num_channels = unpacked_input.size(0); + TORCH_INTERNAL_ASSERT(num_channels == unpacked_output.size(0)); + + auto xout_stride = xout * num_channels; + for (const auto yy : c10::irange(yout)) { + const auto* k = &kk[yy * ksize]; + auto ids_min = idx_ptr_xmin[yy]; + auto ids_size = idx_ptr_size[yy]; + ImagingResampleVerticalConvolution8u( + unpacked_output_p + yy * xout_stride, + unpacked_input_p, + xout, + ids_min, + ids_size, + k, + vert_weights_precision, + num_channels); + } +} + +// This is the only public entry point in this file. It supports bilinear or bicubic +// mode for uint8 dtype when C <= 4, with or without antialias. The +// implem is based on PIL-SIMD. +// Its equivalent implementation (fallback) for when AVX isn't supported or when +// C > 4 is separable_upsample_generic_Nd_kernel_impl() There are a bunch of +// future improvement that can be done: look for the TODOs in this file. +// For details on how the weights are computed and how the multiplications are +// run on int (instead of float weights), see +// [ Weights computation for uint8_t and multiplication trick ] +// For details on how the AVX kernels are implemented, see +// https://gist.github.com/NicolasHug/47c97d731f05eaad5694c173849b86f5 +// See also [ Support for antialias=False as a subcase of antialias=True ] to +// learn more about how the antialias=False case is computed. The same holds +// here: all these kernels are general enough to handle an arbitrary number of +// weights, but when aa=False they could be optimized further. +template +void upsample_avx_bilinear_bicubic_uint8( + const at::Tensor& input_, + const at::Tensor& output, + bool align_corners, + const scale_type& scales, + bool antialias) { + auto batch_size = input_.size(0); + auto num_channels = input_.size(1); + auto xin = input_.size(3); + auto yin = input_.size(2); + auto xout = output.size(3); + auto yout = output.size(2); + + if (xin == xout && yin == yout) { + output.copy_(input_); + return; + } + + at::Tensor input = input_; + if (!(input.is_contiguous() || input.is_contiguous(at::MemoryFormat::ChannelsLast))) { + // If input is not contiguous with memory format channels first or channels last, + // we explicitly convert the input to contiguous channels last memory format. + // This simplifies the rest of the code and let us assume that the format is only contiguous channels first or channels last, + // Most tensors going through this `if` block won't need to go through unpacking, but those having C < 3 may + // have to (this means 2 copies are made). We could avoid the extra copy by handling non-contiguous input + // directly within unpack_rgb() and pack_rgb(), but initial attempts showed that this is fairly complex. + input = input.contiguous(at::MemoryFormat::ChannelsLast); + } + + auto need_horizontal = xout != xin; + auto need_vertical = yout != yin; + + int ksize_horiz, ksize_vert; + std::vector horiz_indices_weights, vert_indices_weights; + unsigned int horiz_weights_precision, vert_weights_precision; + + bool skip_unpacking = (num_channels == 3 || num_channels == 4) && input.is_contiguous(at::MemoryFormat::ChannelsLast); + bool skip_packing = (num_channels == 3 || num_channels == 4) && output.is_contiguous(at::MemoryFormat::ChannelsLast); + + if (need_horizontal) { + int interp_dim = 3; + auto stride = (skip_unpacking) ? num_channels : 4; + std::tie(horiz_indices_weights, ksize_horiz, horiz_weights_precision) = + F::compute_index_ranges_int16_weights( + /*input_size=*/xin, + /*output_size=*/xout, + /*stride=*/stride, + /*ndims=*/4, + /*reshape_dim=*/interp_dim, + /*align_corners=*/align_corners, + /*opt_scale=*/scales[interp_dim - 2], + /*antialias=*/antialias, + /*align_i32=*/true); + } + + if (need_vertical) { + int interp_dim = 2; + auto stride = (skip_unpacking) ? num_channels * xout : 4 * xout; + std::tie(vert_indices_weights, ksize_vert, vert_weights_precision) = + F::compute_index_ranges_int16_weights( + /*input_size=*/yin, + /*output_size=*/yout, + /*stride=*/stride, + /*ndims=*/4, + /*reshape_dim=*/interp_dim, + /*align_corners=*/align_corners, + /*opt_scale=*/scales[interp_dim - 2], + /*antialias=*/antialias, + /*align_i32=*/true); + } + + at::Tensor buffer_horiz, buffer_vert; + // Minor optimization: we can avoid allocating an extra buffer if we're performing + // horizontal-only or vertical-only interpolation, and if the tensor doesn't + // need repacking + if (need_horizontal && (need_vertical || !skip_packing)) { + auto c = (skip_unpacking) ? num_channels : 4; + buffer_horiz = at::empty({c, yin, xout}, input.options()); + } + if (need_vertical && !skip_packing) { + auto c = (skip_unpacking) ? num_channels : 4; + buffer_vert = at::empty({c, yout, xout}, input.options()); + } + + for (const auto i : c10::irange(batch_size)) { + + at::Tensor unpacked_input = (skip_unpacking) ? input[i] : unpack_rgb(input[i]); + at::Tensor unpacked_output; + + if (need_horizontal) { + at::Tensor unpacked_output_temp = (need_vertical || !skip_packing) ? buffer_horiz : output[i]; + + if (skip_unpacking && num_channels == 3) { + ImagingResampleHorizontal<3>( + unpacked_output_temp, + unpacked_input, + ksize_horiz, + horiz_indices_weights, + horiz_weights_precision); + } else { + ImagingResampleHorizontal<4>( + unpacked_output_temp, + unpacked_input, + ksize_horiz, + horiz_indices_weights, + horiz_weights_precision); + } + unpacked_output = unpacked_input = unpacked_output_temp; + } + if (need_vertical) { + unpacked_output = (skip_packing) ? output[i] : buffer_vert; + + ImagingResampleVertical( + unpacked_output, + unpacked_input, + ksize_vert, + vert_indices_weights, + vert_weights_precision + ); + } + + TORCH_INTERNAL_ASSERT(unpacked_output.defined()); + + if (!skip_packing) { + pack_rgb(unpacked_output, output[i]); + } + } +} + +void ImagingResampleHorizontalConvolution8u4x( + uint8_t* C10_RESTRICT lineOut0, + uint8_t* C10_RESTRICT lineOut1, + uint8_t* C10_RESTRICT lineOut2, + uint8_t* C10_RESTRICT lineOut3, + int64_t out_xsize, + const uint8_t* C10_RESTRICT lineIn0, + const uint8_t* C10_RESTRICT lineIn1, + const uint8_t* C10_RESTRICT lineIn2, + const uint8_t* C10_RESTRICT lineIn3, + int64_t in_xsize, + const int64_t* idx_ptr_xmin, + const int64_t* idx_ptr_size, + const int16_t* kk, + int kmax, + unsigned int coefs_precision, + int64_t num_channels, + bool is_last_line) { + + // Interpolation horizontal pass processing together 4 vertical lines. + // - Input data format is RGBA or RGB with R,G,B,A being uint8. In case of RGBA + // we can encode 4 values as a single uint32 value. + // - We split the size of weight vector for a given output index as a sum: + // ids_size = num_blocks_4 * 4 + num_blocks_2 * 2 + num_blocks_1. + // - We load and process 4 weights values in a loop ("block 4") then we process 2 weights values + // in another loop ("block 2") and finally we process 1 weights value in the final loop ("block 1"). + + // Define shuffling masks (low/high) for num_channels 4 and 3 + // Mask low casts lower half of each lane to epi16 and reorder RGBARGBA -> RRGGBBAA: + // [r1 g1 b1 a1 r2 g2 b2 a2 ... | R1 G1 B1 A1 R2 G2 B2 A2 ... ] -> + // [r1 0 r2 0 g1 0 g2 0 b1 0 b2 0 a1 0 a2 0 | R1 0 R2 0 G1 0 G2 0 B1 0 B2 0 A1 0 A2 0] + // Mask high casts upper half of each lane to epi16 and reorder RGBARGBA -> RRGGBBAA:: + // [ ... r3 g3 b3 a3 r4 g4 b4 a4 | ... R3 G3 B3 A3 R4 G4 B4 A4 ] -> + // [r3 0 r4 0 g3 0 g4 0 b3 0 b4 0 a3 0 a4 0 | R3 0 R4 0 G3 0 G4 0 B3 0 B4 0 A3 0 A4 0] + + const auto mask_low_c4 = _mm256_set_epi8( + -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0, + -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0); + const auto mask_high_c4 = _mm256_set_epi8( + -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8, + -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8); + const auto mask_low_c3 = _mm256_set_epi8( + -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0, + -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0); + const auto mask_high_c3 = _mm256_set_epi8( + -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6, + -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6); + + const auto mask_low = (num_channels == 3) ? mask_low_c3 : mask_low_c4; + const auto mask_high = (num_channels == 3) ? mask_high_c3 : mask_high_c4; + + const auto stride = num_channels * sizeof(uint8_t); + + TORCH_INTERNAL_ASSERT(stride == 3 || stride == 4); + + // out_xsize = output width, out_x = output x index + // ids_min is the input offset index corresponding to out_x + // ids_size is the interpolation size for out_x + + // Let's precompute ids_size limits for block 4 and block 2. + // + // In block 4 (4 means we process 4 weight values together), we read input data + // with _mm_loadu_si128, i.e. 16 bytes, per one line: + // lineIn0 + stride * (i + ids_min) + 16 <= lineIn0 + stride * (ids_size + ids_min) + // --> i <= ids_size - 16.0 / stride + // Strict boundary: + // --> i < ids_size + 1 - int(ceil(16.0 / stride)) = ids_size - b4_delta + // Soft boundary for reading inside the buffer except its boundaries: + // --> i < ids_size + 1 - int(16.0 / stride) = ids_size - b4_delta_soft + // RGBA: b4_delta = b4_delta_soft = 3 + // RGB : b4_delta = 5 + // RGB : b4_delta_soft = 4 + const auto b4_delta = (stride == 4) ? 3 : ((is_last_line) ? 5 : 4); + + // In block 2 (2 means we process 2 weights values together), we read input data + // with _mm_loadl_epi64, i.e. 8 bytes, per one line: + // lineIn0 + stride * (i + ids_min) + 8 <= lineIn0 + stride * (ids_size + ids_min) + // --> i <= ids_size - 8.0 / stride + // Strict boundary: + // --> i < ids_size + 1 - int(ceil(8.0 / stride)) = ids_size - b2_delta + // Soft boundary for reading inside the buffer except its boundaries: + // --> i < ids_size + 1 - int(8.0 / stride) = ids_size - b2_delta_soft + // RGBA: b2_delta = b2_delta_soft = 1 + // RGB : b2_delta = 2 + // RGB : b2_delta_soft = 1 + const auto b2_delta = (stride == 4) ? 1 : ((is_last_line) ? 2 : 1); + + const auto max_out_x_strided = out_xsize * stride; + const auto max_in_x_strided = in_xsize * stride; + + const auto zero = _mm256_setzero_si256(); + const auto initial = _mm256_set1_epi32(1 << (coefs_precision - 1)); + + for (const auto out_x : c10::irange(out_xsize)) { + const auto ids_min = idx_ptr_xmin[out_x]; + const auto ids_size = idx_ptr_size[out_x]; + const auto * k = &kk[out_x * kmax]; + int64_t i = 0; + + auto sss0 = initial; + auto sss1 = initial; + + const auto * lineIn0_min = lineIn0 + ids_min; + const auto * lineIn1_min = lineIn1 + ids_min; + const auto * lineIn2_min = lineIn2 + ids_min; + const auto * lineIn3_min = lineIn3 + ids_min; + + // block 4 + for (; i < ids_size - b4_delta; i += 4) { + // Load 4 values from weight vector + // mmk0 = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ...] + // mmk1 = [wl_2 wh_2 wl_3 wh_3 wl_2 wh_2 wl_3 wh_3 ...] + const auto mmk0 = _mm256_set1_epi32(*(int32_t*)&k[i]); + const auto mmk1 = _mm256_set1_epi32(*(int32_t*)&k[i + 2]); + + // RGBA: Load 8 pixels (4 per line) from input lines 0 and 1: + // source = [ + // r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3 + // R0 G0 B0 A0 R1 G1 B1 A1 R2 G2 B2 A2 R3 G3 B3 A3 + // ] + // RGB: Load 10 pixels (5 per line) + // source = [ + // r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5 + // R0 G0 B0 R1 G1 B1 R2 G2 B2 R3 G3 B3 R4 G4 B4 R5 + // ] + auto source = _mm256_inserti128_si256(_mm256_castsi128_si256( + _mm_loadu_si128((__m128i *) (lineIn0_min + stride * i))), + _mm_loadu_si128((__m128i *) (lineIn1_min + stride * i)), 1); + + // Apply mask_low: + // RGBA: + // [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 A0 0 A1 0] + // RGB: + // [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 0 0 0 0] + auto pix1 = _mm256_shuffle_epi8(source, mask_low); + // Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision + sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk0)); + + // Apply mask_high: + // RGBA: + // [r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 a2 0 a3 0 | R2 0 R3 0 G2 0 G3 0 B2 0 B3 0 A2 0 A3 0] + // RGB: + // [r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 0 0 0 0 | R2 0 R3 0 G2 0 G3 0 B2 0 B3 0 0 0 0 0] + auto pix2 = _mm256_shuffle_epi8(source, mask_high); + // Compute output value as C += w2 * C2 + w3 * C3 for each channel in 32-bit precision + sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix2, mmk1)); + + // Same as above to next lines 2 and 3: + auto source2 = _mm256_inserti128_si256(_mm256_castsi128_si256( + _mm_loadu_si128((__m128i *) (lineIn2_min + stride * i))), + _mm_loadu_si128((__m128i *) (lineIn3_min + stride * i)), 1); + auto pix3 = _mm256_shuffle_epi8(source2, mask_low); + sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix3, mmk0)); + auto pix4 = _mm256_shuffle_epi8(source2, mask_high); + sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix4, mmk1)); + } + + // block 2 + for (; i < ids_size - b2_delta; i += 2) { + // Load 2 values from weight vector + // mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ...] + const auto mmk = _mm256_set1_epi32(*(int32_t*)&k[i]); + + // Load 4 pixels (2 per line) from input lines 0 and 1: + // RGBA: source1 = [ + // r0 g0 b0 a0 r1 g1 b1 a1 0 0 0 0 0 0 0 0 + // R0 G0 B0 A0 R1 G1 B1 A1 0 0 0 0 0 0 0 0 + // ] + // RGB: source1 = [ + // r0 g0 b0 r1 g1 b1 r2 0 0 0 0 0 0 0 0 + // R0 G0 B0 R1 G1 B1 R2 0 0 0 0 0 0 0 0 + // ] + auto source1 = _mm256_inserti128_si256(_mm256_castsi128_si256( + _mm_loadl_epi64((__m128i *) (lineIn0_min + stride * i))), + _mm_loadl_epi64((__m128i *) (lineIn1_min + stride * i)), 1); + // Apply mask_low: + // RGBA: + // [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 A0 0 A1 0] + // RGB: + // [r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0 | R0 0 R1 0 G0 0 G1 0 B0 0 B1 0 0 0 0 0] + auto pix1 = _mm256_shuffle_epi8(source1, mask_low); + // Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision + sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk)); + + // Same as above for lines 2 and 3: + auto source2 = _mm256_inserti128_si256(_mm256_castsi128_si256( + _mm_loadl_epi64((__m128i *) (lineIn2_min + stride * i))), + _mm_loadl_epi64((__m128i *) (lineIn3_min + stride * i)), 1); + auto pix2 = _mm256_shuffle_epi8(source2, mask_low); + sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk)); + } + + // block 1 + const auto i32_aligned = num_channels == 4; + for (; i < ids_size - 1; i++) { + // Load 1 value from weight vector + // mmk = [wl_0 wh_0 0 0 wl_0 wh_0 0 0 ...] + const auto mmk = _mm256_set1_epi32(k[i]); + + // Load 2 pixels (one per line) from input lines 0 and 1: + // RGBA: pix1 = [ + // r0 0 0 0 g0 0 0 0 b0 0 0 0 a0 0 0 0 + // R0 0 0 0 G0 0 0 0 B0 0 0 0 A0 0 0 0 + // ] + // RGB: pix1 = [ + // r0 0 0 0 g0 0 0 0 b0 0 0 0 r1 0 0 0 + // R0 0 0 0 G0 0 0 0 B0 0 0 0 R1 0 0 0 + // ] + auto pix1 = _mm256_inserti128_si256(_mm256_castsi128_si256( + mm_cvtepu8_epi32(lineIn0_min + stride * i, i32_aligned)), + mm_cvtepu8_epi32(lineIn1_min + stride * i, i32_aligned), 1); + // Compute output value as C += w0 * C0 for each channel in 32-bit precision + sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk)); + + // Same as above for lines 2 and 3 + auto pix2 = _mm256_inserti128_si256(_mm256_castsi128_si256( + mm_cvtepu8_epi32(lineIn2_min + stride * i, i32_aligned)), + mm_cvtepu8_epi32(lineIn3_min + stride * i, i32_aligned), 1); + sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk)); + } + + if (i == ids_size - 1) { + // last element + auto mmk = _mm256_set1_epi32(k[i]); + // For num_channels == 3 (3 bytes = one pixel) we tolerate to read 4 bytes + // lines 0, 1 and 2 wont go out of allocated memory bounds + auto pix = _mm256_inserti128_si256(_mm256_castsi128_si256( + mm_cvtepu8_epi32(lineIn0_min + stride * i, i32_aligned)), + mm_cvtepu8_epi32(lineIn1_min + stride * i, i32_aligned), 1); + sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix, mmk)); + + auto p0 = mm_cvtepu8_epi32(lineIn2_min + stride * i, i32_aligned); + __m128i p1; + if (num_channels == 3 && C10_UNLIKELY(is_last_line && ids_min + stride * i + 4 >= max_in_x_strided)) { + uint8_t input[4]; + std::memcpy(input, lineIn3_min + stride * i, 3); + p1 = mm_cvtepu8_epi32(input, true); + } else { + p1 = mm_cvtepu8_epi32(lineIn3_min + stride * i, i32_aligned); + } + auto pix2 = _mm256_inserti128_si256(_mm256_castsi128_si256(p0), p1, 1); + sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk)); + } + + // Convert fixed point values back to integers (truncating) + sss0 = _mm256_srai_epi32(sss0, coefs_precision); + sss1 = _mm256_srai_epi32(sss1, coefs_precision); + // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation + // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d 0 0 0 0 0 0 0 0) + sss0 = _mm256_packs_epi32(sss0, zero); + sss1 = _mm256_packs_epi32(sss1, zero); + // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation + // (a a b b c c d d) -> (a b c d 0 0 0 0) + sss0 = _mm256_packus_epi16(sss0, zero); + sss1 = _mm256_packus_epi16(sss1, zero); + + // Write the output into single uint32 + // (a b c d) -> x_uint32 + auto o0 = _mm_cvtsi128_si32(_mm256_castsi256_si128(sss0)); + auto o1 = _mm_cvtsi128_si32(_mm256_extracti128_si256(sss0, 1)); + auto o2 = _mm_cvtsi128_si32(_mm256_castsi256_si128(sss1)); + auto o3 = _mm_cvtsi128_si32(_mm256_extracti128_si256(sss1, 1)); + + const auto out_x_strided = stride * out_x; + + if (num_channels == 3 && C10_UNLIKELY(out_x_strided + 4 >= max_out_x_strided)) { + // Memcpy 4-bytes is faster than 3-bytes and this is a boundary case when we want to write + // 4 bytes (R G B | X) to the output buffer (X1 X2 X3 | R1). + // The 4th byte in the register (X) has a garbage value and 4th byte in the output buffer (R1) has a correct + // value which was previously computed by another line. In other words, it means that we can not overwrite + // it by simply writing 4 bytes from the register to the output. We'll do the following: + // v----------| + // Output = [... X1 X2 X3 | R1 G1 B1 R2 ...] + // First, we write R1 value to the 4th byte of (R G B | X) -> (R G B | R1) + // Second, we write 4 bytes from the register to the output: (X1 X2 X3 | R1) -> (R G B | R1) + // Output = [... R G B | R1 G1 B1 R2 ...] + + _write_endline_rgb_as_uint32(lineOut0 + out_x_strided, o0); + _write_endline_rgb_as_uint32(lineOut1 + out_x_strided, o1); + _write_endline_rgb_as_uint32(lineOut2 + out_x_strided, o2); + + if (C10_UNLIKELY(is_last_line)) { + // When we handle the last line, we can not access the next 4 bytes + // as they are out of memory bounds. + std::memcpy(lineOut3 + out_x_strided, (uint8_t *) &o3, num_channels); + } else { + _write_endline_rgb_as_uint32(lineOut3 + out_x_strided, o3); + } + } else if (num_channels == 3) { + // Memcpy 4-bytes is faster than 3-bytes and here + // we simply write 4 bytes (... R G B X 0 0 0 0 0 ...) where X is a garbage value + // that we will overwrite on the next iteration: (... R G B R G B X 0 0 ...) + std::memcpy(lineOut0 + out_x_strided, (uint8_t *) &o0, 4); + std::memcpy(lineOut1 + out_x_strided, (uint8_t *) &o1, 4); + std::memcpy(lineOut2 + out_x_strided, (uint8_t *) &o2, 4); + std::memcpy(lineOut3 + out_x_strided, (uint8_t *) &o3, 4); + } else { + // num_channels = 4 -> lineOutX + out_x_strided should be uint32 aligned + *(uint32_t *)(lineOut0 + out_x_strided) = o0; + *(uint32_t *)(lineOut1 + out_x_strided) = o1; + *(uint32_t *)(lineOut2 + out_x_strided) = o2; + *(uint32_t *)(lineOut3 + out_x_strided) = o3; + } + } +} + +void ImagingResampleHorizontalConvolution8u( + uint8_t* C10_RESTRICT lineOut, + int64_t out_xsize, + const uint8_t* C10_RESTRICT lineIn, + int64_t in_xsize, + const int64_t* idx_ptr_xmin, + const int64_t* idx_ptr_size, + const int16_t* kk, + int kmax, + unsigned int coefs_precision, + int64_t num_channels, + bool is_last_line) { + + // Interpolation horizontal pass processing only one vertical line. + // - Input data format is RGBA or RGB with R,G,B,A being uint8. In case of RGBA + // we can encode 4 values as a single uint32 value. + // - We split the size of weight vector for a given output index as a sum: + // ids_size = num_blocks_8 * 8 + num_blocks_4 * 4 + num_blocks_2 * 2 + num_blocks_1 + // - We load and process 8 weights values in a loop ("block 8") then 4 weights and 2 weights values in + // in another loops ("block 4" and "block 2") and finally we process 1 weight value in the final loop ("block 1"). + + // Define various shuffling masks + const auto kmask_low = _mm256_set_epi8( + 11, 10, 9, 8, 11, 10, 9, 8, 11, 10, 9, 8, 11, 10, 9, 8, + 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0); + const auto kmask_high = _mm256_set_epi8( + 15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12, + 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4); + const auto kmask_hl = _mm256_set_epi8( + 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, + 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0); + + const auto mask_low_c4 = _mm256_set_epi8( + -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0, + -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0); + const auto mask_high_c4 = _mm256_set_epi8( + -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8, + -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8); + const auto mask_low_c3 = _mm256_set_epi8( + -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0, + -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0); + const auto mask_high_c3 = _mm256_set_epi8( + -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6, + -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6); + const auto mask_hl_c3 = _mm256_set_epi8( + -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6, + -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0); + const auto mask_hl_c4 = _mm256_set_epi8( + -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8, + -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0); + + const auto mask_low128_c3 = _mm_set_epi8( + -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0); + const auto mask_low128_c4 = _mm_set_epi8( + -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0); + + const auto mask_low = (num_channels == 3) ? mask_low_c3 : mask_low_c4; + const auto mask_high = (num_channels == 3) ? mask_high_c3 : mask_high_c4; + const auto mask_hl = (num_channels == 3) ? mask_hl_c3 : mask_hl_c4; + const auto mask_low128 = (num_channels == 3) ? mask_low128_c3 : mask_low128_c4; + + // out_xsize = output width, out_x = output x index + // ids_min is the input offset index corresponding to out_x + // ids_size is the interpolation size for out_x + + const auto stride = num_channels * sizeof(uint8_t); + const auto zero = _mm_setzero_si128(); + + TORCH_INTERNAL_ASSERT(stride == 3 || stride == 4); + + // Let's precompute ids_size limits for block 8, block 4 and block 2 + // + // In block 8 (8 means we process 8 weight values together), we read at + // most 32 bytes input data (16 + 16 bytes for RGBA and 12 + 16 bytes for RGB) + // lineIn + stride * (i + ids_min) + 32 <= lineIn + stride * (ids_size + ids_min) + // --> i <= ids_size - 32.0 / stride + // Strict boundary: + // --> i < ids_size + 1 - int(ceil(32.0 / stride)) = ids_size - b8_delta + // Soft boundary for reading inside the buffer except its boundaries: + // --> i < ids_size + 1 - int(32.0 / stride) = ids_size - b8_delta_soft + // RGBA: b8_delta = b8_delta_soft = 7 + // RGB : b8_delta = 10 + // RGB : b8_delta_soft = 9 + const auto b8_delta = (stride == 4) ? 7 : ((is_last_line) ? 10 : 9); + + // In block 4 (4 means we process 4 weight values together), we read + // 16 bytes of input data. + // lineIn + stride * (i + ids_min) + 16 <= lineIn0 + stride * (ids_size + ids_min) + // --> i <= ids_size - 16.0 / stride + // Strict boundary: + // --> i < ids_size + 1 - int(ceil(16.0 / stride)) = ids_size - b4_delta + // Soft boundary for reading inside the buffer except its boundaries: + // --> i < ids_size + 1 - int(16.0 / stride) = ids_size - b4_delta_soft + // RGBA: b4_delta = b4_delta_soft = 3 + // RGB : b4_delta = 5 + // RGB : b4_delta_soft = 4 + const auto b4_delta = (stride == 4) ? 3 : ((is_last_line) ? 5 : 4); + + // In block 2 (2 means we process 2 weight values together), we read + // 8 bytes of input data. + // lineIn0 + stride * (i + ids_min) + 8 <= lineIn0 + stride * (ids_size + ids_min) + // --> i <= ids_size - 8.0 / stride + // Strict boundary: + // --> i < ids_size + 1 - int(ceil(8.0 / stride)) = ids_size - b2_delta + // Soft boundary for reading inside the buffer except its boundaries: + // --> i < ids_size + 1 - int(8.0 / stride) = ids_size - b2_delta_soft + // RGBA: b2_delta = b2_delta_soft = 1 + // RGB : b2_delta = 2 + // RGB : b2_delta_soft = 1 + const auto b2_delta = (stride == 4) ? 1 : ((is_last_line) ? 2 : 1); + + const auto max_out_x_strided = out_xsize * stride; + const auto max_in_x_strided = in_xsize * stride; + + for (const auto out_x : c10::irange(out_xsize)) { + __m128i sss; + const auto ids_min = idx_ptr_xmin[out_x]; + const auto ids_size = idx_ptr_size[out_x]; + const auto * k = &kk[out_x * kmax]; + int64_t i = 0; + + const auto * lineIn_min = lineIn + ids_min; + + if (ids_size < 8) { + sss = _mm_set1_epi32(1 << (coefs_precision - 1)); + } else { + // Lower part will be added to higher, use only half of the error + auto sss256 = _mm256_set1_epi32(1 << (coefs_precision - 2)); + + // block 8 + for (; i < ids_size - b8_delta; i += 8) { + // Load 8 values from weight vector + auto tmp = _mm_loadu_si128((__m128i*)&k[i]); + // ksource = [ + // wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 wl_4 wh_4 wl_5 wh_5 wl_6 wh_6 wl_7 wh_7 + // wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 wl_4 wh_4 wl_5 wh_5 wl_6 wh_6 wl_7 wh_7 + // ] + auto ksource = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1); + + // RGBA: Load 8 pixels from input: + // source = [ + // r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3 + // r4 g4 b4 a4 r5 g5 b5 a5 r6 g6 b6 a6 r7 g7 b7 a7 + // ] + // RGB: Load 10 pixels from input (however we can process only 8 pixels): + // source = [ + // r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5 + // r4 g4 b4 r5 g5 b5 r6 g6 b6 r7 g7 b7 r8 g8 b8 r9 + // ] + auto source = _mm256_inserti128_si256(_mm256_castsi128_si256( + _mm_loadu_si128((__m128i *) (lineIn_min + stride * i))), + _mm_loadu_si128((__m128i *) (lineIn_min + stride * (i + 4))), 1); + + // Extract lower part of each lane, cast to epi16 and reoder RGBARGBA -> RRGGBBAA + // RGBA: pix1 = [ + // r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0 + // r4 0 r5 0 g4 0 g5 0 b4 0 b5 0 a4 0 a5 0 + // ] + // RGB: pix1 = [ + // r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0 + // r4 0 r5 0 g4 0 g5 0 b4 0 b5 0 0 0 0 0 + // ] + auto pix1 = _mm256_shuffle_epi8(source, mask_low); + // mmk1 = [ + // wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ... + // wl_4 wh_4 wl_5 wh_5 wl_4 wh_4 wl_5 wh_5 ... ... + // ] + auto mmk1 = _mm256_shuffle_epi8(ksource, kmask_low); + // Compute output value as + // C += w0 * C0 + w1 * C1 + // C += w4 * C4 + w5 * C5 for each channel in 32-bit precision + sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix1, mmk1)); + + // Same as above for higher part of each lane + auto pix2 = _mm256_shuffle_epi8(source, mask_high); + auto mmk2 = _mm256_shuffle_epi8(ksource, kmask_high); + // Compute output value as + // C += w2 * C2 + w3 * C3 + // C += w6 * C6 + w7 * C7 for each channel in 32-bit precision + sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix2, mmk2)); + } + + // block 4 + for (; i < ids_size - b4_delta; i += 4) { + // Load 4 values from weight vector + auto tmp = _mm_loadl_epi64((__m128i *) &k[i]); + // ksource = [ + // wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 0 0 0 0 0 0 0 0 + // wl_0 wh_0 wl_1 wh_1 wl_2 wh_2 wl_3 wh_3 0 0 0 0 0 0 0 0 + // ] + auto ksource = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1); + + // Load pixels from input line + tmp = _mm_loadu_si128((__m128i *) (lineIn_min + stride * i)); + // RGBA: source = [ + // r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3 + // r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3 + // ] + // RGB: source = [ + // r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5 + // r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5 + // ] + auto source = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1); + + // Cast source to epi16 and reorder RGBARGBA -> RRGGBBAA + // RGBA: pix = [ + // r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 a0 0 a1 0 + // r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 a2 0 a3 0 + // ] + // RGB: pix = [ + // r0 0 r1 0 g0 0 g1 0 b0 0 b1 0 0 0 0 0 + // r2 0 r3 0 g2 0 g3 0 b2 0 b3 0 0 0 0 0 + // ] + auto pix = _mm256_shuffle_epi8(source, mask_hl); + // mmk = [ + // wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ... + // wl_2 wh_2 wl_3 wh_3 wl_2 wh_2 wl_3 wh_3 ... ... + // ] + auto mmk = _mm256_shuffle_epi8(ksource, kmask_hl); + // Compute output value as + // C += w0 * C0 + w1 * C1 + // C += w2 * C2 + w3 * C3 for each channel in 32-bit precision + sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix, mmk)); + } + + // Sum results between the lanes + sss = _mm_add_epi32( + _mm256_extracti128_si256(sss256, 0), + _mm256_extracti128_si256(sss256, 1)); + } + + // block 2 + for (; i < ids_size - b2_delta; i += 2) { + // Load 2 values from weight vector + // mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ...] + auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]); + // Load pixels from input line + // RGBA: source = [ + // r0 g0 b0 a0 r1 g1 b1 a1 0 0 0 0 0 0 0 0 + // ] + // RGB: source = [ + // r0 g0 b0 r1 g1 b1 r2 g2 0 0 0 0 0 0 0 0 + // ] + auto source = _mm_loadl_epi64((__m128i *) (lineIn_min + stride * i)); + // Cast source to epi16 and reorder RGBARGBA -> RRGGBBAA + auto pix = _mm_shuffle_epi8(source, mask_low128); + // Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision + sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk)); + } + + // block 1 + const auto i32_aligned = num_channels == 4; + for (; i < ids_size - 1; i++) { + // Load 1 value from weight vector + // mmk = [wl_0 wh_0 0 0 wl_0 wh_0 0 0 ...] + auto mmk = _mm_set1_epi32(k[i]); + // Load one pixel from input line + // RGBA: pix = [ + // r0 0 0 0 g0 0 0 0 b0 0 0 0 a0 0 0 0 + // ] + // RGB: pix = [ + // r0 0 0 0 g0 0 0 0 b0 0 0 0 r1 0 0 0 + // ] + auto pix = mm_cvtepu8_epi32(lineIn_min + stride * i, i32_aligned); + // Compute output value as C += w0 * C0 for each channel in 32-bit precision + sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk)); + } + + if (i == ids_size - 1) { + // last element + auto mmk = _mm_set1_epi32(k[i]); + __m128i pix; + auto p = lineIn_min + stride * i; + if (num_channels == 3 && C10_UNLIKELY(is_last_line && ids_min + stride * i + 4 >= max_in_x_strided)) { + uint8_t input[4]; + std::memcpy(input, p, 3); + pix = mm_cvtepu8_epi32(input, true); + } else { + pix = mm_cvtepu8_epi32(p, i32_aligned); + } + sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk)); + } + + // Convert fixed point values back to integers (truncating) + sss = _mm_srai_epi32(sss, coefs_precision); + // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation + // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d 0 0 0 0 0 0 0 0) + sss = _mm_packs_epi32(sss, zero); + // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation + // (a a b b c c d d) -> (a b c d 0 0 0 0) + sss = _mm_packus_epi16(sss, zero); + // Write the output into single uint32 + // (a b c d) -> x_uint32 + auto o = _mm_cvtsi128_si32(sss); + const auto out_x_strided = stride * out_x; + if (num_channels == 3 && C10_UNLIKELY(out_x_strided + 4 >= max_out_x_strided)) { + if (C10_UNLIKELY(is_last_line)) { + // When we handle the last line, we can not access the next 4 bytes + // as they are out of memory bounds. + std::memcpy(lineOut + out_x_strided, (uint8_t *) &o, 3); + } else { + // Memcpy 4-bytes is faster than 3-bytes and this is a boundary case when we want to write + // 4 bytes (R G B | X) to the output buffer (X1 X2 X3 | R1). + // The 4th byte in the register (X) has a garbage value and 4th byte in the output buffer (R1) has a correct + // value which was previously computed by another line. In other words, it means that we can not overwrite + // it by simply writing 4 bytes from the register to the output. We'll do the following: + // v----------| + // Output = [... X1 X2 X3 | R1 G1 B1 R2 ...] + // First, we write R1 value to the 4th byte of (R G B | X) -> (R G B | R1) + // Second, we write 4 bytes from the register to the output: (X1 X2 X3 | R1) -> (R G B | R1) + // Output = [... R G B | R1 G1 B1 R2 ...] + _write_endline_rgb_as_uint32(lineOut + out_x_strided, o); + } + } else if (num_channels == 3) { + // Memcpy 4-bytes is faster than 3-bytes and here + // we simply write 4 bytes (... R G B X 0 0 0 0 0 ...) where X is a garbage value + // that we will overwrite on the next iteration: (... R G B R G B X 0 0 ...) + std::memcpy(lineOut + out_x_strided, (uint8_t *) &o, 4); + } else { + // num_channels = 4 -> lineOut + out_x_strided should be uint32 aligned + *(uint32_t *)(lineOut + out_x_strided) = o; + } + } +} + +void ImagingResampleVerticalConvolution8u( + uint8_t* C10_RESTRICT lineOut, + const uint8_t* C10_RESTRICT lineIn, + int64_t xsize, + int64_t ids_min, + int64_t ids_size, + const int16_t* k, + unsigned int coefs_precision, + int64_t num_channels) { + + // Interpolation vertical pass processing one line. + // - We process x-axis data with blocks of 8, 2 and 1 + // - We split the size of weight vector for a given output index as a sum: K = n * 2 + m. + + // xsize = output width, also equals to input width + // ids_size = interpolation size + // ids_min = input y start index + const auto stride = num_channels * sizeof(uint8_t); + + TORCH_INTERNAL_ASSERT(stride == 3 || stride == 4); + + const int64_t data_size = xsize * stride; + const int64_t data_stride = stride; + constexpr auto vec_size = 256 / 8; + + const auto initial = _mm_set1_epi32(1 << (coefs_precision - 1)); + const auto initial_256 = _mm256_set1_epi32(1 << (coefs_precision - 1)); + const auto zero = _mm_setzero_si128(); + const auto zero_256 = _mm256_setzero_si256(); + + int64_t j = 0; + // block 8 + const auto b8_usable_vec_stride = (vec_size / data_stride) * data_stride; + for (; j < data_size - vec_size; j += b8_usable_vec_stride) { + auto sss0 = initial_256; + auto sss1 = initial_256; + auto sss2 = initial_256; + auto sss3 = initial_256; + int64_t i = 0; + const auto * lineIn_min = lineIn + j + ids_min; + + for (; i < ids_size - 1; i += 2) { + // Load 2 values from weight vector + auto mmk = _mm256_set1_epi32(*(int32_t*)&k[i]); + + // RGBA: Load 8 pixels per line + // source1 = [ + // r0 g0 b0 a0 r1 g1 b1 a1 r2 g2 b2 a2 r3 g3 b3 a3 + // r4 g4 b4 a4 r5 g5 b5 a5 r6 g6 b6 a6 r7 g7 b7 a7 + // ] + // RGB: Load 10 pixels per line (however we can process only 8 pixels): + // source1 = [ + // r0 g0 b0 r1 g1 b1 r2 g2 b2 r3 g3 b3 r4 g4 b4 r5 + // r4 g4 b4 r5 g5 b5 r6 g6 b6 r7 g7 b7 r8 g8 b8 r9 + // ] + auto source1 = + _mm256_loadu_si256((__m256i*)(lineIn_min + data_size * i)); + auto source2 = + _mm256_loadu_si256((__m256i*)(lineIn_min + data_size * (i + 1))); + + // Interleave source1 and source2 from the low half of each 128-bit lane + // and cast the result to epi16 + // RGBA: pix1 = [ + // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 a0 0 A0 0 + // r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 a1 0 A1 0 + // ] + // RGB: pix1 = [ + // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 0 0 0 0 + // r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 0 0 0 0 + // ] + auto source_lo = _mm256_unpacklo_epi8(source1, source2); + auto pix1 = _mm256_unpacklo_epi8(source_lo, zero_256); + // Compute output value as + // C += w0 * c0 + w1 * C0 + // C += w0 * c1 + w1 * C1 for each channel in 32-bit precision + sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk)); + + // RGBA: pix2 = [ + // r2 0 R2 0 g2 0 G2 0 b2 0 B2 0 a2 0 A2 0 + // r3 0 R3 0 g3 0 G3 0 b3 0 B3 0 a3 0 A3 0 + // ] + // RGB: pix2 = [ + // r2 0 R2 0 g2 0 G2 0 b2 0 B2 0 0 0 0 0 + // r3 0 R3 0 g3 0 G3 0 b3 0 B3 0 0 0 0 0 + // ] + auto pix2 = _mm256_unpackhi_epi8(source_lo, zero_256); + // Compute output value as + // C += w0 * c2 + w1 * C2 + // C += w0 * c3 + w1 * C3 for each channel in 32-bit precision + sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk)); + + // Same as above for the high half of each 128-bit lane + auto source_hi = _mm256_unpackhi_epi8(source1, source2); + auto pix3 = _mm256_unpacklo_epi8(source_hi, zero_256); + sss2 = _mm256_add_epi32(sss2, _mm256_madd_epi16(pix3, mmk)); + auto pix4 = _mm256_unpackhi_epi8(source_hi, zero_256); + sss3 = _mm256_add_epi32(sss3, _mm256_madd_epi16(pix4, mmk)); + } + // Same processing as above but with a single weight value + for (; i < ids_size; i += 1) { + auto mmk = _mm256_set1_epi32(k[i]); + + auto source1 = _mm256_loadu_si256((__m256i*)(lineIn_min + i * data_size)); + + auto source_lo = _mm256_unpacklo_epi8(source1, zero_256); + auto pix1 = _mm256_unpacklo_epi8(source_lo, zero_256); + sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk)); + auto pix2 = _mm256_unpackhi_epi8(source_lo, zero_256); + sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk)); + + auto source_hi = _mm256_unpackhi_epi8(source1, zero_256); + auto pix3 = _mm256_unpacklo_epi8(source_hi, _mm256_setzero_si256()); + sss2 = _mm256_add_epi32(sss2, _mm256_madd_epi16(pix3, mmk)); + auto pix4 = _mm256_unpackhi_epi8(source_hi, _mm256_setzero_si256()); + sss3 = _mm256_add_epi32(sss3, _mm256_madd_epi16(pix4, mmk)); + } + // Convert fixed point values back to integers (truncating) + sss0 = _mm256_srai_epi32(sss0, coefs_precision); + sss1 = _mm256_srai_epi32(sss1, coefs_precision); + sss2 = _mm256_srai_epi32(sss2, coefs_precision); + sss3 = _mm256_srai_epi32(sss3, coefs_precision); + // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation + // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d) + sss0 = _mm256_packs_epi32(sss0, sss1); + sss2 = _mm256_packs_epi32(sss2, sss3); + // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation + // (a a b b c c d d) -> (a b c d) + sss0 = _mm256_packus_epi16(sss0, sss2); + + // Stores 32 bytes + _mm256_storeu_si256((__m256i*)(lineOut + j), sss0); + } + + // TODO: Do we also need block 4 ??? + // block 2 + const auto b2_usable_vec_stride = (8 / data_stride) * data_stride; + for (; j < data_size - vec_size / 4; j += b2_usable_vec_stride) { + auto sss0 = initial; + auto sss1 = initial; + int64_t i = 0; + const auto * lineIn_min = lineIn + j + ids_min; + + for (; i < ids_size - 1; i += 2) { + // Load 2 values from weight vector + // mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ] + auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]); + + // Load 2 pixels per line + // RGBA: source1 = [ + // r0 g0 b0 a0 r1 g1 b1 a1 0 0 0 0 0 0 0 0 + // ] + // RGB: source1 = [ + // r0 g0 b0 r1 g1 b1 r2 g2 0 0 0 0 0 0 0 0 + // ] + auto source1 = _mm_loadl_epi64((__m128i *) (lineIn_min + i * data_size)); + auto source2 = _mm_loadl_epi64((__m128i *) (lineIn_min + (i + 1) * data_size)); + // Interleave source1 and source2 and cast the result to epi16 + // RGBA: pix = [ + // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 a0 0 A0 0 + // ] + // RGB: pix = [ + // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 0 0 0 0 + // ] + auto source = _mm_unpacklo_epi8(source1, source2); + auto pix = _mm_unpacklo_epi8(source, zero); + // Compute output value as C += w0 * c0 + w1 * C0 for each channel in 32-bit precision + sss0 = _mm_add_epi32(sss0, _mm_madd_epi16(pix, mmk)); + // RGBA: pix = [ + // r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 a1 0 A1 0 + // ] + // RGB: pix = [ + // r1 0 R1 0 g1 0 G1 0 b1 0 B1 0 0 0 0 0 + // ] + pix = _mm_unpackhi_epi8(source, zero); + // Compute output value as C += w0 * c1 + w1 * C1 for each channel in 32-bit precision + sss1 = _mm_add_epi32(sss1, _mm_madd_epi16(pix, mmk)); + } + // Same processing as above but with a single weight value + for (; i < ids_size; i += 1) { + auto mmk = _mm_set1_epi32(k[i]); + + auto source1 = _mm_loadl_epi64((__m128i*) (lineIn_min + i * data_size)); + + auto source = _mm_unpacklo_epi8(source1, zero); + auto pix1 = _mm_unpacklo_epi8(source, zero); + sss0 = _mm_add_epi32(sss0, _mm_madd_epi16(pix1, mmk)); + auto pix2 = _mm_unpackhi_epi8(source, zero); + sss1 = _mm_add_epi32(sss1, _mm_madd_epi16(pix2, mmk)); + } + // Convert fixed point values back to integers (truncating) + sss0 = _mm_srai_epi32(sss0, coefs_precision); + sss1 = _mm_srai_epi32(sss1, coefs_precision); + // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation + // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d) + sss0 = _mm_packs_epi32(sss0, sss1); + // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation + // (a a b b c c d d) -> (a b c d) + sss0 = _mm_packus_epi16(sss0, sss0); + // Store 2 pixels to the output + _mm_storel_epi64((__m128i*)(lineOut + j), sss0); + } + + // block 1 + const auto b1_usable_vec_stride = (4 / data_stride) * data_stride; + const auto i32_aligned = num_channels == 4; + for (; j < data_size - 4; j += b1_usable_vec_stride) { + auto sss = initial; + int64_t i = 0; + const auto * lineIn_min = lineIn + j + ids_min; + + for (; i < ids_size - 1; i += 2) { + // Load 2 values from weight vector + // mmk = [wl_0 wh_0 wl_1 wh_1 wl_0 wh_0 wl_1 wh_1 ... ] + auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]); + + // Load one pixel per line + // RGBA: source1 = [ + // r0 g0 b0 a0 0 0 0 0 0 0 0 0 0 0 0 0 + // ] + // RGB: source1 = [ + // r0 g0 b0 r1 0 0 0 0 0 0 0 0 0 0 0 0 + // ] + auto source1 = mm_cvtsi32_si128(lineIn_min + i * data_size, i32_aligned); + auto source2 = mm_cvtsi32_si128(lineIn_min + (i + 1) * data_size, i32_aligned); + + // Interleave source1 and source2 and cast the result to epi16 + // RGBA: pix = [ + // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 a0 0 A0 0 + // ] + // RGB: pix = [ + // r0 0 R0 0 g0 0 G0 0 b0 0 B0 0 0 0 0 0 + // ] + auto source = _mm_unpacklo_epi8(source1, source2); + auto pix = _mm_unpacklo_epi8(source, zero); + // Compute output value as C += w0 * c0 + w1 * C0 for each channel in 32-bit precision + sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk)); + } + + for (; i < ids_size; i++) { + auto mmk = _mm_set1_epi32(k[i]); + auto pix = mm_cvtepu8_epi32(lineIn_min + i * data_size, i32_aligned); + sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk)); + } + sss = _mm_srai_epi32(sss, coefs_precision); + sss = _mm_packs_epi32(sss, zero); + sss = _mm_packus_epi16(sss, zero); + + auto o = _mm_cvtsi128_si32(sss); + + // Here we write 4 bytes to the output even if num_channels < 4, e.g o = {r,g,b,X} for num_channels=3 + // It is OK to write 4th byte (e.g. X) as on the next step we will overwrite it with new data. + // We also wont go out of bounds of lineOut memory allocation + std::memcpy(lineOut + j, (uint8_t *) &o, 4); + } + + for (; j < data_size; j += data_stride) { + auto sss = initial; + int64_t i = 0; + const auto * lineIn_min = lineIn + j + ids_min; + // For RGBA we can use (ids_size - 1) as tighter limit but for RGB we can read outside memory boundary + // for the last remaining line + for (; i < ids_size - 2; i += 2) { + // Load two coefficients at once + auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]); + + // Load 2 lines + auto source1 = mm_cvtsi32_si128(lineIn_min + i * data_size, i32_aligned); + auto source2 = mm_cvtsi32_si128(lineIn_min + (i + 1) * data_size, i32_aligned); + + auto source = _mm_unpacklo_epi8(source1, source2); + auto pix = _mm_unpacklo_epi8(source, zero); + sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk)); + } + + // Same processing as above but with a single weight value + for (; i < ids_size; i++) { + auto mmk = _mm_set1_epi32(k[i]); + + const uint8_t * p = lineIn_min + i * data_size; + __m128i pix; + // There is no much perf gain using more detailed condition like + // num_channels == 3 && ids_min + j + data_size * i + 4 >= in_max_size + // const int64_t in_max_size = data_size * in_ysize; + if (num_channels == 3) { + uint8_t input[4]; + std::memcpy(input, p, 3); + pix = mm_cvtepu8_epi32(input, true); + } else { + pix = mm_cvtepu8_epi32(p, true); + } + sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk)); + } + + // Convert fixed point values back to integers (truncating) + sss = _mm_srai_epi32(sss, coefs_precision); + // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation + // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d) + sss = _mm_packs_epi32(sss, zero); + // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation + // (a a b b c c d d) -> (a b c d) + sss = _mm_packus_epi16(sss, zero); + // Store one pixel to the output + auto o = _mm_cvtsi128_si32(sss); + if (num_channels == 3 && C10_UNLIKELY(j + 4 >= data_size)) { + std::memcpy(lineOut + j, (uint8_t *) &o, 3); + } else { + std::memcpy(lineOut + j, (uint8_t *) &o, 4); + } + } +} + +} // anonymous namespace +#endif // CPU_CAPABILITY_AVX2 diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/WeightNormKernel.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/WeightNormKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..1fd8c75cc73b30f986635f52b5a13f910e1088b7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/WeightNormKernel.h @@ -0,0 +1,20 @@ +#pragma once +#include +#include + +namespace at { +class TensorBase; +} + +namespace at::native { + +using weight_norm_fn = void(*)( + TensorBase&, TensorBase&, const TensorBase&, const TensorBase&, int64_t); +using weight_norm_backward_fn = void(*)( + TensorBase&, TensorBase&, const TensorBase&, const TensorBase&, + const TensorBase&, const TensorBase&, int64_t); + +DECLARE_DISPATCH(weight_norm_fn, weight_norm_stub); +DECLARE_DISPATCH(weight_norm_backward_fn, weight_norm_backward_stub); + +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/avx_mathfun.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/avx_mathfun.h new file mode 100644 index 0000000000000000000000000000000000000000..f4fd3b7bc461fbf82e8b4a16dd9453e46e124efa --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/avx_mathfun.h @@ -0,0 +1,522 @@ +#pragma once +/* + AVX implementation of sin, cos, sincos, exp and log + + Based on "sse_mathfun.h", by Julien Pommier + http://gruntthepeon.free.fr/ssemath/ + + Copyright (C) 2012 Giovanni Garberoglio + Interdisciplinary Laboratory for Computational Science (LISC) + Fondazione Bruno Kessler and University of Trento + via Sommarive, 18 + I-38123 Trento (Italy) + + This software is provided 'as-is', without any express or implied + warranty. In no event will the authors be held liable for any damages + arising from the use of this software. + + Permission is granted to anyone to use this software for any purpose, + including commercial applications, and to alter it and redistribute it + freely, subject to the following restrictions: + + 1. The origin of this software must not be misrepresented; you must not + claim that you wrote the original software. If you use this software + in a product, an acknowledgment in the product documentation would be + appreciated but is not required. + 2. Altered source versions must be plainly marked as such, and must not be + misrepresented as being the original software. + 3. This notice may not be removed or altered from any source distribution. + + (this is the zlib license) +*/ + +#include + +/* The original source of this file has been modified. */ +#if defined(CPU_CAPABILITY_AVX2) + +#if defined(__GNUC__) +# define ALIGN32_BEG __attribute__((aligned(32))) +#elif defined(_WIN32) +# define ALIGN32_BEG __declspec(align(32)) +#endif + +typedef __m256 v8sf; // vector of 8 float (avx2) +typedef __m256i v8si; // vector of 8 int (avx2) + +/* declare some AVX constants -- why can't I figure a better way to do that? */ +#define _PS256_CONST(Name, Val) \ + static const ALIGN32_BEG float _ps256_##Name[8] = { Val, Val, Val, Val, Val, Val, Val, Val } +#define _PI32_CONST256(Name, Val) \ + static const ALIGN32_BEG int _pi32_256_##Name[8] = { Val, Val, Val, Val, Val, Val, Val, Val } +#define _PS256_CONST_TYPE(Name, Type, Val) \ + static const ALIGN32_BEG Type _ps256_##Name[8] = { Val, Val, Val, Val, Val, Val, Val, Val } + +_PS256_CONST(1 , 1.0f); +_PS256_CONST(0p5, 0.5f); +/* the smallest non denormalized float number */ +_PS256_CONST_TYPE(min_norm_pos, int, 0x00800000); +_PS256_CONST_TYPE(mant_mask, int, 0x7f800000); +_PS256_CONST_TYPE(inv_mant_mask, int, ~0x7f800000); + +_PS256_CONST_TYPE(sign_mask, int, (int)0x80000000); +_PS256_CONST_TYPE(inv_sign_mask, int, ~0x80000000); + +_PI32_CONST256(0, 0); +_PI32_CONST256(1, 1); +_PI32_CONST256(inv1, ~1); +_PI32_CONST256(2, 2); +_PI32_CONST256(4, 4); +_PI32_CONST256(0x7f, 0x7f); + +_PS256_CONST(cephes_SQRTHF, 0.707106781186547524); +_PS256_CONST(cephes_log_p0, 7.0376836292E-2); +_PS256_CONST(cephes_log_p1, - 1.1514610310E-1); +_PS256_CONST(cephes_log_p2, 1.1676998740E-1); +_PS256_CONST(cephes_log_p3, - 1.2420140846E-1); +_PS256_CONST(cephes_log_p4, + 1.4249322787E-1); +_PS256_CONST(cephes_log_p5, - 1.6668057665E-1); +_PS256_CONST(cephes_log_p6, + 2.0000714765E-1); +_PS256_CONST(cephes_log_p7, - 2.4999993993E-1); +_PS256_CONST(cephes_log_p8, + 3.3333331174E-1); +_PS256_CONST(cephes_log_q1, -2.12194440e-4); +_PS256_CONST(cephes_log_q2, 0.693359375); + + +/* natural logarithm computed for 8 simultaneous float + return NaN for x <= 0 +*/ +inline v8sf log256_ps(v8sf x) { + v8si imm0; + v8sf one = *(v8sf*)_ps256_1; + + //v8sf invalid_mask = _mm256_cmple_ps(x, _mm256_setzero_ps()); + v8sf invalid_mask = _mm256_cmp_ps(x, _mm256_setzero_ps(), _CMP_LE_OS); + + x = _mm256_max_ps(x, *(v8sf*)_ps256_min_norm_pos); /* cut off denormalized stuff */ + + // can be done with AVX2 + imm0 = _mm256_srli_epi32(_mm256_castps_si256(x), 23); + + /* keep only the fractional part */ + x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_mant_mask); + x = _mm256_or_ps(x, *(v8sf*)_ps256_0p5); + + // this is again another AVX2 instruction + imm0 = _mm256_sub_epi32(imm0, *(v8si*)_pi32_256_0x7f); + v8sf e = _mm256_cvtepi32_ps(imm0); + + e = _mm256_add_ps(e, one); + + /* part2: + if( x < SQRTHF ) { + e -= 1; + x = x + x - 1.0; + } else { x = x - 1.0; } + */ + //v8sf mask = _mm256_cmplt_ps(x, *(v8sf*)_ps256_cephes_SQRTHF); + v8sf mask = _mm256_cmp_ps(x, *(v8sf*)_ps256_cephes_SQRTHF, _CMP_LT_OS); + v8sf tmp = _mm256_and_ps(x, mask); + x = _mm256_sub_ps(x, one); + e = _mm256_sub_ps(e, _mm256_and_ps(one, mask)); + x = _mm256_add_ps(x, tmp); + + v8sf z = _mm256_mul_ps(x,x); + + v8sf y = *(v8sf*)_ps256_cephes_log_p0; + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p1); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p2); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p3); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p4); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p5); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p6); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p7); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p8); + y = _mm256_mul_ps(y, x); + + y = _mm256_mul_ps(y, z); + + tmp = _mm256_mul_ps(e, *(v8sf*)_ps256_cephes_log_q1); + y = _mm256_add_ps(y, tmp); + + + tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5); + y = _mm256_sub_ps(y, tmp); + + tmp = _mm256_mul_ps(e, *(v8sf*)_ps256_cephes_log_q2); + x = _mm256_add_ps(x, y); + x = _mm256_add_ps(x, tmp); + x = _mm256_or_ps(x, invalid_mask); // negative arg will be NAN + return x; +} + +_PS256_CONST(exp_hi, 88.3762626647949f); +_PS256_CONST(exp_lo, -88.3762626647949f); + +_PS256_CONST(cephes_LOG2EF, 1.44269504088896341); +_PS256_CONST(cephes_exp_C1, 0.693359375); +_PS256_CONST(cephes_exp_C2, -2.12194440e-4); + +_PS256_CONST(cephes_exp_p0, 1.9875691500E-4); +_PS256_CONST(cephes_exp_p1, 1.3981999507E-3); +_PS256_CONST(cephes_exp_p2, 8.3334519073E-3); +_PS256_CONST(cephes_exp_p3, 4.1665795894E-2); +_PS256_CONST(cephes_exp_p4, 1.6666665459E-1); +_PS256_CONST(cephes_exp_p5, 5.0000001201E-1); + +inline v8sf exp256_ps(v8sf x) { + v8sf tmp = _mm256_setzero_ps(), fx; + v8si imm0; + v8sf one = *(v8sf*)_ps256_1; + + x = _mm256_min_ps(x, *(v8sf*)_ps256_exp_hi); + x = _mm256_max_ps(x, *(v8sf*)_ps256_exp_lo); + + /* express exp(x) as exp(g + n*log(2)) */ + fx = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_LOG2EF); + fx = _mm256_add_ps(fx, *(v8sf*)_ps256_0p5); + + /* how to perform a floorf with SSE: just below */ + //imm0 = _mm256_cvttps_epi32(fx); + //tmp = _mm256_cvtepi32_ps(imm0); + + tmp = _mm256_floor_ps(fx); + + /* if greater, subtract 1 */ + //v8sf mask = _mm256_cmpgt_ps(tmp, fx); + v8sf mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS); + mask = _mm256_and_ps(mask, one); + fx = _mm256_sub_ps(tmp, mask); + + tmp = _mm256_mul_ps(fx, *(v8sf*)_ps256_cephes_exp_C1); + v8sf z = _mm256_mul_ps(fx, *(v8sf*)_ps256_cephes_exp_C2); + x = _mm256_sub_ps(x, tmp); + x = _mm256_sub_ps(x, z); + + z = _mm256_mul_ps(x,x); + + v8sf y = *(v8sf*)_ps256_cephes_exp_p0; + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p1); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p2); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p3); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p4); + y = _mm256_mul_ps(y, x); + y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p5); + y = _mm256_mul_ps(y, z); + y = _mm256_add_ps(y, x); + y = _mm256_add_ps(y, one); + + /* build 2^n */ + imm0 = _mm256_cvttps_epi32(fx); + // another two AVX2 instructions + imm0 = _mm256_add_epi32(imm0, *(v8si*)_pi32_256_0x7f); + imm0 = _mm256_slli_epi32(imm0, 23); + v8sf pow2n = _mm256_castsi256_ps(imm0); + y = _mm256_mul_ps(y, pow2n); + return y; +} + +_PS256_CONST(minus_cephes_DP1, -0.78515625); +_PS256_CONST(minus_cephes_DP2, -2.4187564849853515625e-4); +_PS256_CONST(minus_cephes_DP3, -3.77489497744594108e-8); +_PS256_CONST(sincof_p0, -1.9515295891E-4); +_PS256_CONST(sincof_p1, 8.3321608736E-3); +_PS256_CONST(sincof_p2, -1.6666654611E-1); +_PS256_CONST(coscof_p0, 2.443315711809948E-005); +_PS256_CONST(coscof_p1, -1.388731625493765E-003); +_PS256_CONST(coscof_p2, 4.166664568298827E-002); +_PS256_CONST(cephes_FOPI, 1.27323954473516); // 4 / M_PI + + +/* evaluation of 8 sines at onces using AVX intrinsics + + The code is the exact rewriting of the cephes sinf function. + Precision is excellent as long as x < 8192 (I did not bother to + take into account the special handling they have for greater values + -- it does not return garbage for arguments over 8192, though, but + the extra precision is missing). + + Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the + surprising but correct result. + +*/ +inline v8sf sin256_ps(v8sf x) { // any x + v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, sign_bit, y; + v8si imm0, imm2; + + sign_bit = x; + /* take the absolute value */ + x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask); + /* extract the sign bit (upper one) */ + sign_bit = _mm256_and_ps(sign_bit, *(v8sf*)_ps256_sign_mask); + + /* scale by 4/Pi */ + y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI); + + /* + Here we start a series of integer operations, which are in the + realm of AVX2. + If we don't have AVX, let's perform them using SSE2 directives + */ + + /* store the integer part of y in mm0 */ + imm2 = _mm256_cvttps_epi32(y); + /* j=(j+1) & (~1) (see the cephes sources) */ + // another two AVX2 instruction + imm2 = _mm256_add_epi32(imm2, *(v8si*)_pi32_256_1); + imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_inv1); + y = _mm256_cvtepi32_ps(imm2); + + /* get the swap sign flag */ + imm0 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_4); + imm0 = _mm256_slli_epi32(imm0, 29); + /* get the polynom selection mask + there is one polynom for 0 <= x <= Pi/4 + and another one for Pi/4 +#include + +namespace at::native { + +using weight_to_int4pack_fn = void(*)(const Tensor&, const Tensor&, int, int); +using int4pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, int, const Tensor&, int, int); +using int8pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&); + +DECLARE_DISPATCH(weight_to_int4pack_fn, weight_to_int4pack_stub); +DECLARE_DISPATCH(int4pack_mm_fn, int4pack_mm_stub); +DECLARE_DISPATCH(int8pack_mm_fn, int8pack_mm_stub); + +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/mixed_data_type.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/mixed_data_type.h new file mode 100644 index 0000000000000000000000000000000000000000..13244af3b34a0f36defc69fa7fc219e93aae7757 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/mixed_data_type.h @@ -0,0 +1,41 @@ +#pragma once + +#include + +namespace at::native { + +inline ScalarType first_type() { + return ScalarType::Undefined; +} + +template +inline ScalarType first_type(const Tensor& arg, const Args&... parameters) { + return arg.defined() ? arg.scalar_type() : first_type(parameters...); +} + +template +inline bool is_mixed_type(const Tensor& input, const Args&... parameters) { + const auto parameter_type = first_type(parameters...); + return ((parameter_type != ScalarType::Undefined) && + (parameter_type != input.scalar_type())); +} + +// currently on CPU, mixed data type is only supported +// when input is 'BFloat16' or 'Half' and parameters are 'Float' +inline void check_mixed_data_type(const Tensor& input) { + TORCH_CHECK(at::isReducedFloatingType(input.scalar_type()), + "mixed dtype (CPU): all inputs must share same datatype."); +} + +template +inline void check_mixed_data_type(const Tensor& input, const Tensor& parameter, const Args&... parameters) { + TORCH_CHECK(!parameter.defined() || parameter.scalar_type() == ScalarType::Float, + "mixed dtype (CPU): expect parameter to have scalar type of Float"); + check_mixed_data_type(input, parameters...); +} + +inline ScalarType param_scalar_type(const Tensor& t, bool is_mixed_type) { + return is_mixed_type ? ScalarType::Float : t.scalar_type(); +} + +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/moments_utils.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/moments_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..6f403d60ea7c09849e1ea9d48625532aa51117c3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/moments_utils.h @@ -0,0 +1,202 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace at::native { +inline namespace CPU_CAPABILITY { + +template using opmath_t = at::opmath_type; + +constexpr int64_t kChunkSize = 16; + +template +void AddMoments( + int64_t m0_add, + const T& m1_add, + const T& m2_add, + int64_t& m0, + T& m1, + T& m2) { + const int64_t n = m0 + m0_add; + const T c = n == 0 ? static_cast(0) : static_cast(m0_add) / static_cast(n); + const T delta = m1_add - m1; + m1 += c * delta; + m2 += m2_add + delta * delta * c * static_cast(m0); + m0 = n; +} + +template +C10_ALWAYS_INLINE void AddMomentsVec( + int64_t m0_add, + const vec::Vectorized& m1_add, + const vec::Vectorized& m2_add, + int64_t& m0, + vec::Vectorized& m1, + vec::Vectorized& m2) { + using Vec = vec::Vectorized; + const int64_t n = m0 + m0_add; + const T c = n == 0 ? static_cast(0) : static_cast(m0_add) / static_cast(n); + const Vec c_vec(c); + const Vec delta = m1_add - m1; + m1 += c_vec * delta; + m2 += m2_add + delta * delta * c_vec * Vec(static_cast(m0)); + m0 = n; +} + +template +inline std::enable_if_t>, void> +UpdateMomentsVec( + int64_t m0, + const T* X_ptr, + const std::array>, kChunkSize>& c_vecs, + int64_t& m0_stk0, + vec::Vectorized>& m1_stk0, + vec::Vectorized>& m2_stk0) { + using Vec = vec::Vectorized>; + Vec m1_vec(0); + Vec m2_vec(0); + for (const auto j : c10::irange(m0)) { + const Vec x_vec = Vec::loadu(X_ptr + j * Vec::size()); + const Vec delta_vec = x_vec - m1_vec; + m1_vec += delta_vec * c_vecs[j]; + m2_vec += delta_vec * (x_vec - m1_vec); + } + AddMomentsVec(m0, m1_vec, m2_vec, m0_stk0, m1_stk0, m2_stk0); +} + +// each bfloat16/half vector will be converted to two float vectors, +// and accumulated successively on m1_stk0/m2_stk0. +template +inline std::enable_if_t>, void> +UpdateMomentsVec( + int64_t m0, + const T* X_ptr, + const std::array>, kChunkSize>& c_vecs, + int64_t& m0_stk0, + vec::Vectorized>& m1_stk0, + vec::Vectorized>& m2_stk0) { + using Vec = vec::Vectorized; + using fVec = vec::Vectorized>; + fVec m1_fvec0(0), m1_fvec1(0); + fVec m2_fvec0(0), m2_fvec1(0); + for (const auto j : c10::irange(m0)) { + const Vec x_bvec = Vec::loadu(X_ptr + j * Vec::size()); + auto [x_fvec0, x_fvec1] = convert_to_float(x_bvec); + const fVec delta_fvec0 = x_fvec0 - m1_fvec0; + const fVec delta_fvec1 = x_fvec1 - m1_fvec1; + m1_fvec0 += delta_fvec0 * c_vecs[j]; + m1_fvec1 += delta_fvec1 * c_vecs[j]; + m2_fvec0 += delta_fvec0 * (x_fvec0 - m1_fvec0); + m2_fvec1 += delta_fvec1 * (x_fvec1 - m1_fvec1); + } + AddMomentsVec(m0, m1_fvec0, m2_fvec0, m0_stk0, m1_stk0, m2_stk0); + AddMomentsVec(m0, m1_fvec1, m2_fvec1, m0_stk0, m1_stk0, m2_stk0); +} + +// Compute rowwise moments by Welford algorithm and cascade sum to improve +// numerical stability. +// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance +// https://en.wikipedia.org/wiki/Pairwise_summation +template +std::pair, opmath_t> RowwiseMomentsImpl(const T* X, int64_t N, int64_t ddof = 0) { + using math_t = opmath_t; + + constexpr int64_t kVecSize = vec::Vectorized::size(); + constexpr int64_t kAccVecSize = vec::Vectorized::size(); + const int64_t n = N / kVecSize; + const int64_t m = divup(n, kChunkSize); + const int64_t depth = utils::CeilLog2(m); + + using Vec = vec::Vectorized; + const Vec kZeroVec(math_t(0)); + c10::SmallVector m0_stk(depth, 0); + c10::SmallVector m1_stk(depth, kZeroVec); + c10::SmallVector m2_stk(depth, kZeroVec); + + for (const auto i : c10::irange(m)) { + const T* X_ptr = X + i * kChunkSize * kVecSize; + const int64_t m0 = std::min(kChunkSize, n - i * kChunkSize); + static std::array c_vecs = ([]() { + std::array result; + for (const auto i : c10::irange(kChunkSize)) { + result[i] = Vec(math_t(1) / static_cast(i + 1)); + } + return result; + })(); + UpdateMomentsVec(m0, X_ptr, c_vecs, m0_stk[0], m1_stk[0], m2_stk[0]); + + int64_t mask = i + 1; + for (int64_t j = 1; j < depth && (mask & 1) == 0; ++j) { + AddMomentsVec( + m0_stk[j - 1], + m1_stk[j - 1], + m2_stk[j - 1], + m0_stk[j], + m1_stk[j], + m2_stk[j]); + m0_stk[j - 1] = 0; + m1_stk[j - 1] = kZeroVec; + m2_stk[j - 1] = kZeroVec; + mask >>= 1; + } + } + for (const auto i : c10::irange(1, depth)) { + AddMomentsVec( + m0_stk[i], m1_stk[i], m2_stk[i], m0_stk[0], m1_stk[0], m2_stk[0]); + } + + std::array m1_arr{}; + std::array m2_arr{}; + m1_stk[0].store(m1_arr.data()); + m2_stk[0].store(m2_arr.data()); + + int64_t m0 = 0; + math_t m1 = 0; + math_t m2 = 0; + for (int64_t i = n * kVecSize; i < N; ++i) { + math_t x = static_cast(X[i]); + const math_t delta = x - m1; + ++m0; + m1 += delta / static_cast(m0); + m2 += delta * (x - m1); + } + // for BFloat16, each vector in m1_arr/m2_arr holds 2*n accumulated result + int64_t m0_add = n * kVecSize / kAccVecSize; + for (const auto i : c10::irange(kAccVecSize)) { + AddMoments(m0_add, m1_arr[i], m2_arr[i], m0, m1, m2); + } + + return std::make_pair(m1, m2 / static_cast(N - ddof)); +} + +template +std::pair, opmath_t> RowwiseMoments(const T* X, int64_t N, int64_t ddof = 0) { + using Vec = vec::Vectorized; + constexpr int64_t kVecSize = Vec::size(); + const int64_t n = N / kVecSize; + const int64_t m = divup(n, kChunkSize); + const int64_t depth = utils::CeilLog2(m); + if (depth <= 4) { + return RowwiseMomentsImpl(X, N, ddof); + } else if (depth <= 8) { + return RowwiseMomentsImpl(X, N, ddof); + } else if (depth <= 16) { + return RowwiseMomentsImpl(X, N, ddof); + } else if (depth <= 32) { + return RowwiseMomentsImpl(X, N, ddof); + } else { + return RowwiseMomentsImpl(X, N, ddof); + } +} + +} // namespace CPU_CAPABILITY +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/utils.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..a558c1bf13139a504ce35bf58a1fd838232aa758 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/utils.h @@ -0,0 +1,212 @@ +#pragma once + +#include +#include +#include +#include + +#ifdef USE_FBGEMM +#include +#endif + +namespace at::native { + +template +inline void _store(T* dst, at::vec::Vectorized src) { + src.store(dst); +} + +inline void _store(at::BFloat16* dst, at::vec::Vectorized src) { + auto res = at::vec::convert_float_bfloat16(src, src); + res.store(dst, at::vec::Vectorized::size()); +} + +inline void _store(at::Half* dst, at::vec::Vectorized src) { + auto res = at::vec::convert_float_half(src, src); + res.store(dst, at::vec::Vectorized::size()); +} + +inline namespace CPU_CAPABILITY { + +template +inline T data_index_init(T offset) { + return offset; +} + +template +inline T data_index_init(T offset, T& x, const T& X, Args&&... args) { + offset = data_index_init(offset, std::forward(args)...); + x = offset % X; + return offset / X; +} + +inline bool data_index_step() { + return true; +} + +template +inline bool data_index_step(T& x, const T& X, Args&&... args) { + if (data_index_step(std::forward(args)...)) { + x = ((x + 1) == X) ? 0 : (x + 1); + return x == 0; + } + return false; +} + +// Helper struct for bfloat16/float16 vectorization +// Useful when you need float as immediate dtype or accumulate dtype +using namespace vec; +struct Vec2 { + Vectorized val0, val1; + Vec2(Vectorized v0, Vectorized v1) : val0(v0), val1(v1) {} + Vec2(float v) : val0(v), val1(v) {} + static Vec2 loadu(const BFloat16* ptr) { + auto [v0, v1] = convert_bfloat16_float(Vectorized::loadu(ptr)); + return {v0, v1}; + } + static Vec2 loadu(const Half* ptr) { + auto [v0, v1] = convert_half_float(Vectorized::loadu(ptr)); + return {v0, v1}; + } + static Vec2 loadu(const float* ptr) { + return {Vectorized::loadu(ptr), Vectorized::loadu(ptr + Vectorized::size())}; + } + void store(BFloat16* ptr) const { + Vectorized val = convert_float_bfloat16(val0, val1); + val.store(ptr); + } + void store(Half* ptr) const { + Vectorized val = convert_float_half(val0, val1); + val.store(ptr); + } + void store(float* ptr) const { + val0.store(ptr); + val1.store(ptr + Vectorized::size()); + } +}; +inline Vec2 operator+(const Vec2& a, const Vec2& b) { return {a.val0 + b.val0, a.val1 + b.val1}; } +inline Vec2 operator*(const Vec2& a, const Vec2& b) { return {a.val0 * b.val0, a.val1 * b.val1}; } +inline Vec2 operator-(const Vec2& a, const Vec2& b) { return {a.val0 - b.val0, a.val1 - b.val1}; } +inline Vec2 operator/(const Vec2& a, const Vec2& b) { return {a.val0 / b.val0, a.val1 / b.val1}; } +inline Vec2 maximum(const Vec2& a, const Vec2& b) { return {vec::maximum(a.val0, b.val0), vec::maximum(a.val1, b.val1)}; } +inline Vec2 minimum(const Vec2& a, const Vec2& b) { return {vec::minimum(a.val0, b.val0), vec::minimum(a.val1, b.val1)}; } + +template struct VectorizedType { using type = Vectorized; }; +template <> struct VectorizedType { using type = Vec2; }; +template <> struct VectorizedType { using type = Vec2; }; +template using VecType = typename VectorizedType::type; + +// Helper for mixed data type parameter Vec::load +inline std::tuple, Vectorized> load2f(const BFloat16* ptr) { + return convert_bfloat16_float(Vectorized::loadu(ptr)); +} + +inline std::tuple, Vectorized> load2f(const Half* ptr) { + return convert_half_float(Vectorized::loadu(ptr)); +} + +inline std::tuple, Vectorized> load2f(const float* ptr) { + using Vec = Vectorized; + return std::make_tuple(Vec::loadu(ptr), Vec::loadu(ptr + Vec::size())); +} + +inline std::tuple, Vectorized> load2f(const BFloat16* ptr, int64_t count) { + return convert_bfloat16_float(Vectorized::loadu(ptr, count)); +} + +inline std::tuple, Vectorized> load2f(const Half* ptr, int64_t count) { + return convert_half_float(Vectorized::loadu(ptr, count)); +} + +inline std::tuple, Vectorized> load2f(const float* ptr, int64_t count) { + using Vec = Vectorized; + if (count > Vec::size()) { + return std::make_tuple(Vec::loadu(ptr), Vec::loadu(ptr + Vec::size(), count - Vec::size())); + } else { + return std::make_tuple(Vec::loadu(ptr, count), Vec(0)); + } +} + +} // namespace + +namespace utils { + +template +T CeilLog2(const T& x) { + if (x <= 2) { + return 1; + } + // Last set bit is floor(log2(x)), floor + 1 is ceil + // except when x is an exact powers of 2, so subtract 1 first + return static_cast(llvm::findLastSet(static_cast(x) - 1)) + 1; +} + +// matrix transpose: +// src has shape of M by N, with leading dimension of ld_src +// dst has shape of N by M, with leading dimension of ld_dst +template +inline void transpose(int64_t M, int64_t N, const T* src, int64_t ld_src, T* dst, int64_t ld_dst) { + for (int64_t j = 0; j < N; j++) { + for (int64_t i = 0; i < M; i++) { + dst[j * ld_dst + i] = src[i * ld_src + j]; + } + } +} + +#ifdef USE_FBGEMM +template <> +inline void transpose(int64_t M, int64_t N, const float* src, int64_t ld_src, float* dst, int64_t ld_dst) { + TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM."); + fbgemm::transpose_simd(M, N, src, ld_src, dst, ld_dst); +} + +template <> +inline void transpose(int64_t M, int64_t N, const uint16_t* src, int64_t ld_src, uint16_t* dst, int64_t ld_dst) { + TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM."); + fbgemm::transpose_simd(M, N, src, ld_src, dst, ld_dst); +} +#endif + +template +inline void parallel_sparse_csr( + const TensorAccessor& crow_acc, + const int64_t M, + const int64_t nnz, + const F& f) { + TORCH_CHECK(crow_acc.size(0) == M + 1); + + // directly parallel on `M` may lead to load imbalance, + // statically determine thread partition here to average payload + // for each thread. + int num_threads = at::get_num_threads(); + std::vector thread_splits(num_threads + 1, M); + + int64_t thread_averge_payload = std::max((int64_t)1, divup(nnz, num_threads)); + + thread_splits[0] = 0; + int64_t sum = 0; + int64_t t = 1; + for (const auto m : c10::irange(M)) { + int64_t row_start = crow_acc[m]; + int64_t row_end = crow_acc[m + 1]; + sum += row_end - row_start; + if (sum > t * thread_averge_payload) { + thread_splits[t] = m; + t++; + } + } + // need to restore the last index, + // due to rounding error when calculating `thread_averge_payload`. + thread_splits[num_threads] = M; + + at::parallel_for(0, num_threads, 1, [&](int64_t cbegin, int64_t cend) { + int tid = at::get_thread_num(); + int64_t begin = thread_splits[tid]; + int64_t end = thread_splits[tid + 1]; + f(begin, end); + }); +} + +} // namespace utils + +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/zmath.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/zmath.h new file mode 100644 index 0000000000000000000000000000000000000000..2b4f44db085c997f9fdadd49eb078f3cc67a36f2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/zmath.h @@ -0,0 +1,250 @@ +#pragma once + +// Complex number math operations that act as no-ops for other dtypes. +#include +#include +#include + +namespace at::native { +inline namespace CPU_CAPABILITY { + +template +inline VALUE_TYPE zabs (SCALAR_TYPE z) { + return z; +} + +template<> +inline c10::complex zabs > (c10::complex z) { + return c10::complex(std::abs(z)); +} + +template<> +inline float zabs , float> (c10::complex z) { + return std::abs(z); +} + +template<> +inline c10::complex zabs > (c10::complex z) { + return c10::complex(std::abs(z)); +} + +template<> +inline double zabs , double> (c10::complex z) { + return std::abs(z); +} + +// This overload corresponds to non-complex dtypes. +// The function is consistent with its NumPy equivalent +// for non-complex dtypes where `pi` is returned for +// negative real numbers and `0` is returned for 0 or positive +// real numbers. +// Note: `nan` is propagated. +template +inline VALUE_TYPE angle_impl (SCALAR_TYPE z) { + if (at::_isnan(z)) { + return z; + } + return z < 0 ? c10::pi : 0; +} + +template<> +inline c10::complex angle_impl > (c10::complex z) { + return c10::complex(std::arg(z), 0.0); +} + +template<> +inline float angle_impl , float> (c10::complex z) { + return std::arg(z); +} + +template<> +inline c10::complex angle_impl > (c10::complex z) { + return c10::complex(std::arg(z), 0.0); +} + +template<> +inline double angle_impl , double> (c10::complex z) { + return std::arg(z); +} + +template +constexpr VALUE_TYPE real_impl (SCALAR_TYPE z) { + return z; //No-Op +} + +template<> +constexpr c10::complex real_impl > (c10::complex z) { + return c10::complex(z.real(), 0.0); +} + +template<> +constexpr float real_impl , float> (c10::complex z) { + return z.real(); +} + +template<> +constexpr c10::complex real_impl > (c10::complex z) { + return c10::complex(z.real(), 0.0); +} + +template<> +constexpr double real_impl , double> (c10::complex z) { + return z.real(); +} + +template +constexpr VALUE_TYPE imag_impl (SCALAR_TYPE /*z*/) { + return 0; +} + +template<> +constexpr c10::complex imag_impl > (c10::complex z) { + return c10::complex(z.imag(), 0.0); +} + +template<> +constexpr float imag_impl , float> (c10::complex z) { + return z.imag(); +} + +template<> +constexpr c10::complex imag_impl > (c10::complex z) { + return c10::complex(z.imag(), 0.0); +} + +template<> +constexpr double imag_impl , double> (c10::complex z) { + return z.imag(); +} + +template +inline TYPE conj_impl (TYPE z) { + return z; //No-Op +} + +template<> +inline c10::complex conj_impl > (c10::complex z) { + return c10::complex{z.real(), -z.imag()}; +} + +template<> +inline c10::complex conj_impl > (c10::complex z) { + return c10::complex(z.real(), -z.imag()); +} + +template<> +inline c10::complex conj_impl > (c10::complex z) { + return c10::complex(z.real(), -z.imag()); +} + +template +inline TYPE ceil_impl (TYPE z) { + return std::ceil(z); +} + +template <> +inline c10::complex ceil_impl (c10::complex z) { + return c10::complex(std::ceil(z.real()), std::ceil(z.imag())); +} + +template <> +inline c10::complex ceil_impl (c10::complex z) { + return c10::complex(std::ceil(z.real()), std::ceil(z.imag())); +} + +template +inline c10::complex sgn_impl (c10::complex z) { + if (z == c10::complex(0, 0)) { + return c10::complex(0, 0); + } else { + return z / zabs(z); + } +} + +template +inline TYPE floor_impl (TYPE z) { + return std::floor(z); +} + +template <> +inline c10::complex floor_impl (c10::complex z) { + return c10::complex(std::floor(z.real()), std::floor(z.imag())); +} + +template <> +inline c10::complex floor_impl (c10::complex z) { + return c10::complex(std::floor(z.real()), std::floor(z.imag())); +} + +template +inline TYPE round_impl (TYPE z) { + return std::nearbyint(z); +} + +template <> +inline c10::complex round_impl (c10::complex z) { + return c10::complex(std::nearbyint(z.real()), std::nearbyint(z.imag())); +} + +template <> +inline c10::complex round_impl (c10::complex z) { + return c10::complex(std::nearbyint(z.real()), std::nearbyint(z.imag())); +} + +template +inline TYPE trunc_impl (TYPE z) { + return std::trunc(z); +} + +template <> +inline c10::complex trunc_impl (c10::complex z) { + return c10::complex(std::trunc(z.real()), std::trunc(z.imag())); +} + +template <> +inline c10::complex trunc_impl (c10::complex z) { + return c10::complex(std::trunc(z.real()), std::trunc(z.imag())); +} + +template ::value, int> = 0> +inline TYPE max_impl (TYPE a, TYPE b) { + if (_isnan(a) || _isnan(b)) { + return std::numeric_limits::quiet_NaN(); + } else { + return std::max(a, b); + } +} + +template ::value, int> = 0> +inline TYPE max_impl (TYPE a, TYPE b) { + if (_isnan(a)) { + return a; + } else if (_isnan(b)) { + return b; + } else { + return std::abs(a) > std::abs(b) ? a : b; + } +} + +template ::value, int> = 0> +inline TYPE min_impl (TYPE a, TYPE b) { + if (_isnan(a) || _isnan(b)) { + return std::numeric_limits::quiet_NaN(); + } else { + return std::min(a, b); + } +} + +template ::value, int> = 0> +inline TYPE min_impl (TYPE a, TYPE b) { + if (_isnan(a)) { + return a; + } else if (_isnan(b)) { + return b; + } else { + return std::abs(a) < std::abs(b) ? a : b; + } +} + +} // end namespace +} //end at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Activation.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Activation.h new file mode 100644 index 0000000000000000000000000000000000000000..5fbfe0d2c65569522dfbf878cc82b5ac66c3c4ad --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Activation.h @@ -0,0 +1,20 @@ +#pragma once +#include +#include + +namespace at { +struct TensorIteratorBase; +class TensorBase; +} + +namespace at { namespace native { + +void launch_glu_backward_kernel(const TensorIteratorBase& iter, + int64_t gI_stride, int64_t I_stride); + +void launch_log_sigmoid_forward_kernel(TensorIteratorBase& iter); + +void GeluCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate); +void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate); + +}} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/BinaryInternal.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/BinaryInternal.h new file mode 100644 index 0000000000000000000000000000000000000000..e098d32b114d604f6d9a1b5160dbe87de52c4595 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/BinaryInternal.h @@ -0,0 +1,48 @@ +// DON'T include this except from Binary*.cu files. It should not leak into +// headers. +#pragma once +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at { +namespace native { +namespace binary_internal { + +template +struct DivFunctor { + __device__ scalar_t operator()(scalar_t a, scalar_t b) const { + return a / b; + } +}; + +template +struct MulFunctor { + __device__ T operator()(T a, T b) const { + return a * b; + } +}; + +// Workaround for the error: '*' in boolean context, suggest '&&' instead +// [-Werror=int-in-bool-context] +template <> +struct MulFunctor { + __device__ bool operator()(bool a, bool b) const { + return a && b; + } +}; +void div_true_kernel_cuda(TensorIteratorBase& iter); +void div_trunc_kernel_cuda(TensorIteratorBase& iter); +} // namespace binary_internal +} // namespace native +} // namespace at diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CUDAJitLoops.cuh b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CUDAJitLoops.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e764cc4ce803905a88363f00291a2066d76bb274 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CUDAJitLoops.cuh @@ -0,0 +1,296 @@ +#pragma once +#include + +// Jiterator functions are guarded behind this macro +#if AT_USE_JITERATOR() + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +#include +#include +#include +#include + +namespace at { +namespace native { + +template +constexpr auto tuple_to_array_helper(Tuple& t, std::index_sequence seq) { + constexpr auto size = seq.size(); + (void)t; // warning : unused parameter when tuple is empty. + return std::array{static_cast(&std::get(t))...}; +} + +// Helper function convert tuple to std::array +// for passing the arguments to CUDA Kernel +// NOTE: We capture tuple by reference, +// so the pointers in returned array are only valid +// till tuple is alive. +template +constexpr auto tuple_to_array(std::tuple& extra_args) { + constexpr auto tuple_size = sizeof...(Args); + return tuple_to_array_helper(extra_args, std::make_index_sequence{}); +} + +struct JittedVecKernelCache { + // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements) + at::cuda::jit::NvrtcFunction vec1; + at::cuda::jit::NvrtcFunction vec2; + at::cuda::jit::NvrtcFunction vec4; +}; + +struct JittedKernelVariantCache { + JittedVecKernelCache vec; + at::cuda::jit::NvrtcFunction noncontiguous; + at::cuda::jit::NvrtcFunction dynamic_contiguous; + at::cuda::jit::NvrtcFunction dynamic_noncontiguous; +}; + +inline c10::SmallBuffer pack_kernel_args( + std::initializer_list args, + c10::ArrayRef extra_args) { + c10::SmallBuffer ret(args.size() + extra_args.size()); + std::copy(args.begin(), args.end(), ret.data()); + std::copy(extra_args.begin(), extra_args.end(), ret.data() + args.size()); + return ret; +} + +template +void launch_jitted_unrolled_kernel( + std::mutex &jiterator_mutex, + at::cuda::jit::NvrtcFunction &fn_cache, + const at::cuda::jit::KernelDescriptor &desc, + int64_t N, + array_t data, + inp_calc_t ic, + out_calc_t oc, + loader_t l, + storer_t s, + bool contiguous, + at::cuda::jit::BinaryFuncVariant scalar_pos, + void* scalar_val, + c10::ArrayRef extra_args) { + + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + //casting result to int is always safe, intermediate is int64 and won't overflow + const uint32_t grid = (N + block_work_size() - 1) / block_work_size(); + + if (!fn_cache.function) { + const std::lock_guard lock{jiterator_mutex}; + if (!fn_cache.function) { + constexpr bool dynamic_casting = !std::is_same() || + !std::is_same(); + auto code = at::cuda::jit::generate_code( + desc, contiguous, dynamic_casting, scalar_pos); + fn_cache = at::cuda::jit::jit_pwise_function(code, desc.name); + } + } + + auto args = pack_kernel_args({&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args); + at::cuda::jit::launch_jitted_pwise_function(fn_cache, args.data(), {grid, 1u, 1u}, + {num_threads(), 1u, 1u}); +} + +template +void launch_jitted_vectorized_kernel( + std::mutex &jiterator_mutex, JittedVecKernelCache &fn_cache, + const at::cuda::jit::KernelDescriptor &desc, int64_t N, array_t data, + at::cuda::jit::BinaryFuncVariant scalar_pos, + void *scalar_val, c10::ArrayRef extra_args) { + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + // N is still int64_t for the computation, but it's always safe to cast result to int + const uint32_t grid = (N + block_work_size() - 1) / block_work_size(); + const int vec_size = at::cuda::jit::can_vectorize_up_to( + desc, c10::ArrayRef(data.data, data.size())); + + // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements) + // fn_ptr is set to the appropriate function based on the vec size and GPU used + at::cuda::jit::NvrtcFunction* fn_ptr; + if (vec_size == 4) { + fn_ptr = &fn_cache.vec4; + } else if (vec_size == 2) { + fn_ptr = &fn_cache.vec2; + } else if (vec_size ==1) { + fn_ptr = &fn_cache.vec1; + } else { + TORCH_INTERNAL_ASSERT(false, "unexpected vec_size for jitter vectorized kernel"); + } + + bool vectorized = vec_size > 1; + + if (!fn_ptr->function) { + const std::lock_guard lock{jiterator_mutex}; + if (!fn_ptr->function) { // cache miss! + + // Generates program + auto code = at::cuda::jit::generate_code( + desc, /*contiguous=*/true, /*dynamic_casting=*/false, + scalar_pos, vectorized, vec_size); + std::string kernel_name = vectorized ? desc.name + "_vectorized" + std::to_string(vec_size) : desc.name; + + // Acquires the program + *fn_ptr = at::cuda::jit::jit_pwise_function(code, kernel_name); + } + } + + if (vectorized) { + auto args = pack_kernel_args({&N, &data, scalar_val}, extra_args); + at::cuda::jit::launch_jitted_pwise_function( + *fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u}); + } else { +// NVCC complains about unused variables l and s. +// It should be false positive in most cases, so we suppress the warnings. +#pragma nv_diagnostic push +#pragma nv_diag_suppress 177 + auto ic = TrivialOffsetCalculator(); + auto oc = TrivialOffsetCalculator<1>(); + auto l = memory::LoadWithoutCast(); + auto s = memory::StoreWithoutCast(); + + auto args = pack_kernel_args( + {&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args); + at::cuda::jit::launch_jitted_pwise_function( + *fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u}); +#pragma nv_diagnostic pop + } +} + +template +void jitted_gpu_kernel_generic( + std::mutex &jiterator_mutex, + JittedKernelVariantCache &cache, + const at::cuda::jit::KernelDescriptor &desc, + at::cuda::jit::BinaryFuncVariant scalar_pos, + c10::ArrayRef extra_args, + TensorIteratorBase& iter, + const bool dynamic_casting, + void *scalar_val) { + TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing()); + TORCH_INTERNAL_ASSERT(iter.ninputs() == arity); + TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); + + constexpr int ntensors = arity + 1; + at::detail::Array data; + for (auto i : c10::irange(ntensors)) { + data[i] = (char*)iter.data_ptr(i); + } + + int64_t numel = iter.numel(); + bool contiguous = iter.is_contiguous(); + + // Decides which of 4 kernel types to launch + // Variations are: + // - Case 1: no dynamic casting and contiguous + // - Case 2: no dynamic casting and noncontiguous + // - Case 3: dynamic casting and contiguous + // - Case 4: dynamic casting and noncontiguous + // These cases align with the non-jitted CUDALoops.cuh cases in gpu_kernel_impl + + if (!dynamic_casting) { + if (contiguous) { + // Case 1: no dynamic casting and contiguous + launch_jitted_vectorized_kernel( + jiterator_mutex, cache.vec, desc, + numel, data, scalar_pos, scalar_val, extra_args); + return; + } + + // Case 2: no dynamic casting and noncontiguous + auto input_offset_calculator = make_input_offset_calculator(iter); + auto output_offset_calculator = make_output_offset_calculator(iter); + auto loader = memory::LoadWithoutCast(); + auto storer = memory::StoreWithoutCast(); + launch_jitted_unrolled_kernel( + jiterator_mutex, cache.noncontiguous, desc, numel, data, + input_offset_calculator, output_offset_calculator, loader, + storer, contiguous, scalar_pos, scalar_val, extra_args); + return; + } + + // Cases 3 and 4 are handled below + // Both require construction of a storer (this asserts 1 output) and one or more loaders + + // Creates store cast to output (the zeroth tensor in TensorIterator) + auto storer = memory::StoreWithCast<1>(iter); + + // Creates load casts from inputs (note offset indexing into the iterators 1...n tensors) + auto loader = memory::LoadWithCast(iter); + + if (contiguous) { + // Case 3: dynamic casting and contiguous + auto input_offset_calculator = TrivialOffsetCalculator(); + auto output_offset_calculator = TrivialOffsetCalculator<1>(); + launch_jitted_unrolled_kernel( + jiterator_mutex, cache.dynamic_contiguous, desc, numel, data, input_offset_calculator, + output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args); + return; + } + + // Case 4: dynamic casting and noncontiguous + auto input_offset_calculator = make_input_offset_calculator(iter); + auto output_offset_calculator = make_output_offset_calculator(iter); + launch_jitted_unrolled_kernel( + jiterator_mutex, cache.dynamic_noncontiguous, desc, numel, data, input_offset_calculator, + output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args); +} + +// NOTE: static to reduce chances of name collision. +template < + char const* name, + typename result_type, + typename f_inputs_type, + int arity, + at::cuda::jit::BinaryFuncVariant scalar_pos = + at::cuda::jit::BinaryFuncVariant::NoScalar, + typename... ExtraArgs> +static void jitted_gpu_kernel_impl( + TensorIteratorBase& iter, + const std::string &f, + const bool dynamic_casting, + at::opmath_type scalar_val, + std::tuple extra_args) { + + // TODO: Memory use can probably be optimized by re-using kernels across GPUs with + // the same compute capability + static std::mutex jiterator_mutex; + static std::vector device_caches(c10::cuda::device_count()); + + constexpr int nInputs = arity; + constexpr int nOutputs = 1; // TODO: Support more than 1 output + static const auto desc = at::cuda::jit::make_kernel_descriptor< + result_type, f_inputs_type, ExtraArgs...>(name, f, nInputs, nOutputs); + + auto &cache = device_caches[iter.device().index()]; + auto extra_args_array = tuple_to_array(extra_args); + return jitted_gpu_kernel_generic( + jiterator_mutex, + cache, + desc, + scalar_pos, + extra_args_array, + iter, + dynamic_casting, + &scalar_val + ); +} + +}} // at::native + +#endif // AT_USE_JITERATOR() diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CUDALoops.cuh b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CUDALoops.cuh new file mode 100644 index 0000000000000000000000000000000000000000..b8eb85fd4eb2eec771759f5de11e16f934b31437 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CUDALoops.cuh @@ -0,0 +1,348 @@ +#pragma once + +// This file provides two functions to help write GPU elementwise kernels: +// +// gpu_kernel(TensorIterator iter, ) +// gpu_kernel_with_scalars(TensorIterator iter, ) +// +// The gpu_kernel_with_scalars generates specializations that support a +// single scalar CPU argument, such as from `cuda_tensor + 5`. The CPU scalar +// is lifted to a kernel parameter instead of copying to device memory. +// This should be used in conjunction with TensorIterator::allow_cpu_scalars_, +// which is the default for TensorIterator::binary_op. Otherwise, all inputs +// and the output must be on the GPU. +// +// For example, to write a reciprocal kernel for GPU float Tensors: +// +// gpu_kernel(iter, []GPU_LAMBDA(float a) { +// return 1.0f / a; +// }); +// +// To write a multiplication kernel for GPU float Tensors where one argument +// may be a CPU scalar: +// +// gpu_kernel_with_scalars(iter, []GPU_LAMBDA(float a, float b) { +// return a * b; +// }); +// +// See BinaryOpsKernel.cu for the complete implementation +// + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef __NVCC__ +#define ASSERT_HOST_DEVICE_LAMBDA(type) \ + static_assert( \ + __nv_is_extended_host_device_lambda_closure_type(type), \ + #type " must be a __host__ __device__ lambda") +#else +#define ASSERT_HOST_DEVICE_LAMBDA(type) +#endif + +namespace at { +namespace native { + +template +C10_LAUNCH_BOUNDS_1(num_threads()) +__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) { + using traits = function_traits; + int remaining = N - block_work_size() * blockIdx.x; + + if (remaining < block_work_size()) { // if this block handles the reminder, + // just do a naive unrolled loop + auto input_calc = TrivialOffsetCalculator(); + auto output_calc = TrivialOffsetCalculator<1>(); + auto loader = memory::LoadWithoutCast(); + auto storer = memory::StoreWithoutCast(); + auto policy = memory::policies::unroll< + array_t, + decltype(input_calc), + decltype(output_calc), + memory::LoadWithoutCast, + memory::StoreWithoutCast>( + data, remaining, input_calc, output_calc, loader, storer); + elementwise_kernel_helper(f, policy); + } else { // if this block has a full `block_work_size` data to handle, use + // vectorized memory access + elementwise_kernel_helper( + f, memory::policies::vectorized(data)); + } +} + +template < + typename func_t, + typename array_t, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t> +C10_LAUNCH_BOUNDS_1(num_threads()) +__global__ void unrolled_elementwise_kernel( + int N, + func_t f, + array_t data, + inp_calc_t ic, + out_calc_t oc, + loader_t l, + storer_t s) { + int remaining = N - block_work_size() * blockIdx.x; + auto policy = memory::policies:: + unroll( + data, remaining, ic, oc, l, s); + elementwise_kernel_helper(f, policy); +} + +// this function assume trivial 1d and no dynamic casting +template +static inline void launch_vectorized_kernel( + int64_t N, + const func_t& f, + array_t data) { + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + using traits = function_traits; + int64_t grid = (N + block_work_size() - 1) / block_work_size(); + auto stream = at::cuda::getCurrentCUDAStream(); + int vec_size = memory::can_vectorize_up_to(data); + + switch (vec_size) { + case 4: + vectorized_elementwise_kernel<4, func_t, array_t> + <<>>(N, f, data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + case 2: + vectorized_elementwise_kernel<2, func_t, array_t> + <<>>(N, f, data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + case 1: { + auto input_calc = TrivialOffsetCalculator(); + auto output_calc = TrivialOffsetCalculator<1>(); + auto loader = memory::LoadWithoutCast(); + auto storer = memory::StoreWithoutCast(); + unrolled_elementwise_kernel + <<>>( + N, f, data, input_calc, output_calc, loader, storer); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + } + default: + TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size"); + } +} + +template < + typename func_t, + typename array_t, + typename inp_calc_t, + typename out_calc_t, + typename loader_t, + typename storer_t> +static inline void launch_unrolled_kernel( + int64_t N, + const func_t& f, + array_t data, + inp_calc_t ic, + out_calc_t oc, + loader_t l, + storer_t s) { + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + int64_t grid = (N + block_work_size() - 1) / block_work_size(); + auto stream = at::cuda::getCurrentCUDAStream(); + unrolled_elementwise_kernel + <<>>(N, f, data, ic, oc, l, s); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +C10_LAUNCH_BOUNDS_2(nt, 4) +__global__ void elementwise_kernel(int N, func_t f) { + int tid = threadIdx.x; + int nv = nt * vt; + int idx = nv * blockIdx.x + tid; +#pragma unroll + for (int i = 0; i < vt; i++) { + if (idx < N) { + f(idx); + idx += nt; + } + } +} + +template +static void launch_legacy_kernel(int64_t N, const func_t& f) { + TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits::max()); + if (N == 0) { + return; + } + dim3 block(nt); + dim3 grid((N + block.x * vt - 1) / (block.x * vt)); + auto stream = at::cuda::getCurrentCUDAStream(); + elementwise_kernel<<>>(N, f); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +C10_HOST_DEVICE typename traits::result_type invoke_impl( + const func_t& f, + char* const C10_RESTRICT data[], + const index_t strides[], + int i, + std::index_sequence) { + (void)strides; + (void)i; + return f(c10::load::type>( + data[INDEX] + i * strides[INDEX])...); +} + +template < + typename func_t, + typename index_t, + typename traits = function_traits> +C10_HOST_DEVICE typename traits::result_type invoke( + const func_t& f, + char* const C10_RESTRICT data[], + const index_t strides[], + int i) { + using Indices = std::make_index_sequence; + return invoke_impl(f, data, strides, i, Indices{}); +} + +template +C10_HOST_DEVICE typename traits::result_type invoke_impl( + const func_t& f, + char* const C10_RESTRICT data[], + const index_t strides[], + const ScalarType dtypes[], + int i, + std::index_sequence) { + (void)strides; + (void)i; + return f(c10::fetch_and_cast::type>( + dtypes[I], data[I] + i * strides[I])...); +} + +template < + typename func_t, + typename index_t, + typename traits = function_traits> +C10_HOST_DEVICE typename traits::result_type invoke( + const func_t& f, + char* const C10_RESTRICT data[], + const index_t strides[], + const ScalarType dtypes[], + int i) { + using Indices = std::make_index_sequence; + return invoke_impl(f, data, strides, dtypes, i, Indices{}); +} + +template +void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) { + using traits = function_traits; + using arg0_t = typename traits::result_type; + constexpr int ntensors = traits::arity + 1; + + TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing()); + TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); + TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); + TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter)); + + at::detail::Array data; + for (int i = 0; i < ntensors; i++) { + data[i] = (char*)iter.data_ptr(i); + } + + int64_t numel = iter.numel(); + + bool contiguous = iter.is_contiguous(); + + if (contiguous) { + return launch_vectorized_kernel(numel, f, data); + } + auto offset_calc = ::make_offset_calculator(iter); + constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 2 : 4; + launch_legacy_kernel<128, unroll_factor>(numel, [=] GPU_LAMBDA(int idx) { + auto offsets = offset_calc.get(idx); + arg0_t* out = (arg0_t*)(data[0] + offsets[0]); + *out = invoke(f, &data.data[1], &offsets.data[1], 1); + }); +} + +template +void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) { + if (!needs_dynamic_casting::check(iter)) { + return gpu_kernel_impl_nocast(iter, f); + } + using traits = function_traits; + using arg0_t = typename traits::result_type; + constexpr int ntensors = traits::arity + 1; + + TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing()); + TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); + TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); + + at::detail::Array data; + for (int i = 0; i < ntensors; i++) { + data[i] = (char*)iter.data_ptr(i); + } + + int64_t numel = iter.numel(); + + bool contiguous = iter.is_contiguous(); + + if (contiguous) { +#ifdef USE_ROCM + at::detail::Array dtypes; + auto inner_strides = iter.get_inner_strides(); + at::detail::Array strides; + for (int i = 0; i < ntensors; i++) { + dtypes[i] = iter.dtype(i); + strides[i] = inner_strides[i]; + } + launch_legacy_kernel<512, 1>(numel, [=]GPU_LAMBDA(int idx) { + void* out = data[0] + strides[0] * idx; + arg0_t result = invoke(f, &data.data[1], &strides.data[1], &dtypes.data[1], idx); + c10::cast_and_store(dtypes[0], out, result); + }); +#else + auto loader = memory::LoadWithCast(iter); + auto storer = memory::StoreWithCast<1>(iter); + auto input_offset_calculator = TrivialOffsetCalculator(); + auto output_offset_calculator = TrivialOffsetCalculator<1>(); + launch_unrolled_kernel( + numel, + f, + data, + input_offset_calculator, + output_offset_calculator, + loader, + storer); +#endif + } else { + at::detail::Array dtypes; + for (int i = 0; i < ntensors; i++) { + dtypes[i] = iter.dtype(i); + } + auto offset_calc = ::make_offset_calculator(iter); + launch_legacy_kernel<128, 4>(numel, [=] GPU_LAMBDA(int idx) { + auto offsets = offset_calc.get(idx); + void* out = data[0] + offsets[0]; + arg0_t result = invoke(f, &data.data[1], &offsets.data[1], &dtypes.data[1], 1); + c10::cast_and_store(dtypes[0], out, result); + }); + } +} + +} // namespace native +} // namespace at diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CompositeRandomAccessor.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CompositeRandomAccessor.h new file mode 100644 index 0000000000000000000000000000000000000000..d47a7fa776f1b681b26dc5ec8b4548604d359946 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CompositeRandomAccessor.h @@ -0,0 +1,35 @@ +#pragma once + +#include +#include + +namespace at { namespace native { + +struct TupleInfoCPU { + template + using tuple = thrust::tuple; + + template + static constexpr auto tie(Types&... args) noexcept { + return thrust::tie(args...); + } +}; + +template +using CompositeRandomAccessorCPU = + CompositeRandomAccessor; + +template +void swap( + references_holder rh1, + references_holder rh2 +) { + return thrust::swap(rh1.data(), rh2.data()); +} + +template +auto get(references_holder rh) -> decltype(thrust::get(rh.data())) { + return thrust::get(rh.data()); +} + +}} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Copy.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Copy.h new file mode 100644 index 0000000000000000000000000000000000000000..5639567d666686dd81ca5b4b032fb44f039eb782 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Copy.h @@ -0,0 +1,10 @@ +#pragma once + +namespace at { +struct TensorIteratorBase; + +namespace native { + +void direct_copy_kernel_cuda(TensorIteratorBase &iter); + +}} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CuFFTUtils.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CuFFTUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..4b02f914d7e20ff914e248d203be3f9434bacb3b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CuFFTUtils.h @@ -0,0 +1,73 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace at { namespace native { + +// This means that max dim is 3 + 2 = 5 with batch dimension and possible +// complex dimension +constexpr int max_rank = 3; + +static inline std::string _cudaGetErrorEnum(cufftResult error) +{ + switch (error) + { + case CUFFT_SUCCESS: + return "CUFFT_SUCCESS"; + case CUFFT_INVALID_PLAN: + return "CUFFT_INVALID_PLAN"; + case CUFFT_ALLOC_FAILED: + return "CUFFT_ALLOC_FAILED"; + case CUFFT_INVALID_TYPE: + return "CUFFT_INVALID_TYPE"; + case CUFFT_INVALID_VALUE: + return "CUFFT_INVALID_VALUE"; + case CUFFT_INTERNAL_ERROR: + return "CUFFT_INTERNAL_ERROR"; + case CUFFT_EXEC_FAILED: + return "CUFFT_EXEC_FAILED"; + case CUFFT_SETUP_FAILED: + return "CUFFT_SETUP_FAILED"; + case CUFFT_INVALID_SIZE: + return "CUFFT_INVALID_SIZE"; + case CUFFT_UNALIGNED_DATA: + return "CUFFT_UNALIGNED_DATA"; + case CUFFT_INCOMPLETE_PARAMETER_LIST: + return "CUFFT_INCOMPLETE_PARAMETER_LIST"; + case CUFFT_INVALID_DEVICE: + return "CUFFT_INVALID_DEVICE"; + case CUFFT_PARSE_ERROR: + return "CUFFT_PARSE_ERROR"; + case CUFFT_NO_WORKSPACE: + return "CUFFT_NO_WORKSPACE"; + case CUFFT_NOT_IMPLEMENTED: + return "CUFFT_NOT_IMPLEMENTED"; +#if !defined(USE_ROCM) + case CUFFT_LICENSE_ERROR: + return "CUFFT_LICENSE_ERROR"; +#endif + case CUFFT_NOT_SUPPORTED: + return "CUFFT_NOT_SUPPORTED"; + default: + std::ostringstream ss; + ss << "unknown error " << error; + return ss.str(); + } +} + +static inline void CUFFT_CHECK(cufftResult error) +{ + if (error != CUFFT_SUCCESS) { + std::ostringstream ss; + ss << "cuFFT error: " << _cudaGetErrorEnum(error); + AT_ERROR(ss.str()); + } +} + +}} // at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/DeviceSqrt.cuh b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/DeviceSqrt.cuh new file mode 100644 index 0000000000000000000000000000000000000000..38a7804015be1822f4012f74319a459daeb5e885 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/DeviceSqrt.cuh @@ -0,0 +1,25 @@ +#pragma once + +namespace at { namespace native { +#if defined(USE_ROCM) +// take these out when ROCm implements std:: math functions +#include +template +static __forceinline__ __device__ scalar_t device_sqrt(scalar_t val); + +template <> +__forceinline__ __device__ float device_sqrt(float val) { + return ::sqrtf(val); +} + +template <> +__forceinline__ __device__ double device_sqrt(double val) { + return ::sqrt(val); +} +#else +template +__forceinline__ __device__ double device_sqrt(scalar_t val) { + return std::sqrt(val); +} +#endif +}} diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/DistributionTemplates.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/DistributionTemplates.h new file mode 100644 index 0000000000000000000000000000000000000000..b30dcb60ffe564d33e0f8443a06eeec50e0d4b17 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/DistributionTemplates.h @@ -0,0 +1,671 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace { + +// launch bounds used for kernels utilizing TensorIterator +const uint32_t block_size_bound = 256; +const uint32_t grid_size_bound = 4; +// At the time of writing, there is no curand_* call that increments the offset by more than 4. +// See: https://docs.nvidia.com/cuda/archive/11.8.0/curand/group__DEVICE.html +const uint32_t max_generator_offsets_per_curand_call = 4; + +// utility function that calculates proper philox_offset +// for distributions utilizing TensorIterator. For distributions using +// TensorIterator, we are using a grid-stride loop with each +// thread yielding one element per thread. For the edge of the grid-stride +// loop, if the tensor size is large, the unroll loop will kick in and the float4 +// from curand4 will start getting utilized (for common tensor sizes, we end up +// using rand.x from each thread). The philox_offset calculation was changed to +// (number of elements per thread * maximum generator increment per "curand_*" call), which makes +// sure that philox offset increment is not less than the number of randoms used +// in each thread. +std::tuple calc_execution_policy(const int64_t total_elements, const uint32_t unroll_factor) { + const uint64_t numel = static_cast(total_elements); + const uint32_t block_size = block_size_bound; + dim3 dim_block(block_size); + dim3 grid((numel + block_size - 1) / block_size); + uint32_t blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size; + grid.x = std::min( + static_cast(at::cuda::getCurrentDeviceProperties()->multiProcessorCount) * blocks_per_sm, + grid.x); + //number of times random will be generated per thread, to offset philox counter in thc random state + uint64_t counter_offset = ((numel - 1) / (block_size * grid.x * unroll_factor) + 1) * max_generator_offsets_per_curand_call; + return std::make_tuple(counter_offset, grid, dim_block); +} + +// grid stride loop kernel for distributions +template +C10_LAUNCH_BOUNDS_2(block_size_bound, grid_size_bound) +__global__ void distribution_elementwise_grid_stride_kernel(int numel, + PhiloxCudaState philox_args, + const dist_t dist_func, + const transform_t transform_func) { + auto seeds = at::cuda::philox::unpack(philox_args); + int idx = blockIdx.x * blockDim.x + threadIdx.x; + curandStatePhilox4_32_10_t state; + curand_init(std::get<0>(seeds), + idx, + std::get<1>(seeds), + &state); + + int rounded_size = ((numel - 1)/(blockDim.x * gridDim.x * unroll_factor)+1) * + blockDim.x * gridDim.x * unroll_factor; + for(int linear_index = idx; linear_index < rounded_size; linear_index += blockDim.x * gridDim.x * unroll_factor) { + auto rand = dist_func(&state); + #pragma unroll + for (int ii = 0; ii < unroll_factor; ii++) { + int li = linear_index + blockDim.x * gridDim.x * ii; + if (li < numel) { + transform_func(li, static_cast((&rand.x)[ii])); + } + } + __syncthreads(); + } +} + +/** + * distribution_nullary_kernel is analogous to gpu_kernel in + * ATen/native/cuda/Loops.cuh. Like gpu_kernel, it uses + * TensorIterator to launch a kernel. However, the differences are + * - it launches a grid-stride loop based kernel. The kernel is not + * generic like elementwise_kernel in Loops.cuh and is specialized + * for the distribution kernels here. + * - For big size tensors, we can launch multiple kernels recursively + * (i.e. if (!iter.can_use_32bit_indexing())) and hence, the philox + * offset calculation is done in this function. + * + * FIXME: Can we specialize elementwise_kernel and launch_kernel in Loops.cuh + * to have grid-stride loop kernel and then use that to launch our distribution + * kernels? Note that we need a grid-stride loop kernel because, we found by testing + * that it achieves peak effective bandwidth. + */ +template +void distribution_nullary_kernel(at::TensorIteratorBase& iter, + RNG gen, + const dist_t& dist_func, + const transform_t transform_func) { + const int unroll_factor = sizeof(dist_func_return_t) / sizeof(accscalar_t); + TORCH_CHECK(unroll_factor >= 1, "unroll_factor must be >= 1."); + int64_t numel = iter.numel(); + if (numel == 0) { + return; + } + + auto execution_policy = calc_execution_policy(numel, unroll_factor); + auto counter_offset = std::get<0>(execution_policy); + auto grid = std::get<1>(execution_policy); + auto block = std::get<2>(execution_policy); + PhiloxCudaState rng_engine_inputs; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_cuda_state(counter_offset); + } + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + distribution_nullary_kernel(sub_iter, + gen, dist_func, transform_func); + } + return; + } + + char* out_data = (char*)iter.data_ptr(0); + + auto stream = at::cuda::getCurrentCUDAStream(); + if (iter.is_trivial_1d()) { + auto strides = iter.get_inner_strides(); + int stride0 = strides[0]; + distribution_elementwise_grid_stride_kernel<<>>( + numel, + rng_engine_inputs, + dist_func, + [=]__device__(int idx, accscalar_t rand) { + scalar_t* out = (scalar_t*)&out_data[stride0 * idx]; + *out = transform_func(rand); + } + ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + auto offset_calc = make_offset_calculator<1>(iter); + distribution_elementwise_grid_stride_kernel<<>>( + numel, + rng_engine_inputs, + dist_func, + [=]__device__(int idx, accscalar_t rand) { + auto offsets = offset_calc.get(idx); + scalar_t* out = (scalar_t*)&out_data[offsets[0]]; + *out = transform_func(rand); + } + ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } +} + +// Binary kernel +template +__global__ void distribution_binary_elementwise_kernel( + int numel, + func_t f, + PhiloxCudaState philox_args, + typename function_traits::result_type *output_data, + const typename function_traits::template arg<1>::type *input_data_1, + const typename function_traits::template arg<2>::type *input_data_2, + inp_offset_calc_t inp_calc, + out_offset_calc_t out_calc) { + auto seeds = at::cuda::philox::unpack(philox_args); + + using input_t_1 = typename function_traits::template arg<1>::type; + using input_t_2 = typename function_traits::template arg<2>::type; + + input_t_1 inputs_1[thread_work_size()]; + input_t_2 inputs_2[thread_work_size()]; + + int base_index = block_work_size() * blockIdx.x; + int remaining = std::min(numel - base_index, block_work_size()); + + curandStatePhilox4_32_10_t state; + curand_init(std::get<0>(seeds), + blockIdx.x * blockDim.x + threadIdx.x, + std::get<1>(seeds), + &state); + + // load data into registers + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < thread_work_size(); i++) { + if (thread_idx >= remaining) { + break; + } + int input_idx = thread_idx + base_index; + auto offsets = inp_calc.get(input_idx); + inputs_1[i] = input_data_1[offsets[0]]; + inputs_2[i] = input_data_2[offsets[1]]; + + thread_idx += num_threads(); + } + + // compute and store + thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < thread_work_size(); i++) { + if (thread_idx >= remaining) { + break; + } + int input_idx = thread_idx + base_index; + auto offsets = out_calc.get(input_idx); + output_data[offsets[0]] = f(state, inputs_1[i], inputs_2[i]); + thread_idx += num_threads(); + } +} + +template +void distribution_binary_kernel(TensorIteratorBase &iter, PhiloxCudaState philox_args, const func_t &f) { + static_assert(std::is_same::template arg<0>::type, curandStatePhilox4_32_10_t&>::value, "the first argument of functor must be curandStatePhilox4_32_10_t"); + using input_t_1 = typename function_traits::template arg<1>::type; + using input_t_2 = typename function_traits::template arg<2>::type; + using output_t = typename function_traits::result_type; + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + distribution_binary_kernel(sub_iter, philox_args, f); + } + return; + } + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(iter.can_use_32bit_indexing()); + + int64_t numel = iter.numel(); + if (numel == 0) { + return; + } + + output_t *output_data = static_cast(iter.data_ptr(0)); + const input_t_1 *input_data_1 = static_cast(iter.data_ptr(1)); + const input_t_2 *input_data_2 = static_cast(iter.data_ptr(2)); + + int64_t grid = (numel + block_work_size() - 1) / block_work_size(); + auto stream = at::cuda::getCurrentCUDAStream(); + + if (iter.is_contiguous()) { + distribution_binary_elementwise_kernel<<>>( + numel, f, philox_args, output_data, input_data_1, input_data_2, + TrivialOffsetCalculator<2>(), TrivialOffsetCalculator<1>()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + distribution_binary_elementwise_kernel<<>>( + numel, f, philox_args, output_data, input_data_1, input_data_2, + make_input_offset_calculator<2>(iter), make_output_offset_calculator(iter)); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } +} + +} // namespace +}} // namespace at::native + + +namespace at { +namespace native { +namespace templates { +namespace cuda { + +// ==================================================== Random ======================================================== + +template +void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG gen) { + AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cuda", AT_WRAP([&] { + if (( + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) && range >= 1ULL << 32) + { + // define lambda to mod with range and add base + auto random_func = [range, base] __device__ (uint64_t rand) { + return transformation::uniform_int_from_to(rand, range, base); + }; + distribution_nullary_kernel(iter, + gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 { + ulonglong2 ret; + uint4 rand_val = curand4(state); + ret.x = (static_cast(rand_val.x) << 32) | rand_val.y; + ret.y = (static_cast(rand_val.z) << 32) | rand_val.w; + return ret; + }, + random_func); + } else { + auto random_func = [range, base] __device__ (uint32_t rand) { + return transformation::uniform_int_from_to(rand, range, base); + }; + distribution_nullary_kernel(iter, + gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> uint4 { + return curand4(state); + }, + random_func); + } + }), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); +} + +// This is the special kernel to handle single specific case: +// from(inclusive) = std::numeric_limits::lowest() +// to(exclusive) = None (= std::numeric_limits::max() + 1) +template +void random_full_64_bits_range_kernel(TensorIteratorBase& iter, RNG gen) { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, iter.dtype(), "random_full_64_bits_range_kernel_cuda", [&] { + if (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) { + auto random_func = [] __device__ (uint64_t rand) { + return transformation::uniform_int_full_range(rand); + }; + distribution_nullary_kernel(iter, + gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 { + ulonglong2 ret; + uint4 rand_val = curand4(state); + ret.x = (static_cast(rand_val.x) << 32) | rand_val.y; + ret.y = (static_cast(rand_val.z) << 32) | rand_val.w; + return ret; + }, + random_func); + } else { + TORCH_CHECK(false, "random_full_64_bits_range_kernel_cuda handles only int64, double, float and bfloat16"); + } + }); +} + +template +struct RandomFromToKernel { + void operator()(TensorIteratorBase& iter, uint64_t range, int64_t base, std::optional gen) { + random_from_to_kernel(iter, range, base, check_generator(gen)); + } + void operator()(TensorIteratorBase& iter, std::optional gen) { + random_full_64_bits_range_kernel(iter, check_generator(gen)); + } +}; + +template +void random_kernel(TensorIteratorBase& iter, RNG gen) { + AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cuda", [&] { + if (std::is_same::value || std::is_same::value) { + auto random_func = [] __device__ (uint64_t rand) { + return transformation::uniform_int(rand); + }; + distribution_nullary_kernel(iter, gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 { + ulonglong2 ret; + uint4 rand_val = curand4(state); + ret.x = (static_cast(rand_val.x) << 32) | rand_val.y; + ret.y = (static_cast(rand_val.z) << 32) | rand_val.w; + return ret; + }, + random_func); + } else { + auto random_func = [] __device__ (uint32_t rand) { + return transformation::uniform_int(rand); + }; + distribution_nullary_kernel(iter, + gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> uint4 { + return curand4(state); + }, + random_func); + } + }); +} + +template +struct RandomKernel { + void operator()(TensorIteratorBase& iter, RNG gen) { + random_kernel(iter, gen); + } +}; + +// ==================================================================================================================== + +template +void uniform_and_transform(TensorIteratorBase& iter, RNG gen, transform_t transform) { + if (std::is_same::value) { + distribution_nullary_kernel(iter, + gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> double2 { return curand_uniform2_double(state); }, + transform); + } else { + distribution_nullary_kernel(iter, + gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> float4 { return curand_uniform4(state); }, + transform); + } +} + +template +void normal_and_transform(TensorIteratorBase& iter, RNG gen, transform_t transform) { + if (std::is_same::value) { + distribution_nullary_kernel(iter, + gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> double2 { return curand_normal2_double(state); }, + transform); + } else { + distribution_nullary_kernel(iter, + gen, + [] __device__ (curandStatePhilox4_32_10_t* state) -> float4 { return curand_normal4(state); }, + transform); + } +} + +// ==================================================== Normal ======================================================== + +template +void normal_kernel(const TensorBase &self, double mean_, double std_, RNG gen) { + auto iter = TensorIterator::borrowing_nullary_op(self); + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "normal_kernel_cuda", [&] { + using accscalar_t = at::acc_type; + auto mean = static_cast(mean_); + auto std = static_cast(std_); + // define lambda to multiply std and add mean + auto normal_func = [mean, std] __device__ (accscalar_t rand) { + return static_cast(transformation::normal(rand, mean, std)); + }; + normal_and_transform(iter, gen, normal_func); + }); +} + +template +struct NormalKernel { + void operator()(const TensorBase &self, double mean, double std, std::optional gen) { + normal_kernel(self, mean, std, check_generator(gen)); + } +}; + +// ==================================================== Uniform ======================================================== + +template +void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG gen) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "uniform_kernel_cuda", [&] { + auto from = static_cast(from_); + auto to = static_cast(to_); + using opmath_t = at::opmath_type; + auto range = static_cast(to-from); + // define lambda to reverse bounds, multiply 'range' and add 'from_' + auto uniform_func = [range, from, to] __device__ (opmath_t rand) { + // Compute output value before reversing the bounds + // BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/96947 + auto value = static_cast(rand * range + from); + // reverse the bounds of curand4 from (0, 1] to [0, 1) + // Note that this method is from legacy THCTensorRandom and is likely to give + // you more 0-s, since, the probability of gettings 1-s is higher than 0-s and + // by reversing the bounds, we are flipping the probabilities of 1-s and 0-s. + // BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/16706 + auto reverse_bound_value = value == to ? from : value; + return reverse_bound_value; + }; + uniform_and_transform(iter, gen, uniform_func); + }); +} + +template +struct UniformKernel { + void operator()(TensorIteratorBase& iter, double from, double to, std::optional gen) { + uniform_kernel(iter, from, to, check_generator(gen)); + } +}; + +// ================================================== LogNormal ======================================================= + +template +void log_normal_kernel(TensorIteratorBase& iter, double mean_, double std_, RNG gen) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "log_normal_cuda", [&] { + using accscalar_t = at::acc_type; + auto mean = static_cast(mean_); + auto std = static_cast(std_); + // define lambda for log_normal transformation + auto log_normal_func = [mean, std] __device__ (accscalar_t rand) { + return static_cast(transformation::log_normal(transformation::normal(rand, mean, std))); + }; + normal_and_transform(iter, gen, log_normal_func); + }); +} + +template +struct LogNormalKernel { + void operator()(TensorIteratorBase& iter, double mean, double std, std::optional gen) { + log_normal_kernel(iter, mean, std, check_generator(gen)); + } +}; + +// =================================================== Geometric ====================================================== + +template +void geometric_kernel(TensorIteratorBase& iter, double p, RNG gen) { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_cuda", [&] { + using accscalar_t = at::DiscreteDistributionType::type; + // define lambda for geometric transformation + auto geometric_func = [p] __device__ (accscalar_t rand) { + return static_cast(transformation::geometric(rand, p)); + }; + uniform_and_transform(iter, gen, geometric_func); + }); +} + +template +struct GeometricKernel { + void operator()(TensorIteratorBase& iter, double p, std::optional gen) { + geometric_kernel(iter, p, check_generator(gen)); + } +}; + +// ================================================== Exponential ===================================================== + +template +void exponential_kernel(TensorIteratorBase& iter, double lambda_, RNG gen) { + TORCH_CHECK(isFloatingType(iter.dtype()), "Exponential distribution is a continuous probability distribution. dtype must be a floating point but you specified ", iter.dtype()); + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_cuda", [&] { + using accscalar_t = at::acc_type; + auto lambda = static_cast(lambda_); + // define lambda for exponential transformation + auto exponential_func = [lambda] __device__ (accscalar_t rand) { + return static_cast(transformation::exponential(rand, lambda)); + }; + uniform_and_transform(iter, gen, exponential_func); + }); +} + +template +struct ExponentialKernel { + void operator()(TensorIteratorBase& iter, double lambda, std::optional gen) { + exponential_kernel(iter, lambda, check_generator(gen)); + } +}; + +// ==================================================== Cauchy ======================================================== + +template +void cauchy_kernel(TensorIteratorBase& iter, double median_, double sigma_, RNG gen) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "cauchy_cuda", [&] { + using accscalar_t = at::acc_type; + auto median = static_cast(median_); + auto sigma = static_cast(sigma_); + // define lambda for cauchy transformation + auto cauchy_func = [median, sigma] __device__ (accscalar_t rand) { + return static_cast(transformation::cauchy(rand, median, sigma)); + }; + uniform_and_transform(iter, gen, cauchy_func); + }); +} + +template +struct CauchyKernel { + void operator()(TensorIteratorBase& iter, double median, double sigma, std::optional gen) { + cauchy_kernel(iter, median, sigma, check_generator(gen)); + } +}; + +// ==================================================== Bernoulli ===================================================== + +template +void bernoulli_tensor_cuda_kernel( + const TensorBase &ret, const at::TensorBase &p, + PhiloxCudaState philox_args) { + auto functor = [philox_args] __device__( + int n, scalar_t& v1, scalar_t& v2, scalar_t& v3, scalar_t& v4, + const prob_t& p1, const prob_t& p2, const prob_t& p3, const prob_t& p4) { + auto seeds = at::cuda::philox::unpack(philox_args); + curandStatePhilox4_32_10_t state; + curand_init(std::get<0>(seeds), + blockIdx.x * blockDim.x + threadIdx.x, + std::get<1>(seeds), + &state); + + // See Note [Register spilling in curand call for CUDA < 10] + float4 rand = curand_uniform4(&state); + switch (n) { + case 4: { + CUDA_KERNEL_ASSERT(0 <= p4 && p4 <= 1); + v4 = static_cast(rand.w <= p4); + [[fallthrough]]; + } + case 3: { + CUDA_KERNEL_ASSERT(0 <= p3 && p3 <= 1); + v3 = static_cast(rand.z <= p3); + [[fallthrough]]; + } + case 2: { + CUDA_KERNEL_ASSERT(0 <= p2 && p2 <= 1); + v2 = static_cast(rand.y <= p2); + [[fallthrough]]; + } + case 1: { + CUDA_KERNEL_ASSERT(0 <= p1 && p1 <= 1); + v1 = static_cast(rand.x <= p1); + } + } + }; + // The template argument `4` below indicates that we want to operate on four + // element at each time. See NOTE [ CUDA_tensor_applyN helpers ] for details. + at::cuda::CUDA_tensor_apply2(ret, p, functor); +} + +template +void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG gen) { + PhiloxCudaState rng_engine_inputs; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_cuda_state(10); + } + TORCH_CHECK(at::isFloatingType(p_.scalar_type()), "expected probabilities tensor to have floating type, got ", p_.scalar_type()); + // cast probabilities tensor to double for double `self` tensor, and to `float` for everything else + const auto p_type = self.dtype() == at::kDouble ? at::kDouble : at::kFloat; + auto p_cuda = p_.to(TensorOptions().device(self.device()).dtype(p_type)); + auto p = expand_inplace(self, p_cuda); + AT_DISPATCH_ALL_TYPES_AND3( + at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, self.scalar_type(), "bernoulli_tensor_cuda_self_", [&] { + if (std::is_same::value) { + return bernoulli_tensor_cuda_kernel(self, *p, rng_engine_inputs); + } else { + return bernoulli_tensor_cuda_kernel(self, *p, rng_engine_inputs); + } + }); +} + +template +void bernoulli_kernel(TensorIteratorBase& iter, double p, RNG gen) { + AT_DISPATCH_ALL_TYPES_AND3( + at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "bernoulli_scalar_cuda_", [&] { + using accscalar_t = at::DiscreteDistributionType::type; + // define lambda for bernoulli transformation + auto bernoulli_func = [p] __device__ (accscalar_t rand) { + return static_cast(transformation::bernoulli(rand, p)); + }; + uniform_and_transform(iter, gen, bernoulli_func); + }); +} + +template +struct BernoulliKernel { + void operator()(TensorIteratorBase& iter, double p, std::optional gen) { + bernoulli_kernel(iter, p, check_generator(gen)); + } + void operator()(const TensorBase &self, const TensorBase &p_, std::optional gen) { + bernoulli_kernel(self, p_, check_generator(gen)); + } +}; + +}}}} diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Distributions.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Distributions.h new file mode 100644 index 0000000000000000000000000000000000000000..1a34fdfdf31494faab439544578be8aaf950dc32 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Distributions.h @@ -0,0 +1,25 @@ +#pragma once + +namespace at { +struct CUDAGeneratorImpl; +struct TensorIteratorBase; +class TensorBase; + +namespace native { + +void launch_poisson_cuda_kernel( + const TensorBase &ret, const TensorBase &lambda, CUDAGeneratorImpl *gen); + +void launch_gamma_kernel( + const TensorBase &ret, const TensorBase &alpha, CUDAGeneratorImpl *gen); + +void launch_binomial_cuda_kernel( + TensorIteratorBase &iter, CUDAGeneratorImpl *gen); + +void launch_dirichlet_kernel(TensorIteratorBase &iter); + +void launch_standard_gamma_grad_kernel(TensorIteratorBase &iter); + +void launch_dirichlet_grad_kernel(TensorIteratorBase &iter); + +}} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/ForeachFunctors.cuh b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/ForeachFunctors.cuh new file mode 100644 index 0000000000000000000000000000000000000000..55e4fd7a598907f452d033f73816c16b7c6e22b8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/ForeachFunctors.cuh @@ -0,0 +1,681 @@ +#pragma once +#include +#include +#include +#include + +namespace at::native { + +namespace { + +// TODO(crcrpar): Handle version bump in codegen. +// rel: +// https://github.com/pytorch/pytorch/blob/9cf84347767c8abb8feba18a9a1baba321eeb8b9/tools/autograd/gen_inplace_or_view_type.py#L481-L482 +inline void increment_version(TensorList tensors) { + for (const auto& t : tensors) { + t.unsafeGetTensorImpl()->bump_version(); + } +} + +// Initializes args and checks if all args are aligned +template +__device__ bool init_args( + T** args, + TensorListMetadata& tl, + const int64_t chunk_idx, + const int64_t chunk_size, + const int64_t tensor_loc) { + bool all_aligned = true; + for (int i = 0; i < depth; i++) { + args[i] = (T*)tl.addresses[i][tensor_loc]; + args[i] += chunk_idx * chunk_size; + + if (!is_aligned(args[i])) { + all_aligned = false; + } + } + return all_aligned; +} + +// Initializes args and checks if all args are aligned +template +__device__ bool init_args( + T** args, + TensorListScalarListMetadata& tl, + const int64_t chunk_idx, + const int64_t chunk_size, + const int64_t tensor_loc) { + bool all_aligned = true; + for (int i = 0; i < depth; i++) { + args[i] = (T*)tl.addresses[i][tensor_loc]; + args[i] += chunk_idx * chunk_size; + + if (!is_aligned(args[i])) { + all_aligned = false; + } + } + return all_aligned; +} + +template +__device__ bool init_args( + T** args, + FusedOptimizerTensorListMetadata& tl, + const int64_t chunk_idx, + const int64_t chunk_size, + const int64_t tensor_loc) { + bool all_aligned = true; + for (int i = 0; i < depth; i++) { + args[i] = (T*)tl.addresses[i][tensor_loc]; + args[i] += chunk_idx * chunk_size; + + if (!is_aligned(args[i])) { + all_aligned = false; + } + } + return all_aligned; +} + +template +__device__ void load_args( + T r_args[][kILP], + T** args, + const int64_t i_start, + const int64_t chunk_size, + const int64_t n) { +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + const auto i = i_start + threadIdx.x + ii * blockDim.x; + for (int r_index = 0; r_index < depth; r_index++) { + r_args[r_index][ii] = 0; + if (i < n && i < chunk_size) { + r_args[r_index][ii] = args[r_index][i]; + } + } + } +} + +template +__device__ void store_args( + T* dst, + T* src, + const int64_t i_start, + const int64_t chunk_size, + const int64_t n) { +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + const int64_t i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) + dst[i] = src[ii]; + } +} + +template +__device__ __forceinline__ void binary_op_scalar( + T r_args[][kILP], + T** args, + opmath_t scalar, + const int64_t n, + const int64_t chunk_size, + const bool all_aligned, + Op op) { + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + op(static_cast(r_args[0][ii]), + static_cast(scalar))); + } + // store + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + // Regardless if depth is 1 (for inplace) or 2 (for out of place), r_args + // has depth 1 + load_args<1>(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + op(static_cast(r_args[0][ii]), + static_cast(scalar))); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } +} + +template +__device__ __forceinline__ void pointwise_op_scalar( + T r_args[][kILP], + T** args, + opmath_t scalar, + const int64_t n, + const int64_t chunk_size, + const bool all_aligned, + Op op) { + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); + load_store(r_args[1], args[1], 0, i_start); + load_store(r_args[2], args[2], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + static_cast(r_args[0][ii]) + + scalar * + op(static_cast(r_args[1][ii]), + static_cast(r_args[2][ii]))); + } + // store + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + // Regardless if depth is 3 (for inplace) or 4 (for out of place), r_args + // has depth 3 + load_args<3>(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + static_cast(r_args[0][ii]) + + scalar * + op(static_cast(r_args[1][ii]), + static_cast(r_args[2][ii]))); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } +} + +// +// Binary Functors +// +template +struct BinaryOpScalarFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op, + opmath_t scalar) { + const int tensor_loc = tl.block_to_tensor[blockIdx.x]; + const int chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + binary_op_scalar( + r_args, args, scalar, n, chunk_size, all_aligned, op); + } +}; + +template +struct BinaryOpScalarListFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListScalarListMetadata& tl, + Op op) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + opmath_t scalar = tl.scalar_vals[tensor_loc]; + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + binary_op_scalar( + r_args, args, scalar, n, chunk_size, all_aligned, op); + } +}; + +template +struct BinaryOpListAlphaFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op, + opmath_t alpha) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); + load_store(r_args[1], args[1], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + op(static_cast(r_args[0][ii]), + alpha * static_cast(r_args[1][ii]))); + } + // store + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + op(static_cast(r_args[0][ii]), + alpha * static_cast(r_args[1][ii]))); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } + } +}; + +template +struct BinaryOpScalarTensorFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op, + T* scalar, + opmath_t alpha) { + const int tensor_loc = tl.block_to_tensor[blockIdx.x]; + const int chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast(op( + static_cast(r_args[0][ii]), + static_cast(alpha) * static_cast(*scalar))); + } + // store + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + // Regardless if depth is 1 (for inplace) or 2 (for out of place), + // r_args has depth 1 + load_args<1>(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast(op( + static_cast(r_args[0][ii]), + static_cast(alpha) * static_cast(*scalar))); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } + } +}; + +// +// Unary Functors +// + +template +struct ZeroFunctor { + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata<1>& tl) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const auto all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = 0; + } + // store + load_store(args[0], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = 0; + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } + } +}; + +template +struct UnaryOpFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = + static_cast(op(static_cast(r_args[0][ii]))); + } + // store + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = + static_cast(op(static_cast(r_args[0][ii]))); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } + } +}; + +// +// Pointwise Functors +// + +template +struct PointwiseOpScalarFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op, + opmath_t scalar) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + pointwise_op_scalar( + r_args, args, scalar, n, chunk_size, all_aligned, op); + } +}; + +template +struct PointwiseOpScalarListFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListScalarListMetadata& tl, + Op op) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + opmath_t scalar = tl.scalar_vals[tensor_loc]; + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + pointwise_op_scalar( + r_args, args, scalar, n, chunk_size, all_aligned, op); + } +}; + +template +struct PointwiseOpListFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[depth - 1][kILP]; + + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); + load_store(r_args[1], args[1], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]))); + } + // store + load_store(args[2], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = static_cast( + op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]))); + } + store_args(args[2], r_args[0], i_start, chunk_size, n); + } + } + } +}; + +template +struct TernaryOpListFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op) { + static_assert(depth == 3 || depth == 4, ""); + static_assert(depth >= r_args_depth, ""); + static_assert(res_arg_index == depth - 1 || res_arg_index == 0, ""); + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + load_store(r_args[0], args[0], 0, i_start); + load_store(r_args[1], args[1], 0, i_start); + load_store(r_args[2], args[2], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = + op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]), + static_cast(r_args[2][ii])); + } + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = + op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]), + static_cast(r_args[2][ii])); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } + } +}; + +template +struct TernaryOpScalarFunctor { + using opmath_t = at::opmath_type; + template + __device__ __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + Op op, + opmath_t alpha) { + static_assert(depth == 2 || depth == 3, ""); + static_assert(depth >= r_args_depth, ""); + static_assert(res_arg_index == depth - 1 || res_arg_index == 0, ""); + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + T* args[depth]; + const bool all_aligned = + init_args(args, tl, chunk_idx, chunk_size, tensor_loc); + n -= chunk_idx * chunk_size; + T r_args[r_args_depth][kILP]; + + // to make things simple, we put aligned case in a different code path + if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_args[0], args[0], 0, i_start); + load_store(r_args[1], args[1], 0, i_start); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = + op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]), + alpha); + } + // store + load_store(args[res_arg_index], r_args[0], i_start, 0); + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + r_args[0][ii] = + op(static_cast(r_args[0][ii]), + static_cast(r_args[1][ii]), + alpha); + } + store_args(args[res_arg_index], r_args[0], i_start, chunk_size, n); + } + } + } +}; + +template +struct power_functor { + C10_DEVICE T operator()(const T& a, const T& b) const { + return at::native::pow_(a, b); + } +}; + +template +struct reverse_power_functor { + C10_DEVICE T operator()(const T& a, const T& b) const { + return at::native::pow_(b, a); + } +}; + +} // namespace +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/GridSampler.cuh b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/GridSampler.cuh new file mode 100644 index 0000000000000000000000000000000000000000..65cf9858b3bb38698a578f9600b60cf96f897e07 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/GridSampler.cuh @@ -0,0 +1,321 @@ +#pragma once +#include +#include + +namespace at { namespace native { + +using detail::GridSamplerInterpolation; +using detail::GridSamplerPadding; + +// Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value, +// where we view each pixel as an area between (idx - 0.5) and (idx + 0.5). +// if align_corners: -1 and +1 get sent to the centers of the corner pixels +// -1 --> 0 +// +1 --> (size - 1) +// scale_factor = (size - 1) / 2 +// if not align_corners: -1 and +1 get sent to the image edges +// -1 --> -0.5 +// +1 --> (size - 1) + 0.5 == size - 0.5 +// scale_factor = size / 2 +template +__forceinline__ __device__ +scalar_t grid_sampler_unnormalize(scalar_t coord, int size, bool align_corners) { + if (align_corners) { + // unnormalize coord from [-1, 1] to [0, size - 1] + return ((coord + 1.f) / 2) * (size - 1); + } else { + // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] + return ((coord + 1.f) * size - 1) / 2; + } +} + +// grid_sampler_unnormalize_set_grad works the same as grid_sampler_unnormalize +// except that it also returns the `d output / d input` via pointer argument +// `grad_in`. +// This is useful in the backward pass of grid_sampler. +template +__forceinline__ __device__ +scalar_t grid_sampler_unnormalize_set_grad(scalar_t coord, int size, + bool align_corners, scalar_t *grad_in) { + if (align_corners) { + // unnormalize coord from [-1, 1] to [0, size - 1] + *grad_in = static_cast(size - 1) / 2; + return ((coord + 1.f) / 2) * (size - 1); + } else { + // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] + *grad_in = static_cast(size) / 2; + return ((coord + 1.f) * size - 1) / 2; + } +} + +// Clips coordinates to between 0 and clip_limit - 1 +template +__forceinline__ __device__ +scalar_t clip_coordinates(scalar_t in, int clip_limit) { + return ::min(static_cast(clip_limit - 1), ::max(in, static_cast(0))); +} + +// clip_coordinates_set_grad works similarly to clip_coordinates except that +// it also returns the `d output / d input` via pointer argument `grad_in`. +// This is useful in the backward pass of grid_sampler. +template +__forceinline__ __device__ +scalar_t clip_coordinates_set_grad(scalar_t in, int clip_limit, scalar_t *grad_in) { + // Note that it is important for the gradient calculation that borders + // are considered out of bounds. + if (in <= static_cast(0)) { + *grad_in = static_cast(0); + return static_cast(0); + } else { + scalar_t max = static_cast(clip_limit - 1); + if (in >= max) { + *grad_in = static_cast(0); + return max; + } else { + *grad_in = static_cast(1); + return in; + } + } +} + +// Reflects coordinates until they fall between low and high (inclusive). +// The bounds are passed as twice their value so that half-integer values +// can be represented as ints. +template +__forceinline__ __device__ +scalar_t reflect_coordinates(scalar_t in, int twice_low, int twice_high) { + if (twice_low == twice_high) { + return static_cast(0); + } + scalar_t min = static_cast(twice_low) / 2; + scalar_t span = static_cast(twice_high - twice_low) / 2; + in = ::fabs(in - min); + // `fmod` returns same sign as `in`, which is positive after the `fabs` above. + scalar_t extra = ::fmod(in, span); + int flips = static_cast(::floor(in / span)); + if (flips % 2 == 0) { + return extra + min; + } else { + return span - extra + min; + } +} + +// reflect_coordinates_set_grad works similarly to reflect_coordinates except +// that it also returns the `d output / d input` via pointer argument +// `grad_in`. +// This is useful in the backward pass of grid_sampler. +template +__forceinline__ __device__ +scalar_t reflect_coordinates_set_grad(scalar_t in, int twice_low, int twice_high, + scalar_t *grad_in) { + if (twice_low == twice_high) { + *grad_in = static_cast(0); + return static_cast(0); + } + int grad_in_mult_; + scalar_t min = static_cast(twice_low) / 2; + scalar_t span = static_cast(twice_high - twice_low) / 2; + in = in - min; + if (in < static_cast(0)) { + grad_in_mult_ = -1; + in = -in; + } else { + grad_in_mult_ = 1; + } + // `fmod` returns same sign as `in`, which is positive after the `if` above. + scalar_t extra = ::fmod(in, span); + int flips = static_cast(::floor(in / span)); + if (flips % 2 == 0) { + *grad_in = static_cast(grad_in_mult_); + return extra + min; + } else { + *grad_in = static_cast(-grad_in_mult_); + return span - extra + min; + } +} + +template +__forceinline__ __device__ +scalar_t safe_downgrade_to_int_range(scalar_t x){ + // -100.0 does not have special meaning. This is just to make sure + // it's not within_bounds_2d or within_bounds_3d, and does not cause + // undefined behavior. See #35506. + if (x > INT_MAX-1 || x < INT_MIN || !::isfinite(static_cast(x))) + return static_cast(-100.0); + return x; +} + +template +__forceinline__ __device__ +scalar_t compute_coordinates(scalar_t coord, int size, + GridSamplerPadding padding_mode, + bool align_corners) { + if (padding_mode == GridSamplerPadding::Border) { + // clip coordinates to image borders + coord = clip_coordinates(coord, size); + } else if (padding_mode == GridSamplerPadding::Reflection) { + // reflect coordinates by image borders + if (align_corners) { + coord = reflect_coordinates(coord, 0, 2*(size - 1)); + } else { + coord = reflect_coordinates(coord, -1, 2*size - 1); + } + // clip coordinates to image borders + coord = clip_coordinates(coord, size); + } + + coord = safe_downgrade_to_int_range(coord); + return coord; +} + +// Computes the pixel source index value for a grid coordinate +template +__forceinline__ __device__ +scalar_t grid_sampler_compute_source_index( + scalar_t coord, + int size, + GridSamplerPadding padding_mode, + bool align_corners) { + coord = grid_sampler_unnormalize(coord, size, align_corners); + coord = compute_coordinates(coord, size, padding_mode, align_corners); + return coord; +} + +// grid_sampler_compute_source_index_set_grad works similarly to +// grid_sampler_compute_source_index except that it also returns the +// `d output / d input` via pointer argument `grad_in`. +// This is useful in the backward pass of grid_sampler. +template +__forceinline__ __device__ +scalar_t grid_sampler_compute_source_index_set_grad( + scalar_t coord, + int size, + GridSamplerPadding padding_mode, + bool align_corners, + scalar_t *grad_in) { + scalar_t grad_clip, grad_refl; + coord = grid_sampler_unnormalize_set_grad(coord, size, align_corners, grad_in); + if (padding_mode == GridSamplerPadding::Border) { + // clip coordinates to image borders + coord = clip_coordinates_set_grad(coord, size, &grad_clip); + *grad_in = (*grad_in) * grad_clip; + } else if (padding_mode == GridSamplerPadding::Reflection) { + // reflect coordinates by image borders + if (align_corners) { + coord = reflect_coordinates_set_grad(coord, 0, 2*(size - 1), &grad_refl); + } else { + coord = reflect_coordinates_set_grad(coord, -1, 2*size - 1, &grad_refl); + } + // clip coordinates to image borders + coord = clip_coordinates_set_grad(coord, size, &grad_clip); + *grad_in = (*grad_in) * grad_refl * grad_clip; + } + + coord = safe_downgrade_to_int_range(coord); + return coord; +} + +__forceinline__ __device__ +bool within_bounds_2d(int h, int w, int H, int W) { + return h >= 0 && h < H && w >= 0 && w < W; +} + +__forceinline__ __device__ +bool within_bounds_3d(int d, int h, int w, int D, int H, int W) { + return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W; +} + +template +__forceinline__ __device__ +scalar_t get_value_bounded( + const scalar_t *data, scalar_t x, scalar_t y, int W, int H, int sW, int sH, + GridSamplerPadding padding_mode, + bool align_corners) { + + x = compute_coordinates(x, W, padding_mode, align_corners); + y = compute_coordinates(y, H, padding_mode, align_corners); + + int ix = static_cast(x); + int iy = static_cast(y); + + if (within_bounds_2d(iy, ix, H, W)) { + return data[iy * sH + ix * sW]; + } + return static_cast(0); +} + +template +__forceinline__ __device__ +void safe_add_2d(scalar_t *data, int h, int w, + int sH, int sW, int H, int W, + scalar_t delta, + const index_t NC_offset, + const index_t memory_span) { + if (within_bounds_2d(h, w, H, W)) { + fastAtomicAdd(data, + NC_offset + h * sH + w * sW, + memory_span, + delta, + true); + } +} + +template +__forceinline__ __device__ +void safe_add_3d(scalar_t *data, int d, int h, int w, + int sD, int sH, int sW, int D, int H, int W, + scalar_t delta, + const index_t NC_offset, + const index_t memory_span) { + if (within_bounds_3d(d, h, w, D, H, W)) { + fastAtomicAdd(data, + NC_offset + d * sD + h * sH + w * sW, + memory_span, + delta, + true); + } +} + +template +__forceinline__ __device__ +void add_value_bounded( + scalar_t* data, scalar_t x, scalar_t y, int W, int H, int sW, int sH, + scalar_t delta, + GridSamplerPadding padding_mode, + bool align_corners, + const index_t NC_offset, + const index_t memory_span) { + + x = compute_coordinates(x, W, padding_mode, align_corners); + y = compute_coordinates(y, H, padding_mode, align_corners); + + int ix = static_cast(x); + int iy = static_cast(y); + + safe_add_2d(data, iy, ix, sH, sW, H, W, delta, NC_offset, memory_span); +} + +// Calculate the differential of the cubic convolution, i.e. `d coeff / d x` +template +__forceinline__ __device__ +void get_cubic_coefficients_grad( + scalar_t coeffs[4], + scalar_t t) { + + // Must be the same as forward calculation in + // aten/src/ATen/native/cuda/UpSample.cuh:get_cubic_upsample_coefficients + scalar_t A = -0.75; + + scalar_t x; + x = -1 - t; // 1 < x = |-1 - tx| < 2 + coeffs[0] = (-3 * A * x - 10 * A ) * x - 8 * A; + x = -t; // x = |0 - tx| <= 1 + coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x; + x = 1 - t; // x = |1 - tx| <= 1 + coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x; + x = 2 - t; // 1 < x = |2 - tx| < 2 + coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A; +} + + +}} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/IndexKernel.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/IndexKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..edd9190deb0dba12979556a9f1bc12705f5801b4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/IndexKernel.h @@ -0,0 +1,16 @@ +#pragma once +#include +#include + +namespace at { +struct TensorIteratorBase; +class TensorBase; +} + +namespace at { +namespace native { +/// @param maskPrefixSum[in,out] +void launch_masked_scatter_kernel( + const TensorBase &self, const TensorBase &mask, + const TensorBase &maskPrefixSum, const TensorBase &source); +}} diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/KernelUtils.cuh b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/KernelUtils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..d07f54093e8136234cba6879f2dfff17a6d0e5a3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/KernelUtils.cuh @@ -0,0 +1,149 @@ +#pragma once +#include + +#if !(defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) +#include +#endif + +namespace at { +namespace native { + +__device__ __forceinline__ size_t +idx(const size_t nc, + const size_t height, + const size_t width, + const size_t h, + const size_t w) { + return (nc * height + h) * width + w; +} + +// for channels-last +__device__ __forceinline__ size_t +idx_cl( + const size_t n, const size_t h, const size_t w, const size_t c, + const size_t height, const size_t width, const size_t channel +) { + return ((n * height + h) * width + w) * channel + c; +} + +// fastSpecializedAtomicAdd (and fastAtomicAdd) are an optimization +// that speed up half-precision atomics. The situation with half +// precision atomics is that we have a slow __half atomic, and +// a fast vectored __half2 atomic (this can be worth up to a 6x +// speedup, see https://github.com/pytorch/pytorch/pull/21879). +// We can convert a __half atomic into a __half2 atomic by simply +// pairing the __half with a zero entry on the left/right depending +// on alignment... but only if this wouldn't cause an out of bounds +// access! Thus, you must specify tensor and numel so we can check +// if you would be out-of-bounds and use a plain __half atomic if +// you would be. +template < + typename scalar_t, + typename index_t, + typename std::enable_if::value>::type* = + nullptr> +__device__ __forceinline__ void fastSpecializedAtomicAdd( + scalar_t* tensor, + index_t index, + const index_t numel, + scalar_t value) { +#if ( \ + (defined(USE_ROCM)) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700))) + gpuAtomicAddNoReturn( + reinterpret_cast(tensor) + index, + static_cast(value)); +#else + // Accounts for the chance tensor falls on an odd 16 bit alignment (ie, not 32 bit aligned) + __half* target_addr = reinterpret_cast<__half*>(tensor + index); + bool low_byte = (reinterpret_cast(target_addr) % sizeof(__half2) == 0); + + if (low_byte && index < (numel - 1)) { + __half2 value2; + value2.x = static_cast<__half>(value); + value2.y = __int2half_rz(0); + atomicAdd(reinterpret_cast<__half2*>(target_addr), value2); + + } else if (!low_byte && index > 0) { + __half2 value2; + value2.x = __int2half_rz(0); + value2.y = static_cast<__half>(value); + atomicAdd(reinterpret_cast<__half2*>(target_addr - 1), value2); + + } else { + atomicAdd( + reinterpret_cast<__half*>(tensor) + index, static_cast<__half>(value)); + } +#endif +} + +template < + typename scalar_t, + typename index_t, + typename std::enable_if::value>::type* = + nullptr> +__device__ __forceinline__ void fastSpecializedAtomicAdd( + scalar_t* tensor, + index_t index, + const index_t numel, + scalar_t value) { +#if ( \ + (defined(USE_ROCM)) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))) + gpuAtomicAddNoReturn( + reinterpret_cast(tensor) + index, + static_cast(value)); +#else + // Accounts for the chance tensor falls on an odd 16 bit alignment (ie, not 32 bit aligned) + __nv_bfloat16* target_addr = reinterpret_cast<__nv_bfloat16*>(tensor + index); + bool low_byte = (reinterpret_cast(target_addr) % sizeof(__nv_bfloat162) == 0); + + if (low_byte && index < (numel - 1)) { + __nv_bfloat162 value2; + value2.x = *reinterpret_cast<__nv_bfloat16*>(&value); + value2.y = __int2bfloat16_rz(0); + atomicAdd(reinterpret_cast<__nv_bfloat162*>(target_addr), value2); + + } else if (!low_byte && index > 0) { + __nv_bfloat162 value2; + value2.x = __int2bfloat16_rz(0); + value2.y = *reinterpret_cast<__nv_bfloat16*>(&value); + atomicAdd(reinterpret_cast<__nv_bfloat162*>(target_addr - 1), value2); + + } else { + atomicAdd( + reinterpret_cast<__nv_bfloat16*>(tensor) + index, *reinterpret_cast<__nv_bfloat16*>(&value)); + } +#endif +} + + +template < + typename scalar_t, + typename index_t, + typename std::enable_if::value && !std::is_same::value >::type* = + nullptr> +__device__ __forceinline__ void fastSpecializedAtomicAdd( + scalar_t* tensor, + index_t index, + const index_t numel, + scalar_t value) { + gpuAtomicAddNoReturn(tensor + index, value); +} + +template +__device__ __forceinline__ void fastAtomicAdd( + scalar_t* tensor, + index_t index, + const index_t numel, + scalar_t value, + bool fast_atomics) { + if (fast_atomics) { + fastSpecializedAtomicAdd(tensor, index, numel, value); + } else { + gpuAtomicAddNoReturn(tensor + index, value); + } +} + +} // namespace native +} // namespace at diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/LaunchUtils.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/LaunchUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..c9640b15b18c8a2d6d4f3dd92379701ae1ec5164 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/LaunchUtils.h @@ -0,0 +1,18 @@ +#pragma once +#include + +namespace at { +namespace native { + +// returns 2**floor(log2(n)) +static int lastPow2(unsigned int n) { + n |= (n >> 1); + n |= (n >> 2); + n |= (n >> 4); + n |= (n >> 8); + n |= (n >> 16); + return std::max(1, n - (n >> 1)); +} + +} // namespace native +} // namespace at diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/MemoryAccess.cuh b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/MemoryAccess.cuh new file mode 100644 index 0000000000000000000000000000000000000000..0fdc813fd77707310f2bc7da62bd147f000ae726 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/MemoryAccess.cuh @@ -0,0 +1,389 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +// References: +// https://devblogs.nvidia.com/cuda-pro-tip-increase-performance-with-vectorized-memory-access/ + +namespace at { namespace native { namespace memory { + +namespace detail { + +// What does the `static_unroll` do? +// +// We want to do something like: +// +// using args_t = typename traits::ArgsTuple; +// args_t args; +// #pragma unroll +// for (int i = 0; i < traits::arity; i++) { +// std::get(args) = .... +// } +// +// but unfortunately the above code does not work because +// the template argument has to be a compile time constant +// so `static_unroll` is created to simulate `#pragma unroll` +// using template metaprogramming. + +template typename func, int end, int current=0> +struct static_unroll { + template + static inline C10_HOST_DEVICE void with_args(Args&&... args) { + func::apply(std::forward(args)...); + static_unroll::with_args(args...); + } +}; + +template typename func, int end> +struct static_unroll { + template + static inline C10_HOST_DEVICE void with_args(Args... args) {} +}; + +// helper structs to be used with static_unroll to load arguments +// one by one + +template +struct vectorized_load_helper { + template + static __device__ void apply(policy_t &self, args_t *args, int idx) { + using arg_t = std::tuple_element_t; + // `data` hold the data_ptr for tensors [output, input0, input1, ...], so we + // need a +1 offset to get the input + auto ptr = reinterpret_cast(self.data[arg_index + 1]) + block_work_size() * idx; + auto args_accessor = [&args] __device__ (int thread_unroll_idx) -> arg_t & { return std::get(args[thread_unroll_idx]); }; + self.load_single_arg(args_accessor, ptr); + } +}; + +template +struct unroll_load_helper { + template + static __device__ void apply(policy_t &self, args_t *args, offset_t offset, loader_t loader, int j, int num_outputs) { + using arg_t = std::tuple_element_t; + // `data` hold the data_ptr for tensors [output, input0, input1, ...], so we + // need a +1 offset to get the input + std::get(args[j]) = loader.template load(self.data[arg_index + num_outputs], offset[arg_index], arg_index); + } +}; + +template +struct multi_outputs_store_helper { + template + C10_HOST_DEVICE static void apply( + at::detail::Array data, + at::detail::Array offsets, + thrust::tuple ret) { + using T = typename thrust::tuple_element>::type; + T *to = reinterpret_cast(data[current]) + offsets[current]; + *to = thrust::get(ret); + } +}; + +} // namespace detail + +struct LoadWithoutCast { + template + __device__ scalar_t load(char *base_ptr, uint32_t offset, int arg) { + return c10::load(reinterpret_cast(base_ptr) + offset); + } +}; + +template +struct LoadWithCast { + using array_t = at::detail::Array(N, 1)>; + using size_array_t = at::detail::Array(N, 1)>; + + array_t dtypes; + size_array_t element_sizes; + + LoadWithCast(const TensorIteratorBase& iter) { + CUDA_KERNEL_ASSERT(iter.ninputs() == N); + #pragma unroll + for (auto i = 0; i < N; ++i) { + this->dtypes[i] = iter.dtype(i + iter.noutputs()); + element_sizes[i] = c10::elementSize(iter.dtype(i + iter.noutputs())); + } + } + + template + __device__ scalar_t load(char *base_ptr, uint32_t offset, int arg) { + void *ptr = base_ptr + element_sizes[arg] * offset; + return c10::fetch_and_cast(dtypes[arg], ptr); + } +}; + +struct StoreWithoutCast { + template + __device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg = 0) { + *(reinterpret_cast(base_ptr) + offset) = value; + } +}; + +template +struct StoreWithCast { + using array_t = at::detail::Array(N, 1)>; + using size_array_t = at::detail::Array(N, 1)>; + + array_t dtypes; + size_array_t element_sizes; + + StoreWithCast(const TensorIteratorBase& iter) { + CUDA_KERNEL_ASSERT(iter.noutputs() == N); + #pragma unroll + for (auto i = 0; i < N; ++i) { + this->dtypes[i] = iter.dtype(i); + element_sizes[i] = c10::elementSize(iter.dtype(i)); + } + } + + template + __device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg = 0) { + void *ptr = base_ptr + element_sizes[arg] * offset; + c10::cast_and_store(dtypes[arg], ptr, value); + } +}; + +// aligned vector generates vectorized load/store on CUDA +template +struct alignas(sizeof(scalar_t) * vec_size) aligned_vector { + scalar_t val[vec_size]; +}; + +template +__device__ aligned_vector load_vector(const scalar_t *base_ptr, uint32_t offset) { + using vec_t = aligned_vector; + auto *from = reinterpret_cast(base_ptr); + return from[offset]; +} + +template +__device__ aligned_vector load_vector(const bool *base_ptr, uint32_t offset) { + // See NOTE [Loading boolean values] + auto tmp = load_vector(reinterpret_cast(base_ptr), offset); + aligned_vector ret; + for (int i = 0; i < vec_size; ++i) { + ret.val[i] = bool(tmp.val[i]); + } + return ret; +} + +namespace policies { + +// Assumption: +// all tensors are contiguous, that is: stride == sizeof(type) for all tensors +template +struct unroll { + + data_t data; + int remaining; + inp_calc_t input_offset_calculator; + out_calc_t output_offset_calculator; + loader_t loader; + storer_t storer; + + __device__ unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s): + data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc), loader(l), storer(s) {} + + __device__ inline bool check_inbounds(int thread_work_elem) { + return ((int)(threadIdx.x + thread_work_elem*num_threads()) < remaining); + } + + template + __device__ inline void load(args_t *args, int idx) { + constexpr int arity = std::tuple_size::value; + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < thread_work_size(); i++) { + if (thread_idx >= remaining) { + return; + } + int linear_idx = thread_idx + block_work_size() * idx; + auto offset = input_offset_calculator.get(linear_idx); + detail::static_unroll::with_args(*this, args, offset, loader, i, num_outputs); + thread_idx += num_threads(); + } + } + + template + __device__ inline void store(scalar_t *from, int idx) { + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < thread_work_size(); i++) { + if (thread_idx >= remaining) { + return; + } + int linear_idx = thread_idx + block_work_size() * idx; + int offset = output_offset_calculator.get(linear_idx)[0]; + storer.store(from[i], data[0], offset); + thread_idx += num_threads(); + } + } +}; + +// Assumption: +// all tensors are contiguous, that is: stride == sizeof(type) for all tensors +// Note: +// Functions in vectorized policy does not do boundary check. It assumes the whole block +// has its job to do. So the reminders should be handled by the caller manually. +template // vec_size: number of scalars, can be 1, 2, or 4. +struct vectorized { + + static_assert(thread_work_size() % vec_size == 0, "The workload per thread must be a multiple of vec_size"); + static constexpr int loop_size = thread_work_size() / vec_size; + + data_t data; + + __device__ vectorized(data_t data) : data(data) {} + + __device__ inline constexpr bool check_inbounds(int thread_work_elem) { + return true; + } + + template + __device__ inline void load_single_arg(accessor_t to, scalar_t *from) { + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < loop_size; i++) { + int index = thread_idx + i * num_threads(); + auto v = load_vector(from, index); + #pragma unroll + for (int j = 0; j < vec_size; j++) { + to(vec_size * i + j) = v.val[j]; + } + } + } + + template + __device__ inline void load(args_t *args, int idx) { + constexpr int arity = std::tuple_size::value; + detail::static_unroll::with_args(*this, args, idx); + } + + template + __device__ inline void store(scalar_t *from, int idx) { + using vec_t = aligned_vector; + scalar_t *to = reinterpret_cast(data[0]) + block_work_size() * idx; + vec_t *to_ = reinterpret_cast(to); + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < loop_size; i++) { + int index = thread_idx + i * num_threads(); + vec_t v; + for (int j = 0; j < vec_size; j++) { + v.val[j] = from[vec_size * i + j]; + } + to_[index] = v; + } + } +}; + +template +struct multi_outputs_unroll { + //multi_outputs_unroll struct members and check_inbounds and load methods are copypasted from unroll struct + //we don't use inheritance because of compiler bug in cuda 10.2+ + data_t data; + int remaining; + inp_calc_t input_offset_calculator; + out_calc_t output_offset_calculator; + LoadWithoutCast loader; + StoreWithoutCast storer; + + __device__ multi_outputs_unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc): + data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc) {} + + __device__ inline bool check_inbounds(int thread_work_elem) { + return ((int)(threadIdx.x + thread_work_elem*num_threads()) < remaining); + } + + template + __device__ inline void load(args_t *args, int idx) { + constexpr int arity = std::tuple_size::value; + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < thread_work_size(); i++) { + if (thread_idx >= remaining) { + return; + } + int linear_idx = thread_idx + block_work_size() * idx; + auto offset = input_offset_calculator.get(linear_idx); + detail::static_unroll::with_args(*this, args, offset, loader, i, num_outputs); + thread_idx += num_threads(); + } + } + + + template + __device__ inline void store(return_t *from, int idx) { + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < thread_work_size(); i++) { + if (thread_idx >= this->remaining) { + return; + } + int linear_idx = thread_idx + block_work_size() * idx; + auto offsets = this->output_offset_calculator.get(linear_idx); + memory::detail::static_unroll::with_args(this->data, offsets, from[i]); + thread_idx += num_threads(); + } + } +}; + +} // namespace policies + +// This is only used in host, but we will wrap this into some templates +// which is C10_HOST_DEVICE, so we have to make this C10_HOST_DEVICE +// in order to compile +template +inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) { + uint64_t address = reinterpret_cast(pointer); + constexpr int vec2_alignment = std::alignment_of>::value; + constexpr int vec4_alignment = std::alignment_of>::value; + if (address % vec4_alignment == 0) { + return 4; + } else if (address % vec2_alignment == 0) { + return 2; + } + return 1; +} + +template +inline C10_HOST_DEVICE int can_vectorize_up_to(char *pointer) { + return can_vectorize_up_to(static_cast(pointer)); +} + +template +struct can_vectorize_up_to_helper { + template + static C10_HOST_DEVICE void apply(int &result, array_t pointers, traits _) { + using arg_t = typename traits::template arg::type; + // `pointers` hold the data_ptr for tensors [output, input0, input1, ...], so we + // need a +1 offset to get the input + result = std::min(result, can_vectorize_up_to(pointers[i + 1])); + } +}; + +template +inline int can_vectorize_up_to(array_t pointers) { + using traits = function_traits; + using return_t = typename traits::result_type; + constexpr int arity = traits::arity; + int result = can_vectorize_up_to(pointers[0]); + // We need to get the type for each argument of `func_t`, this can only + // be done at compile time. + detail::static_unroll::with_args(result, pointers, traits()); + return result; +} + +}}} // namespace at::native::memory diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/MiscUtils.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/MiscUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..e616a7d1fcfb8254528dccc4e6b9d0658ffe1a3c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/MiscUtils.h @@ -0,0 +1,32 @@ +#pragma once +#include +#include +#include +#include + +namespace at { +namespace native { + +static inline int cuda_int_cast(int64_t value, const char* varname) { + auto result = static_cast(value); + TORCH_CHECK(static_cast(result) == value, + "cuda_int_cast: The value of ", varname, "(", (long long)value, + ") is too large to fit into a int (", sizeof(int), " bytes)"); + return result; +} + +// Creates an array of size elements of type T, backed by pinned memory +// wrapped in a Storage +template +static inline Storage pin_memory(int64_t size) { + auto* allocator = cuda::getPinnedMemoryAllocator(); + int64_t adjusted_size = size * sizeof(T); + return Storage( + Storage::use_byte_size_t(), + adjusted_size, + allocator, + /*resizable=*/false); +} + +} // namespace native +} // namespace at diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/MultiTensorApply.cuh b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/MultiTensorApply.cuh new file mode 100644 index 0000000000000000000000000000000000000000..17f14444abd14a03de30f57d3be7254f51a957f9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/MultiTensorApply.cuh @@ -0,0 +1,379 @@ +#pragma once +#include +#include +#include +#include +#include +#include + +namespace at::native { + +namespace { + +static constexpr int64_t kILP = 4; +static constexpr int64_t kChunkSize = 65536; +static constexpr int64_t kBlockSize = 512; + +// TODO(crcrpar): Add `n>5` for `low prec params & their higher prec copy` +// TensorListMetadata has to be < 4KB - the limit for kernel launch argument +static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30}; +static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; +static constexpr int depth_to_max_tensors_scalarlist[5] = {96, 64, 48, 36, 30}; +static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = { + 72, + 60}; + +template +__device__ __forceinline__ bool is_aligned(T* p) { + return ((uint64_t)p) % (kILP * sizeof(T)) == 0; +} + +template +__device__ __forceinline__ void load_store( + T* dst, + T* src, + int64_t dst_offset, + int64_t src_offset) { + using LT = at::native::memory::aligned_vector; + ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; +} + +template +struct TensorListMetadata { + const void* addresses[n][depth_to_max_tensors[n - 1]]; + int64_t numel_for_tensor[depth_to_max_tensors[n - 1]]; + unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; + int block_to_chunk[depth_to_max_blocks[n - 1]]; + int start_tensor_this_launch; +}; + +template +struct TensorListScalarListMetadata { + const void* addresses[n][depth_to_max_tensors_scalarlist[n - 1]]; + int64_t numel_for_tensor[depth_to_max_tensors_scalarlist[n - 1]]; + scalar_vals_t scalar_vals[depth_to_max_tensors_scalarlist[n - 1]]; + unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; + int block_to_chunk[depth_to_max_blocks[n - 1]]; +}; + +// note(mkozuki): `n` of 1&2 violate the limit of cuda kernel argument size of +// 4kb with `c10::complex` +template <> +struct TensorListScalarListMetadata, 1> { + const void* addresses[1] + [depth_to_max_tensors_scalarlist_of_complex_double[0]]; + int64_t + numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[0]]; + c10::complex + scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[0]]; + unsigned char block_to_tensor[depth_to_max_blocks[1 - 1]]; + int block_to_chunk[depth_to_max_blocks[1 - 1]]; +}; + +template <> +struct TensorListScalarListMetadata, 2> { + const void* addresses[2] + [depth_to_max_tensors_scalarlist_of_complex_double[1]]; + int64_t + numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[1]]; + c10::complex + scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[1]]; + unsigned char block_to_tensor[depth_to_max_blocks[2 - 1]]; + int block_to_chunk[depth_to_max_blocks[2 - 1]]; +}; + +// NOTE(crcrpar): This is a conservative resolution to handle `state_steps` +// whose each element is `at::Tensor` of 1 element representing the number of +// `step`s called so far. +template +struct FusedOptimizerTensorListMetadata { + const void* addresses[n][depth_to_max_tensors[n - 1]]; + int64_t numel_for_tensor[depth_to_max_tensors[n - 1]]; + const void* state_steps_addresses[depth_to_max_tensors_scalarlist[n - 1]]; + unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; + int block_to_chunk[depth_to_max_blocks[n - 1]]; + int start_tensor_this_launch; +}; + +template +C10_LAUNCH_BOUNDS_1(kBlockSize) +__global__ void multi_tensor_apply_kernel( + T tensorListMeta, + U callable, + ArgTypes... args) { + // Hand the chunk information to the user-supplied functor to process however + // it likes. + callable(kChunkSize, tensorListMeta, args...); +} + +} // namespace + +// multi_tensor_apply enables horizontal fusion across lists of tensors. +// For example, whereas you once had a for-loop of a + b = c, where a, b, +// and c are individual tensors in lists as, bs, and cs, you can now with +// fewer kernel launches compute as + bs = cs. +// +// You can also imagine bs to be a scalar list vs a tensor list. +// +// The function below takes in tensor lists, scalars, and a callable and +// chunks up the computation to launch as few kernels as possible by iterating +// through every "chunk" in every tensor (thus the nested for loops). In the +// simplest case, everything gets bundled into just one kernel launch, but +// due to blocksize constraints, we may need to launch multiple kernels. +// Each kernel launch is defined by one tensorListMeta construct, which we +// use to track and reset the necessary metadata for each launch. +template +void multi_tensor_apply( + std::vector>& tensor_lists, + at::ArrayRef scalars, + T callable, + ArgTypes... args) { + TORCH_CHECK( + tensor_lists.size() == depth, + "Number of tensor lists has to match the depth."); + const size_t n_tensors = tensor_lists[0].size(); + using scalar_vals_t = typename T::opmath_t; + TensorListScalarListMetadata tensorListMeta; + + int loc_block_info = 0; + int loc_tensor_info = 0; + for (size_t t = 0; t < n_tensors; t++) { + // short-circuit to avoid adding empty tensors to tensorListMeta + if (tensor_lists[0][t].numel() == 0) { + continue; + } + tensorListMeta.scalar_vals[loc_tensor_info] = scalars[t].to(); + tensorListMeta.numel_for_tensor[loc_tensor_info] = + tensor_lists[0][t].numel(); + for (int d = 0; d < depth; d++) { + tensorListMeta.addresses[d][loc_tensor_info] = + tensor_lists[d][t].const_data_ptr(); + } + loc_tensor_info++; + + // now we enter [chunking territory]. + // we will launch a kernel when EITHER the blocks get filled up OR + // the tensors get filled up. There will always be at least one block + // per tensor since the zero-sized ones will not enter the loop, so + // the nested forloop within represents iterating through the chunks + // of a single tensor. + const auto numel = tensor_lists[0][t].numel(); + const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0); + for (auto chunk = 0; chunk < chunks; chunk++) { + tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1; + tensorListMeta.block_to_chunk[loc_block_info] = chunk; + loc_block_info++; + + // a tensor is not considered full unless all its chunks have been + // processed + const bool tensors_full = + (loc_tensor_info == depth_to_max_tensors_scalarlist[depth - 1] && + chunk == chunks - 1); + const bool blocks_full = + (loc_block_info == depth_to_max_blocks[depth - 1]); + + if (tensors_full || blocks_full) { + multi_tensor_apply_kernel<<< + loc_block_info, + kBlockSize, + 0, + at::cuda::getCurrentCUDAStream()>>>( + tensorListMeta, callable, args...); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + // Reset. + loc_block_info = 0; + // all chunks have already been handled in the kernel + if (chunk == chunks - 1) { + loc_tensor_info = 0; + } else { // blocks were full and tensor chunks remain + tensorListMeta.numel_for_tensor[0] = + tensorListMeta.numel_for_tensor[loc_tensor_info - 1]; + tensorListMeta.scalar_vals[0] = + tensorListMeta.scalar_vals[loc_tensor_info - 1]; + for (int d = 0; d < depth; d++) { + tensorListMeta.addresses[d][0] = + tensorListMeta.addresses[d][loc_tensor_info - 1]; + } + loc_tensor_info = 1; + } + } + } + } + + // note: [finishing what we started] + // if there's remaining work to be done but the tensors/blocks aren't full + // yet we are at the end, submit the kernel to do the work! + if (loc_block_info != 0) { + multi_tensor_apply_kernel<<< + loc_block_info, + kBlockSize, + 0, + at::cuda::getCurrentCUDAStream()>>>(tensorListMeta, callable, args...); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } +} + +template +void multi_tensor_apply( + std::vector>& tensor_lists, + T callable, + ArgTypes... args) { + TORCH_CHECK( + tensor_lists.size() == depth, + "Number of tensor lists has to match the depth."); + const size_t n_tensors = tensor_lists[0].size(); + TensorListMetadata tensorListMeta; + tensorListMeta.start_tensor_this_launch = 0; + + int loc_block_info = 0; + int loc_tensor_info = 0; + for (size_t t = 0; t < n_tensors; t++) { + // short-circuit to avoid adding empty tensors to tensorListMeta + if (tensor_lists[0][t].numel() == 0) { + continue; + } + tensorListMeta.numel_for_tensor[loc_tensor_info] = + tensor_lists[0][t].numel(); + for (int d = 0; d < depth; d++) { + tensorListMeta.addresses[d][loc_tensor_info] = + tensor_lists[d][t].const_data_ptr(); + } + loc_tensor_info++; + + // see note: [chunking territory]. + const auto numel = tensor_lists[0][t].numel(); + const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0); + for (auto chunk = 0; chunk < chunks; chunk++) { + tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1; + tensorListMeta.block_to_chunk[loc_block_info] = chunk; + loc_block_info++; + + const bool tensors_full = + (loc_tensor_info == depth_to_max_tensors[depth - 1] && + chunk == chunks - 1); + const bool blocks_full = + (loc_block_info == depth_to_max_blocks[depth - 1]); + + if (tensors_full || blocks_full) { + multi_tensor_apply_kernel<<< + loc_block_info, + kBlockSize, + 0, + at::cuda::getCurrentCUDAStream()>>>( + tensorListMeta, callable, args...); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + // Reset. + loc_block_info = 0; + if (chunk == chunks - 1) { + loc_tensor_info = 0; + tensorListMeta.start_tensor_this_launch = t + 1; + } else { + tensorListMeta.numel_for_tensor[0] = + tensorListMeta.numel_for_tensor[loc_tensor_info - 1]; + for (int d = 0; d < depth; d++) { + tensorListMeta.addresses[d][0] = + tensorListMeta.addresses[d][loc_tensor_info - 1]; + } + loc_tensor_info = 1; + tensorListMeta.start_tensor_this_launch = t; + } + } + } + } + + // see note: [finishing what we started] + if (loc_block_info != 0) { + multi_tensor_apply_kernel<<< + loc_block_info, + kBlockSize, + 0, + at::cuda::getCurrentCUDAStream()>>>(tensorListMeta, callable, args...); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } +} + +template +void multi_tensor_apply_for_fused_optimizer( + std::vector>& tensor_lists, + at::TensorList state_steps, + T callable, + ArgTypes... args) { + TORCH_CHECK( + tensor_lists.size() == depth, + "Number of tensor lists has to match the depth"); + const auto num_tensors = tensor_lists[0].size(); + FusedOptimizerTensorListMetadata tensorListMeta; + + int loc_block_info = 0; + int loc_tensor_info = 0; + for (const auto& tensor_index : c10::irange(num_tensors)) { + // short-circuit to avoid adding empty tensors to tensorListMeta + if (tensor_lists[0][tensor_index].numel() == 0) { + continue; + } + tensorListMeta.state_steps_addresses[loc_tensor_info] = + state_steps[tensor_index].const_data_ptr(); + tensorListMeta.numel_for_tensor[loc_tensor_info] = + tensor_lists[0][tensor_index].numel(); + for (const auto& d : c10::irange(depth)) { + tensorListMeta.addresses[d][loc_tensor_info] = + tensor_lists[d][tensor_index].const_data_ptr(); + } + loc_tensor_info++; + + // see above note: [chunking territory] + const auto numel = tensor_lists[0][tensor_index].numel(); + const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0); + TORCH_CHECK(chunks > -1); + for (const auto& chunk : c10::irange(chunks)) { + tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1; + tensorListMeta.block_to_chunk[loc_block_info] = chunk; + loc_block_info++; + + const auto tensor_full = + (loc_tensor_info == depth_to_max_tensors[depth - 1] && + chunk == chunks - 1); + const auto blocks_full = loc_block_info == depth_to_max_blocks[depth - 1]; + + if (tensor_full || blocks_full) { + multi_tensor_apply_kernel<<< + loc_block_info, + kBlockSize, + 0, + at::cuda::getCurrentCUDAStream()>>>( + tensorListMeta, callable, args...); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + // Reset. + loc_block_info = 0; + if (chunk == chunks - 1) { + loc_tensor_info = 0; + } else { + tensorListMeta.numel_for_tensor[0] = + tensorListMeta.numel_for_tensor[loc_tensor_info - 1]; + tensorListMeta.state_steps_addresses[0] = + tensorListMeta.state_steps_addresses[loc_tensor_info - 1]; + for (const auto& d : c10::irange(depth)) { + tensorListMeta.addresses[d][0] = + tensorListMeta.addresses[d][loc_tensor_info - 1]; + } + loc_tensor_info = 1; + } + } + } + } + + // see above note: [finishing what we've started] + if (loc_block_info != 0) { + multi_tensor_apply_kernel<<< + loc_block_info, + kBlockSize, + 0, + at::cuda::getCurrentCUDAStream()>>>(tensorListMeta, callable, args...); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } +} + +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Normalization.cuh b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Normalization.cuh new file mode 100644 index 0000000000000000000000000000000000000000..455390a96a431d17f7221654ae5463a9537acc7a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Normalization.cuh @@ -0,0 +1,1742 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#include +#endif + +namespace at { namespace native { + +// The maximum number of threads in a block +#if defined(USE_ROCM) +constexpr int MAX_BLOCK_SIZE = 256; +#else +constexpr int MAX_BLOCK_SIZE = 512; +#endif + +constexpr unsigned MAX_GRID_SIZE = 65535u; + +// Number of threads in a block given an input size up to MAX_BLOCK_SIZE +static int getNumThreads(int nElem) { +#if defined(USE_ROCM) + int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE }; +#else + int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE }; +#endif + for (int i = 0; i != 5; ++i) { + if (nElem <= threadSizes[i]) { + return threadSizes[i]; + } + } + return MAX_BLOCK_SIZE; +} + +// Returns the index of the most significant 1 bit in `val`. +__device__ __forceinline__ int getMSB(int val) { + return 31 - __clz(val); +} + +template +struct Float2 { + accscalar_t v1, v2; + __device__ Float2() {} + __device__ Float2(scalar_t v1, scalar_t v2) : v1(static_cast(v1)), v2(static_cast(v2)) {} + __device__ Float2(int v) : v1(static_cast(v)), v2(static_cast(v)) {} + __device__ Float2& operator+=(const Float2& a) { + v1 += a.v1; + v2 += a.v2; + return *this; + } + __device__ friend Float2 operator+(Float2 a, const Float2& b) { + a += b; + return a; + } +}; + +template +struct GradOp { + __device__ GradOp(accscalar_t m, const PTA& i, const PTA& g) + : mean(m), input(i), grad_output(g) {} + __device__ __forceinline__ Float2 operator()(int batch, int plane, int n) { + accscalar_t g = grad_output[batch][plane][n]; + accscalar_t c = static_cast(input[batch][plane][n]) - mean; + return Float2(g, g * c); + } + const accscalar_t mean; + const PTA& input; + const PTA& grad_output; +}; + +template +struct SumReduceOp { + __device__ __forceinline__ acc_t combine(acc_t a, acc_t b) const { return a + b; } + + __device__ __forceinline__ acc_t warp_shfl_down(acc_t data, int offset) const { + return WARP_SHFL_DOWN(data, offset); + } +}; + +template +struct SumReduceOp> { + using acc_t = Float2; + + __device__ __forceinline__ acc_t combine(acc_t a, acc_t b) const { return a + b; } + + __device__ __forceinline__ acc_t warp_shfl_down(acc_t data, int offset) const { + return {WARP_SHFL_DOWN(data.v1, offset), WARP_SHFL_DOWN(data.v2, offset)}; + } +}; + +// Sum across (batch, x/y/z) applying Op() pointwise +// this works by first having each thread sum it's part +// of the data. Then there is a double-shuffling reduction. +// First each warp (of C10_WARP_SIZE threads) uses warpSum to reduce its +// data to the "warp leader", who writes its value into shared memory. +// Then a single warp reads the remaining (at most C10_WARP_SIZE) items +// and reduces them using another warpSum. +// The implicit assumption is that there are no more +// than C10_WARP_SIZE**2 threads. +template +__device__ scalar_t reduce(Op op, PTA tensor, int plane) { + // first the reductions each thread does separately + scalar_t sum = static_cast(0); + for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) { + for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) { + sum += op(batch, plane, x); + } + } + __shared__ scalar_t shared[C10_WARP_SIZE]; + SumReduceOp reduce_op; + sum = cuda_utils::BlockReduce, cuda_utils::Block2D>(sum, reduce_op, 0, shared); + if (threadIdx.x == 0 && threadIdx.y == 0) { + shared[0] = sum; + } + __syncthreads(); + // Everyone picks it up, should be broadcast into the whole grad_input + return shared[0]; +} + +constexpr int ELEMENTS_PER_ITER = 4; // enables concurrency within each thread to hide latency +constexpr int ELEMENTS_PER_THREAD = 16; +constexpr int OPTIMAL_TILE_W = 32; +constexpr int MAX_H_BLOCK = 128; + +__host__ void flexible_launch_configs( + const int reduction, + const int stride, + dim3 &block, + dim3 &grid, + const bool coop_flag = false) { + int block_x = std::min(lastPow2(stride), OPTIMAL_TILE_W); + int block_y = std::min(lastPow2(at::ceil_div(reduction , ELEMENTS_PER_THREAD)), + MAX_BLOCK_SIZE / block_x); + if (block_x * block_y != MAX_BLOCK_SIZE) { + block_x = std::min(lastPow2(stride), MAX_BLOCK_SIZE / block_y); + } + + int grid_x = at::ceil_div(stride, block_x); + int grid_y = std::min(at::ceil_div(reduction, block_y * ELEMENTS_PER_THREAD), MAX_H_BLOCK); + if (coop_flag) { + // it's not worth having a grid reduction if the reduction dimension is not big enough + grid_y = grid_y < 8 ? 1 : grid_y; + } + + block.x = block_x; + block.y = block_y; + block.z = 1; + grid.x = grid_x; + grid.y = grid_y; + grid.z = 1; +} + +template +__device__ __forceinline__ void welford_merge_element(C& count, + T& mean, + T& m2n, + const C& count_new, + const T& mean_new, + const T& m2n_new) { + T factor = T(1.0) / ::max(1, (count + count_new)); + T delta0 = mean - mean_new; + mean = (mean_new * count_new + mean * count) * factor; + m2n += m2n_new + delta0 * delta0 * count_new * count * factor; + count += count_new; +} + +// merge mean/m2n among threadIdx.y within block +template +__device__ __forceinline__ void welford_merge_block_vertical(C& count, + T& mean, + T& m2n, + C* shmem_count, + T* shmem_mean, + T* shmem_m2n) { + // write to shared memory + auto address_base = threadIdx.x + threadIdx.y * blockDim.x; + +#pragma unroll + for (int offset = blockDim.y/2; offset > 0; offset >>= 1) { + if (threadIdx.y < offset*2) { + shmem_mean[address_base] = mean; + shmem_m2n[address_base] = m2n; + shmem_count[address_base] = count; + } + __syncthreads(); + if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { + auto address = address_base + offset * blockDim.x; + // read shared memory back to register for reduction + auto count_new = shmem_count[address]; + auto mean_new = shmem_mean[address]; + auto m2n_new = shmem_m2n[address]; + + welford_merge_element(count, mean, m2n, count_new, mean_new, m2n_new); + } + } +} + +template +__global__ void batch_norm_transform_input_kernel( + const GenericPackedTensorAccessor input, + GenericPackedTensorAccessor output, + const GenericPackedTensorAccessor::type, 1, RestrictPtrTraits, index_t> mean_, + const GenericPackedTensorAccessor::type, 1, RestrictPtrTraits, index_t> var_or_invstd, + const GenericPackedTensorAccessor weight, + const GenericPackedTensorAccessor bias, + stat_accscalar_t epsilon) { + + index_t plane = blockIdx.x; + + if (plane >= input.size(1)) { + return; + } + + stat_accscalar_t gamma = weight.size(0) > 0 ? static_cast(weight[plane]) : static_cast(1); + stat_accscalar_t beta = bias.size(0) > 0 ? static_cast(bias[plane]) : static_cast(0); + stat_accscalar_t mean = static_cast(mean_[plane]); + stat_accscalar_t invstd; + if (train) { + invstd = var_or_invstd[plane]; + } else { + invstd = static_cast(1) / device_sqrt(static_cast(var_or_invstd[plane]) + epsilon); + } + + index_t bs = input.size(0); + index_t fs = input.size(2); + + index_t bstep = blockDim.y * gridDim.y; + for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) { + auto o = output[batch][plane]; + auto i = input[batch][plane]; + for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) { + o[feature] = static_cast(gamma * (i[feature] - mean) * invstd + beta); + } + } +} + +struct InvStd { + template + __device__ __forceinline__ T operator()(T var, double epsilon) const { + T invstd = 0; + if (var != static_cast(0) || epsilon != static_cast(0)) { + invstd = static_cast(1) / device_sqrt(var + epsilon); + } + return invstd; + } +}; + +struct Var { + template + __device__ __forceinline__ T operator()(T var, double epsilon) const { + return var; + } +}; + +template +__global__ void batch_norm_collect_statistics_kernel( + const GenericPackedTensorAccessor input, + const stat_accscalar_t epsilon, + const stat_accscalar_t momentum, + GenericPackedTensorAccessor save_mean, + GenericPackedTensorAccessor save_transformed_var) { + + __shared__ int shared_n[2 * 2 * C10_WARP_SIZE + C10_WARP_SIZE]; + + int plane = blockIdx.x; + int N = input.size(0) * input.size(2); + int tid = threadIdx.x + threadIdx.y * blockDim.x; + + // Compute the mean and variance across (batch, x/y/z) + // this uses the Welford (in the for loop)/parallel algorithm (to sum across the block) + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm + // and the parallel algorithm on the same page. + // We use two shuffles to reduce across the entire block. + // https://devblogs.nvidia.com/faster-parallel-reductions-kepler/ has a description. + stat_accscalar_t* shared_avg_var = (stat_accscalar_t*) &shared_n[C10_WARP_SIZE]; + + // first the reductions each thread does separately + stat_accscalar_t avg = 0; + stat_accscalar_t var_n = 0; + int n = 0; + for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) { + for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) { + stat_accscalar_t v = input[batch][plane][x]; + stat_accscalar_t d1 = v - avg; + n++; + avg += d1 / n; + var_n += d1 * (v - avg); + } + } + + // first warpSum to get one value per thread to + // one value per warp + for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) { + stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, C10_WARP_SIZE); + int o_n = WARP_SHFL_XOR(n, 1 << i, C10_WARP_SIZE); + stat_accscalar_t factor = 1.0 / fmaxf(1.0, n+o_n); + var_n += WARP_SHFL_XOR(var_n, 1 << i, C10_WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor; + avg = (n * avg + o_n * o_avg) * factor; + n += o_n; + } + + // this writes each warps item into shared memory + // there are at most C10_WARP_SIZE items left because + // there are at most C10_WARP_SIZE**2 threads at the beginning + __syncthreads(); + if (tid % C10_WARP_SIZE == 0) { + shared_n[tid / C10_WARP_SIZE] = n; + shared_avg_var[tid / C10_WARP_SIZE * 2] = avg; + shared_avg_var[tid / C10_WARP_SIZE * 2 + 1] = var_n; + } + __syncthreads(); + // now have a second warpSum to reduce the intermediate values + // from shared memory to a single number. The very first + // thread writes it to shared memory. + + if (tid < C10_WARP_SIZE) { + n = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_n[tid] : 0); + avg = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_avg_var[2 * tid] : stat_accscalar_t(0)); + var_n = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_avg_var[2 * tid + 1] : stat_accscalar_t(0)); + } + for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) { + stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, C10_WARP_SIZE); + int o_n = WARP_SHFL_XOR(n, 1 << i, C10_WARP_SIZE); + stat_accscalar_t factor = 1.0 / fmaxf(1.0, n+o_n); + var_n += WARP_SHFL_XOR(var_n, 1 << i, C10_WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor; + avg = (n * avg + o_n * o_avg) * factor; + n += o_n; + } + + // Save the mean, variance, and moving averages + if (tid == 0) { + if (save_mean.data() != NULL) { + save_mean[plane] = avg; + } + if (save_transformed_var.data() != NULL) { + save_transformed_var[plane] = VarTransform{}(var_n / N, epsilon); + } + } + +} + +template +__global__ void batch_norm_backward_kernel( + const GenericPackedTensorAccessor input, + const GenericPackedTensorAccessor grad_output, + GenericPackedTensorAccessor grad_input, + GenericPackedTensorAccessor grad_weight, + GenericPackedTensorAccessor grad_bias, + const GenericPackedTensorAccessor weight, + const GenericPackedTensorAccessor running_mean, + const GenericPackedTensorAccessor running_var, + const GenericPackedTensorAccessor save_mean, + const GenericPackedTensorAccessor save_invstd, + bool train, + stat_accscalar_t epsilon) { + + index_t plane = blockIdx.x; + index_t N = grad_output.size(0) * grad_output.size(2); + + stat_accscalar_t mean, invstd; + if (train) { + mean = save_mean[plane]; + invstd = save_invstd[plane]; + } else { + mean = static_cast(running_mean[plane]); + invstd = static_cast(1) / device_sqrt(static_cast(running_var[plane]) + epsilon); + } + + stat_accscalar_t weight_val = weight.size(0) > 0 ? static_cast(weight[plane]) : stat_accscalar_t(1); + stat_accscalar_t norm = stat_accscalar_t(1) / N; + + // Compute two values across (batch, x/y/z) in one pass: + // 1. Sum(grad_output) + // 2. DotProduct(input - mean, grad_output) + GradOp> g(mean, input, grad_output); + auto res = reduce>(g, grad_output, plane); + + stat_accscalar_t grad_output_sum = res.v1; + stat_accscalar_t dot_p = res.v2; + + stat_accscalar_t grad_mean = grad_output_sum * norm; + stat_accscalar_t proj_scale = dot_p * norm * invstd * invstd; + stat_accscalar_t grad_scale = invstd * weight_val; + + if (grad_input.data() != NULL) { + for (int batch = threadIdx.y; batch < grad_output.size(0); batch += blockDim.y) { + for (int x = threadIdx.x; x < grad_output.size(2); x += blockDim.x) { + input_scalar_t go = grad_output[batch][plane][x]; + if (train) { + stat_accscalar_t inp = input[batch][plane][x]; + stat_accscalar_t proj = (inp - mean) * proj_scale; + grad_input[batch][plane][x] = static_cast((go - proj - grad_mean) * grad_scale); + } else { + grad_input[batch][plane][x] = static_cast(go * grad_scale); + } + } + } + } + + if (grad_weight.size(0) > 0) { + if (threadIdx.x == 0) { + grad_weight[plane] = static_cast(dot_p * invstd); + } + } + + if (grad_bias.size(0) > 0) { + if (threadIdx.x == 0) { + grad_bias[plane] = static_cast(grad_output_sum); + } + } +} + +template +__global__ void batch_norm_reduce_statistics_kernel( + const GenericPackedTensorAccessor vec_mean, + const GenericPackedTensorAccessor vec_invstd, + GenericPackedTensorAccessor mean, + GenericPackedTensorAccessor invstd, + GenericPackedTensorAccessor running_mean, + GenericPackedTensorAccessor running_var, + const accscalar_t epsilon, + const accscalar_t momentum, + const GenericPackedTensorAccessor counts) { + + int feature_size = vec_mean.size(1); + int world_size = vec_mean.size(0); + + int bid = blockIdx.x; + int tid = threadIdx.x; + + // first the reductions each thread does separately + for (int i = bid*blockDim.x+tid; i < feature_size; i += gridDim.x*blockDim.x) { + accscalar_t avg = 0; + accscalar_t var_n = 0; + index_t n = 0; + for (int j = 0; j < world_size; j++) { + scalar_t count = counts[j]; + accscalar_t m = vec_mean[j][i]; + accscalar_t v = accscalar_t(1.0) / (vec_invstd[j][i]); + v = (v * v - epsilon) * count; + accscalar_t factor = 1.0 / (n + count); + var_n += v + (avg - m) * (avg - m) * n * count * factor; + avg = n * factor * avg + count * factor * m; + n += count; + } + mean[i] = avg; + invstd[i] = static_cast(1) / device_sqrt(var_n / n + epsilon); + if (running_mean.data() != NULL) { + running_mean[i] = static_cast((1 - momentum) * running_mean[i] + momentum * avg); + } + accscalar_t unbiasedVar = var_n / (n - 1); + if (running_var.data() != NULL) { + running_var[i] = static_cast((1 - momentum) * running_var[i] + momentum * unbiasedVar); + } + } + +} + +template +__global__ void batch_norm_backward_reduce_kernel( + const GenericPackedTensorAccessor input, + const GenericPackedTensorAccessor grad_output, + GenericPackedTensorAccessor mean, + GenericPackedTensorAccessor invstd, + GenericPackedTensorAccessor sum_dy, + GenericPackedTensorAccessor sum_dy_xmu, + GenericPackedTensorAccessor grad_weight, + GenericPackedTensorAccessor grad_bias) { + + index_t plane = blockIdx.x; + + stat_accscalar_t r_mean = mean[plane]; + stat_accscalar_t factor = invstd[plane]; + + GradOp> g(r_mean, input, grad_output); + auto res = reduce>(g, grad_output, plane); + + if (threadIdx.x == 0) { + if (grad_weight.size(0) > 0) { + grad_weight[plane] = static_cast(res.v2 * factor); + } + if (grad_bias.size(0) > 0) { + grad_bias[plane] = static_cast(res.v1); + } + if (sum_dy.size(0) > 0) { + sum_dy[plane] = static_cast(res.v1); + } + if (sum_dy_xmu.size(0) > 0) { + sum_dy_xmu[plane] = static_cast(res.v2); + } + } +} + +template +__device__ __forceinline__ void batch_norm_backward_elemt_kernel_impl( + const GenericPackedTensorAccessor input, + const GenericPackedTensorAccessor grad_output, + const GenericPackedTensorAccessor mean, + const GenericPackedTensorAccessor invstd, + const GenericPackedTensorAccessor weight, + const GenericPackedTensorAccessor sum_dy, + const GenericPackedTensorAccessor sum_dy_xmu, + GenericPackedTensorAccessor grad_input, + const stat_accscalar_t norm_fct) { + index_t plane = blockIdx.x; + + if (plane >= input.size(1)) { + return; + } + + stat_accscalar_t m_c = mean[plane]; + stat_accscalar_t m_dy_c = sum_dy[plane] * norm_fct; + stat_accscalar_t factor_1_c = invstd[plane]; + stat_accscalar_t factor_2_c = weight.size(0) > 0 ? static_cast(weight[plane]) : stat_accscalar_t(1); + factor_2_c *= factor_1_c; + factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[plane] * norm_fct; + + index_t bs = input.size(0); + index_t fs = input.size(2); + + index_t bstep = blockDim.y * gridDim.y; + for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) { + auto g_i = grad_input[batch][plane]; + auto g_o = grad_output[batch][plane]; + auto i = input[batch][plane]; + for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) { + g_i[feature] = static_cast((g_o[feature] - m_dy_c - (i[feature] - m_c) * factor_1_c) * factor_2_c); + } + } +} + +template +__global__ void batch_norm_backward_elemt_kernel( + const GenericPackedTensorAccessor input, + const GenericPackedTensorAccessor grad_output, + const GenericPackedTensorAccessor mean, + const GenericPackedTensorAccessor invstd, + const GenericPackedTensorAccessor weight, + const GenericPackedTensorAccessor sum_dy, + const GenericPackedTensorAccessor sum_dy_xmu, + GenericPackedTensorAccessor grad_input, + const int* __restrict__ numel, const int world_size) { + int64_t total_numel = 0; + for (int i = 0; i < world_size; i ++) { + total_numel += numel[i]; + } + + const stat_accscalar_t norm_fct = + static_cast(1) / static_cast(total_numel); + batch_norm_backward_elemt_kernel_impl( + input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct); +} + +template +__global__ void batch_norm_backward_elemt_kernel( + const GenericPackedTensorAccessor input, + const GenericPackedTensorAccessor grad_output, + const GenericPackedTensorAccessor mean, + const GenericPackedTensorAccessor invstd, + const GenericPackedTensorAccessor weight, + const GenericPackedTensorAccessor sum_dy, + const GenericPackedTensorAccessor sum_dy_xmu, + GenericPackedTensorAccessor grad_input, + const stat_accscalar_t norm_fct) { + batch_norm_backward_elemt_kernel_impl( + input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct); +} + +template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> +static GenericPackedTensorAccessor get_packed_accessor( + const Tensor& t, c10::string_view var_name) { + constexpr auto expect_type = c10::CppTypeToScalarType::type>::value; + const auto actual_type = t.scalar_type(); + TORCH_CHECK(actual_type == expect_type, "Expected ", var_name, + " to have type ", expect_type, " but got ", actual_type); + return t.generic_packed_accessor(); +} + +template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> +static GenericPackedTensorAccessor packed_accessor_or_dummy( + const Tensor& t, c10::string_view var_name) { + if (!t.defined()) { + const std::array zeros{{0}}; + return GenericPackedTensorAccessor(nullptr, zeros.data(), zeros.data()); + } + return get_packed_accessor(t, var_name); +} + +template +std::tuple batch_norm_backward_cuda_template(const Tensor& grad_out_, const Tensor& input_, const Tensor& weight_, + const Tensor& running_mean_, const Tensor& running_var_, const Tensor& save_mean_, const Tensor& save_invstd_, + bool train, double epsilon, std::array grad_input_mask) { + + using accscalar_t = at::acc_type; + Tensor grad_input_; + Tensor grad_input_reshaped; + Tensor grad_weight_; + Tensor grad_bias_; + auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); + auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes()); + + if (grad_input_mask[0]) { + grad_input_ = at::empty_like(input_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + grad_input_reshaped = grad_input_.view(input_reshaped.sizes()); + } + if (grad_input_mask[1]) { + grad_weight_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + if (grad_input_mask[2]) { + grad_bias_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + + auto input = get_packed_accessor< + const input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input"); + auto grad_output = get_packed_accessor< + const input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output"); + auto grad_input = packed_accessor_or_dummy< + input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input"); + auto weight = packed_accessor_or_dummy< + const stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight"); + auto grad_weight = packed_accessor_or_dummy< + stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_, "grad_weight"); + auto grad_bias = packed_accessor_or_dummy< + stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_, "grad_bias"); + auto running_mean = packed_accessor_or_dummy< + const stat_scalar_t, 1, DefaultPtrTraits, index_t>(running_mean_, "running_mean"); + auto running_var = packed_accessor_or_dummy< + const stat_scalar_t, 1, DefaultPtrTraits, index_t>(running_var_, "running_var"); + auto save_mean = packed_accessor_or_dummy< + const accscalar_t, 1, DefaultPtrTraits, index_t>(save_mean_, "save_mean"); + auto save_invstd = packed_accessor_or_dummy< + const accscalar_t, 1, DefaultPtrTraits, index_t>(save_invstd_, "save_invstd"); + + auto stream = at::cuda::getCurrentCUDAStream(); + dim3 blocks(input.size(1)); + int tf = getNumThreads(input.size(2)); + dim3 threads(tf, std::max(1, MAX_BLOCK_SIZE/tf)); + + batch_norm_backward_kernel <<>> + (input, grad_output, grad_input, grad_weight, grad_bias, weight, running_mean, running_var, + save_mean, save_invstd, train, epsilon); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return std::make_tuple(grad_input_, grad_weight_, grad_bias_); +} + +template +void batch_norm_stats_cuda_template( + const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input_, double epsilon) { + + using accscalar_t = at::acc_type; + int64_t n_input = input_.size(1); + Tensor dummy_mean_; + Tensor dummy_var_; + auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions + + resize_output(out_mean, {n_input}); + resize_output(out_invstd, {n_input}); + auto input = get_packed_accessor< + const scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input"); + TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() && + out_invstd.sizes()[0]); + TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() && + out_mean.sizes()[0]); + + auto mean = packed_accessor_or_dummy< + accscalar_t, 1, RestrictPtrTraits, index_t>(out_mean, "out_mean"); + auto invstd = packed_accessor_or_dummy< + accscalar_t, 1, RestrictPtrTraits, index_t>(out_invstd, "out_invstd"); + auto stream = at::cuda::getCurrentCUDAStream(); + + dim3 blocks(input.size(1)); + int tf = getNumThreads(input.size(2)); + dim3 threads(tf, std::max(1, MAX_BLOCK_SIZE/tf)); + batch_norm_collect_statistics_kernel <<>> + (input, epsilon, 0.0, mean, invstd); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void batch_norm_elemt_cuda_template(const Tensor& output_, const Tensor& input_, const Tensor& weight_, + const Tensor& bias_, const Tensor& mean_, const Tensor& invstd_) { + + using stat_accscalar_t = at::acc_type; + int64_t n_input = input_.size(1); + auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions + auto output_reshaped = output_.view({input_.size(0), input_.size(1), -1}); + + auto input = get_packed_accessor< + const input_scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input"); + auto output = get_packed_accessor< + input_scalar_t, 3, RestrictPtrTraits, index_t>(output_reshaped, "output"); + auto weight = packed_accessor_or_dummy< + const stat_scalar_t, 1, RestrictPtrTraits, index_t>(weight_, "weight"); + auto bias = packed_accessor_or_dummy< + const stat_scalar_t, 1, RestrictPtrTraits, index_t>(bias_, "bias"); + auto mean = packed_accessor_or_dummy< + stat_accscalar_t, 1, RestrictPtrTraits, index_t>(mean_, "mean"); + auto invstd = packed_accessor_or_dummy< + stat_accscalar_t, 1, RestrictPtrTraits, index_t>(invstd_, "invstd"); + auto stream = at::cuda::getCurrentCUDAStream(); + + // NOTE: We use transform_input_kernel in training mode, which ignores epsilon + const double dummy_epsilon = 1e-5; + + // The input_transform kernel is pointwise, but we need to balance reading parameters (save_var/mean, + // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks + // and good occupancy. Quiet likely, we could go with even more blocks than 1024. + // The various planes are independent, so we use blocks for them. + int tf = std::max(getNumThreads(input.size(2)/4), + std::min(getNumThreads(input.size(2)), 64)); + int tb = std::max(64/tf, 1); + dim3 blocks_trans(input.size(1), std::max(1, std::min((256*1024)/input.size(1), + (input.size(0)+tb-1)/tb))); + blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE); + dim3 threads_trans(tf, tb); + batch_norm_transform_input_kernel <<>> + (input, output, mean, invstd, weight, bias, dummy_epsilon); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +std::tuple batch_norm_gather_stats_cuda_template(const Tensor& mean_, const Tensor& invstd_, + const Tensor& running_mean_, const Tensor& running_var_, + double momentum, double epsilon, const Tensor& counts_) { + + Tensor save_mean_; + Tensor save_invstd_; + + auto features = mean_.size(1); + auto input_options = mean_.options(); + if (mean_.scalar_type() == at::ScalarType::Half || mean_.scalar_type() == at::ScalarType::BFloat16) { + input_options = input_options.dtype(ScalarType::Float); + } + save_mean_ = at::empty({features}, input_options); + save_invstd_ = at::empty({features}, input_options); + + auto mean = packed_accessor_or_dummy< + accscalar_t, 2, RestrictPtrTraits, index_t>(mean_, "mean"); + auto invstd = packed_accessor_or_dummy< + accscalar_t, 2, RestrictPtrTraits, index_t>(invstd_, "invstd"); + auto running_mean = packed_accessor_or_dummy< + scalar_t, 1, RestrictPtrTraits, index_t>(running_mean_, "running_mean"); + auto running_var = packed_accessor_or_dummy< + scalar_t, 1, RestrictPtrTraits, index_t>(running_var_, "running_mean"); + auto counts = packed_accessor_or_dummy< + scalar_t, 1, RestrictPtrTraits, index_t>(counts_, "counts"); + + auto save_mean = get_packed_accessor< + accscalar_t, 1, RestrictPtrTraits, index_t>(save_mean_, "save_mean"); + auto save_invstd = get_packed_accessor< + accscalar_t, 1, RestrictPtrTraits, index_t>(save_invstd_, "save_invstd"); + auto stream = at::cuda::getCurrentCUDAStream(); + + int block = getNumThreads(features); + int grid = std::max(1, features/block); + batch_norm_reduce_statistics_kernel <<>> + (mean, invstd, save_mean, save_invstd, running_mean, running_var, epsilon, momentum, counts); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return std::make_tuple(save_mean_, save_invstd_); +} + +template +std::tuple batch_norm_backward_reduce_cuda_template(const Tensor& grad_out_, const Tensor& input_, + const Tensor& mean_, const Tensor& invstd_, const Tensor& weight_, + const bool input_g, const bool weight_g, const bool bias_g) { + + using stat_accscalar_t = at::acc_type; + int64_t n_input = input_.size(1); + Tensor sum_dy_; + Tensor sum_dy_xmu_; + Tensor grad_weight_; + Tensor grad_bias_; + auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions + auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes()); + + if (input_g) { + sum_dy_ = at::empty_like(mean_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + sum_dy_xmu_ = at::empty_like(mean_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + if (weight_g) { + grad_weight_ = at::empty({n_input}, weight_.options()); + } + if (bias_g) { + grad_bias_ = at::empty({n_input}, weight_.options()); + } + + auto input = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input"); + auto grad_output = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output"); + auto grad_weight = packed_accessor_or_dummy< + stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_, "grad_weight"); + auto grad_bias = packed_accessor_or_dummy< + stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_, "grad_bias"); + auto mean = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean"); + auto invstd = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd"); + auto sum_dy = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy"); + auto sum_dy_xmu = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu"); + + auto batch_size = input_reshaped.size(0); + auto feature_size = input_reshaped.size(2); + auto stream = at::cuda::getCurrentCUDAStream(); + + int warp_size = at::cuda::warp_size(); + int block_y = std::min(lastPow2(batch_size), MAX_BLOCK_SIZE/warp_size); + // We want block_x to be at least a warp width + int block_x = std::min(std::max(getNumThreads(feature_size), warp_size), MAX_BLOCK_SIZE/block_y); + const dim3 block(block_x, block_y); + const dim3 grid(n_input); + + batch_norm_backward_reduce_kernel <<>> + (input, grad_output, mean, invstd, sum_dy, sum_dy_xmu, grad_weight, grad_bias); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return std::make_tuple(sum_dy_, sum_dy_xmu_, grad_weight_, grad_bias_); +} + +template +Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Tensor& input_, + const Tensor& mean_, const Tensor& invstd_, + const Tensor& weight_, const Tensor& sum_dy_, const Tensor& sum_dy_xmu_) { + + using stat_accscalar_t = at::acc_type; + int64_t n_input = input_.size(1); + auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions + auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes()); + auto grad_input_reshaped = at::empty_like(input_reshaped, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + + auto input = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input"); + auto grad_input = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input"); + auto grad_output = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output"); + auto mean = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean"); + auto invstd = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd"); + auto weight = packed_accessor_or_dummy< + stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight"); + auto sum_dy = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy"); + auto sum_dy_xmu = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu"); + + auto stream = at::cuda::getCurrentCUDAStream(); + + // The kernel is pointwise, but we need to balance reading parameters (save_var/mean, + // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks + // and good occupancy. Quiet likely, we could go with even more blocks than 1024. + // The various planes are independent, so we use blocks for them. + int tf = std::max(getNumThreads(input.size(2)/4), + std::min(getNumThreads(input.size(2)), 64)); + int tb = std::max(64/tf, 1); + dim3 blocks_trans(input.size(1), std::max(1, std::min((256*1024)/input.size(1), + (input.size(0)+tb-1)/tb))); + blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE); + dim3 threads_trans(tf, tb); + auto reduction_size = input_.numel() / n_input; + auto norm_fct = static_cast(1.0 / reduction_size); + batch_norm_backward_elemt_kernel + <<>> + (input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return grad_input_reshaped.view(input_.sizes()); +} + +template +Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Tensor& input_, + const Tensor& mean_, const Tensor& invstd_, + const Tensor& weight_, const Tensor& sum_dy_, const Tensor& sum_dy_xmu_, const Tensor& count) { + + using stat_accscalar_t = at::acc_type; + int64_t n_input = input_.size(1); + auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions + auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes()); + auto grad_input_reshaped = at::empty_like(input_reshaped, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + + auto input = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input"); + auto grad_input = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input"); + auto grad_output = get_packed_accessor< + input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output"); + auto mean = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean"); + auto invstd = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd"); + auto weight = packed_accessor_or_dummy< + stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight"); + auto sum_dy = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy"); + auto sum_dy_xmu = packed_accessor_or_dummy< + stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu"); + + auto stream = at::cuda::getCurrentCUDAStream(); + + // The kernel is pointwise, but we need to balance reading parameters (save_var/mean, + // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks + // and good occupancy. Quiet likely, we could go with even more blocks than 1024. + // The various planes are independent, so we use blocks for them. + int tf = std::max(getNumThreads(input.size(2)/4), + std::min(getNumThreads(input.size(2)), 64)); + int tb = std::max(64/tf, 1); + dim3 blocks_trans(input.size(1), std::max(1, std::min((256*1024)/input.size(1), + (input.size(0)+tb-1)/tb))); + blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE); + dim3 threads_trans(tf, tb); + batch_norm_backward_elemt_kernel <<>> + (input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, count.const_data_ptr(), count.numel()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return grad_input_reshaped.view(input_.sizes()); +} + +// welford kernel for c last tensor calculating mean/biased_variance/unbiased_variance +// original apex name: welford_kernel_c_last +template + +__global__ void +batch_norm_collect_statistics_channels_last_kernel( + const scalar_t* __restrict__ input, + accscalar_t* __restrict__ out_mean, + accscalar_t* __restrict__ out_invstd, + volatile accscalar_t* staging_data, + int* semaphores, + const int reduction_size, + const int stride, + accscalar_t epsilon) { + // hide latency with concurrency + accscalar_t x_mean[PARALLEL_LOADS]; + accscalar_t m_2_n[PARALLEL_LOADS]; + int count[PARALLEL_LOADS]; + +#pragma unroll + for (int i = 0; i < PARALLEL_LOADS; i++) { + x_mean[i] = accscalar_t(0); + m_2_n[i] = accscalar_t(0); + count[i] = accscalar_t(0); + } + // tensor dimension (m,c) + + // loop along m dimension + int inner_loop_stride = blockDim.y * gridDim.y; + + // offset along m dimension + int m_offset = blockIdx.y * blockDim.y + threadIdx.y; + int c_offset = blockIdx.x * blockDim.x + threadIdx.x; + + int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); + int address_base = m_offset * stride + c_offset; + int address_increment = inner_loop_stride * stride; + + for (int i = 0; i < loop_count; i++) { + accscalar_t x_math[PARALLEL_LOADS]; + accscalar_t x_count_inv[PARALLEL_LOADS]; + accscalar_t is_valid[PARALLEL_LOADS]; + + // load multiple data in +#pragma unroll + for (int j = 0; j < PARALLEL_LOADS; j++) { + if (c_offset < stride && m_offset < reduction_size) { + x_math[j] = input[address_base]; + count[j]++; + x_count_inv[j] = accscalar_t(1) / count[j]; + is_valid[j] = accscalar_t(1); + } else { + x_math[j] = accscalar_t(0); + x_count_inv[j] = accscalar_t(0); + is_valid[j] = accscalar_t(0); + } + m_offset += inner_loop_stride; + address_base += address_increment; + } + + // calculate mean/m2n with welford +#pragma unroll + for (int j = 0; j < PARALLEL_LOADS; j++) { + accscalar_t delta0 = x_math[j] - x_mean[j]; + x_mean[j] += delta0 * x_count_inv[j]; + accscalar_t delta1 = x_math[j] - x_mean[j]; + m_2_n[j] += delta0 * delta1 * is_valid[j]; + } + } + + // thread reduction to accumulate mean/m_2_n/count between PARALLEL_LOADS +#pragma unroll + for (int j = 1; j < PARALLEL_LOADS; j++) { + welford_merge_element(count[0], x_mean[0], m_2_n[0], count[j], x_mean[j], m_2_n[j]); + } + + // release x_mean / m_2_n + auto mean_th = x_mean[0]; + auto m2_th = m_2_n[0]; + auto count_th = count[0]; + + // block-wise reduction with shared memory (since reduction cannot be done within a warp) + static __shared__ accscalar_t shmem_mean[MAX_BLOCK_SIZE]; + static __shared__ accscalar_t shmem_m2n[MAX_BLOCK_SIZE]; + static __shared__ int shmem_count[MAX_BLOCK_SIZE]; + + welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n); + + if (gridDim.y > 1) { + volatile accscalar_t* staging_mean = staging_data; + volatile accscalar_t* staging_m2n = &staging_data[stride*gridDim.y]; + volatile int* staging_count = reinterpret_cast(&staging_m2n[stride*gridDim.y]); + + address_base = c_offset + blockIdx.y * stride; + // write data to staging_data; + if (threadIdx.y == 0 && c_offset < stride) { + staging_mean[address_base] = mean_th; + staging_m2n[address_base] = m2_th; + staging_count[address_base] = count_th; + } + + __threadfence(); + __syncthreads(); // ensuring writes to staging_ is visible to all blocks + + __shared__ bool is_last_block_done; + // mark block done + if (threadIdx.x == 0 && threadIdx.y == 0) { + int old = atomicAdd(&semaphores[blockIdx.x], 1); + is_last_block_done = (old == (gridDim.y-1)); + } + + __syncthreads(); + + // check that all data is now available in global memory + if (is_last_block_done) { + count_th = 0; + mean_th = accscalar_t(0.0); + m2_th = accscalar_t(0.0); + + for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) { + address_base = c_offset + y * stride; + int count_new = c_offset < stride ? staging_count[address_base] : 0; + accscalar_t mean_new = c_offset < stride ? staging_mean[address_base] : accscalar_t(0.0); + accscalar_t m2n_new = c_offset < stride ? staging_m2n[address_base] : accscalar_t(0.0); + + welford_merge_element(count_th, mean_th, m2_th, count_new, mean_new, m2n_new); + } + + welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n); + if (threadIdx.y == 0 && c_offset < stride) { + out_mean[c_offset] = static_cast(mean_th); + out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon); + } + } + } else { + if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) { + out_mean[c_offset] = static_cast(mean_th); + out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon); + } + } +} + +// elementwise BN kernel +// original apex name: batchnorm_forward_c_last_kernel +template < + typename scalar_t, + typename accscalar_t, + typename layerscalar_t, + int PARALLEL_LOADS> +__global__ void batch_norm_transform_input_channels_last_kernel( + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ z, + const accscalar_t* __restrict__ mean, + const accscalar_t* __restrict__ inv_std, + const layerscalar_t* __restrict__ weight, + const layerscalar_t* __restrict__ shift, + scalar_t* __restrict__ out, + const int reduction_size, + const int stride, + const bool fuse_relu) { + // tensor dimension (m,c) + // loop along m dimension + int inner_loop_stride = blockDim.y * gridDim.y; + + // offset along m dimension + int m_offset = blockIdx.y * blockDim.y + threadIdx.y; + int c_offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (c_offset >= stride || m_offset >= reduction_size) { + return; + } + + auto m_c = mean[c_offset]; + auto inv_std_c = static_cast(inv_std[c_offset]); + auto w_c = weight == nullptr ? accscalar_t(1.0) : static_cast(weight[c_offset]); + auto s_c = shift == nullptr ? accscalar_t(0.0) : static_cast(shift[c_offset]); + + int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); + int address_base = m_offset * stride + c_offset; + int address_increment = inner_loop_stride * stride; + + for (int i = 0; i < loop_count; i++) { +#pragma unroll + for (int j = 0; j < PARALLEL_LOADS; j++) { + if (c_offset < stride && m_offset < reduction_size) { + auto tmp = w_c * (static_cast(input[address_base]) - m_c ) * inv_std_c + s_c; + if (z != nullptr) { + tmp += z[address_base]; + } + out[address_base] = (fuse_relu && tmp <= accscalar_t(0.0) ? scalar_t(0.0) : static_cast(tmp)); + } + m_offset += inner_loop_stride; + address_base += address_increment; + } + } +} + +template +__device__ __forceinline__ void merge_block_vertical_backward(T& sum_dy, + T& sum_dy_xmu, + T* shmem_sum_dy, + T* shmem_sum_dy_xmu) { + // write to shared memory + auto address_base = threadIdx.x + threadIdx.y * blockDim.x; + +#pragma unroll + for (int offset = blockDim.y/2; offset > 0; offset >>= 1) { + if (threadIdx.y < offset*2) { + shmem_sum_dy[address_base] = sum_dy; + shmem_sum_dy_xmu[address_base] = sum_dy_xmu; + } + __syncthreads(); + if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { + auto address = address_base + offset * blockDim.x; + + sum_dy += shmem_sum_dy[address]; + sum_dy_xmu += shmem_sum_dy_xmu[address]; + } + } +} + +// batchnorm backward kernel for c last tensor +// original apex name: reduce_bn_c_last_kernel +template < + int PARALLEL_LOADS, + typename scalar_t, + typename accscalar_t, + typename layerscalar_t> +__global__ void batch_norm_backward_reduce_channels_last_kernel( + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ grad_output, + const accscalar_t* __restrict__ mean, + const accscalar_t* __restrict__ inv_std, + accscalar_t* __restrict__ sum_dy_o, + accscalar_t* __restrict__ sum_dy_xmu_o, + layerscalar_t* __restrict__ grad_weight, + layerscalar_t* __restrict__ grad_bias, + volatile accscalar_t* staging_data, + int* semaphores, + const int reduction_size, + const int stride) { + + // hide latency with concurrency + accscalar_t sum_dy[PARALLEL_LOADS]; + accscalar_t sum_dy_xmu[PARALLEL_LOADS]; + +#pragma unroll + for (int i = 0; i < PARALLEL_LOADS; i++) { + sum_dy[i] = accscalar_t(0); + sum_dy_xmu[i] = accscalar_t(0); + } + // tensor dimension (m,c) + + // loop along m dimension + int inner_loop_stride = blockDim.y * gridDim.y; + + // offset along m dimension + int m_offset = blockIdx.y * blockDim.y + threadIdx.y; + int c_offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (c_offset >= stride || m_offset >= reduction_size) { + return; + } + + int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); + int address_base = m_offset * stride + c_offset; + int address_increment = inner_loop_stride * stride; + + auto r_mean = mean[c_offset]; + auto factor = inv_std[c_offset]; + + for (int i = 0; i < loop_count; i++) { + accscalar_t x_input[PARALLEL_LOADS]; + accscalar_t x_grad_output[PARALLEL_LOADS]; + + // load multiple data in +#pragma unroll + for (int j = 0; j < PARALLEL_LOADS; j++) { + if (c_offset < stride && m_offset < reduction_size) { + x_input[j] = input[address_base]; + x_grad_output[j] = grad_output[address_base]; + } else { + x_input[j] = accscalar_t(0); + x_grad_output[j] = accscalar_t(0); + } + m_offset += inner_loop_stride; + address_base += address_increment; + } + + // calculate sum_dy / sum_dy_xmu +#pragma unroll + for (int j = 0; j < PARALLEL_LOADS; j++) { + sum_dy[j] += x_grad_output[j]; + sum_dy_xmu[j] += x_grad_output[j] * (x_input[j] - r_mean); + } + } + + // thread reduction to accumulate sum_dy / sum_dy_xmu between PARALLEL_LOADS +#pragma unroll + for (int j = 1; j < PARALLEL_LOADS; j++) { + sum_dy[0] += sum_dy[j]; + sum_dy_xmu[0] += sum_dy_xmu[j]; + } + + // release array of registers + auto sum_dy_th = sum_dy[0]; + auto sum_dy_xmu_th = sum_dy_xmu[0]; + + // block-wise reduction with shared memory (since reduction cannot be done within a warp) + static __shared__ accscalar_t shmem_sum_dy[MAX_BLOCK_SIZE]; + static __shared__ accscalar_t shmem_sum_dy_xmu[MAX_BLOCK_SIZE]; + + merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu); + + if (gridDim.y > 1) { + volatile accscalar_t* staging_sum_dy = staging_data; + volatile accscalar_t* staging_sum_dy_xmu = &staging_data[stride*gridDim.y]; + + address_base = c_offset + blockIdx.y * stride; + // write data to staging_data; + if (threadIdx.y == 0 && c_offset < stride) { + staging_sum_dy[address_base] = sum_dy_th; + staging_sum_dy_xmu[address_base] = sum_dy_xmu_th; + } + + __threadfence(); + __syncthreads(); // ensuring writes to staging_ is visible to all blocks + + __shared__ bool is_last_block_done; + // mark block done + if (threadIdx.x == 0 && threadIdx.y == 0) { + int old = atomicAdd(&semaphores[blockIdx.x], 1); + is_last_block_done = (old == (gridDim.y-1)); + } + + __syncthreads(); + + // check that all data is now available in global memory + if (is_last_block_done) { + sum_dy_th = accscalar_t(0.0); + sum_dy_xmu_th = accscalar_t(0.0); + + for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) { + address_base = c_offset + y * stride; + sum_dy_th += (c_offset < stride ? staging_sum_dy[address_base] : accscalar_t(0.0)); + sum_dy_xmu_th += (c_offset < stride ? staging_sum_dy_xmu[address_base] : accscalar_t(0.0)); + } + + merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu); + if (threadIdx.y == 0 && c_offset < stride) { + if (grad_bias != nullptr) { + grad_bias[c_offset] = static_cast(sum_dy_th); + } + if (grad_weight != nullptr) { + grad_weight[c_offset] = static_cast(sum_dy_xmu_th * factor); + } + //mean_dy[c_offset] = sum_dy_th / reduction_size; + //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size; + sum_dy_o[c_offset] = sum_dy_th; + sum_dy_xmu_o[c_offset] = sum_dy_xmu_th; + } + } + } else { + if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) { + if (grad_bias != nullptr) { + grad_bias[c_offset] = static_cast(sum_dy_th); + } + if (grad_weight != nullptr) { + grad_weight[c_offset] = static_cast(sum_dy_xmu_th * factor); + } + //mean_dy[c_offset] = sum_dy_th / reduction_size; + //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size; + sum_dy_o[c_offset] = sum_dy_th; + sum_dy_xmu_o[c_offset] = sum_dy_xmu_th; + } + } +} + +// elementwise BN kernel +// original apex name: batchnorm_backward_c_last_kernel +template < + int PARALLEL_LOADS, + typename scalar_t, + typename accscalar_t, + typename layerscalar_t> +__device__ __forceinline__ void batch_norm_backward_elemt_channels_last_kernel_impl( + const scalar_t* __restrict__ grad_output, + const scalar_t* __restrict__ input, + const accscalar_t* __restrict__ mean, + const accscalar_t* __restrict__ inv_std, + const layerscalar_t* __restrict__ weight, + const accscalar_t* __restrict__ sum_dy, + const accscalar_t* __restrict__ sum_dy_xmu, + scalar_t* __restrict__ grad_input, + const accscalar_t norm_fct, + const int reduction_size, + const int stride) { + // tensor dimension (m,c) + // loop along m dimension + int inner_loop_stride = blockDim.y * gridDim.y; + + // offset along m dimension + int m_offset = blockIdx.y * blockDim.y + threadIdx.y; + int c_offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (c_offset >= stride || m_offset >= reduction_size) { + return; + } + + auto m_c = mean[c_offset]; + auto m_dy_c = sum_dy[c_offset] * norm_fct; + auto factor_1_c = inv_std[c_offset]; + auto factor_2_c = (weight == nullptr? accscalar_t(1.0) : static_cast(weight[c_offset])) * factor_1_c; + factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[c_offset] * norm_fct; + + int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); + int address_base = m_offset * stride + c_offset; + int address_increment = inner_loop_stride * stride; + + for (int i = 0; i < loop_count; i++) { +#pragma unroll + for (int j = 0; j < PARALLEL_LOADS; j++) { + if (c_offset < stride && m_offset < reduction_size) { + grad_input[address_base] = static_cast( + (static_cast(grad_output[address_base]) - m_dy_c - + (static_cast(input[address_base]) - m_c) * factor_1_c) + * factor_2_c); + } + m_offset += inner_loop_stride; + address_base += address_increment; + } + } +} + +template < + int PARALLEL_LOADS, + typename scalar_t, + typename accscalar_t, + typename layerscalar_t> +__global__ void batch_norm_backward_elemt_channels_last_kernel( + const scalar_t* __restrict__ grad_output, + const scalar_t* __restrict__ input, + const accscalar_t* __restrict__ mean, + const accscalar_t* __restrict__ inv_std, + const layerscalar_t* __restrict__ weight, + const accscalar_t* __restrict__ sum_dy, + const accscalar_t* __restrict__ sum_dy_xmu, + const int* __restrict__ numel, + scalar_t* __restrict__ grad_input, + const int64_t world_size, + const int reduction_size, + const int stride) { + + int64_t total_numel = 0; + for (int i = 0; i < world_size; i++) { + total_numel += numel[i]; + } + + auto norm_fct = static_cast(1) / static_cast(total_numel); + batch_norm_backward_elemt_channels_last_kernel_impl( + grad_output, input, mean, inv_std, weight, sum_dy, sum_dy_xmu, + grad_input, norm_fct, reduction_size, stride); +} + +template < + int PARALLEL_LOADS, + typename scalar_t, + typename accscalar_t, + typename layerscalar_t> +__global__ void batch_norm_backward_elemt_channels_last_kernel( + const scalar_t* __restrict__ grad_output, + const scalar_t* __restrict__ input, + const accscalar_t* __restrict__ mean, + const accscalar_t* __restrict__ inv_std, + const layerscalar_t* __restrict__ weight, + const accscalar_t* __restrict__ sum_dy, + const accscalar_t* __restrict__ sum_dy_xmu, + scalar_t* __restrict__ grad_input, + const accscalar_t norm_fct, + const int reduction_size, + const int stride) { + batch_norm_backward_elemt_channels_last_kernel_impl( + grad_output, input, mean, inv_std, weight, sum_dy, sum_dy_xmu, + grad_input, norm_fct, reduction_size, stride); +} + +template +void batch_norm_stats_channels_last_cuda_template( + const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input, double epsilon) { + using accscalar_t = at::acc_type; + + const auto stride = input.sizes()[1]; + const auto reduction_size = input.numel() / stride; + + resize_output(out_mean, {stride}); + resize_output(out_invstd, {stride}); + TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() && + out_invstd.sizes()[0]); + TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() && + out_mean.sizes()[0]); + + dim3 block; + dim3 grid; + flexible_launch_configs(reduction_size, stride, block, grid, true); + + at::Tensor staging_data; + at::Tensor semaphores; + if (grid.y > 1) { + staging_data = at::empty({4*stride*grid.y}, out_mean.options()); + semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt)); + } + + accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.mutable_data_ptr() : nullptr; + int* semaphores_ptr = grid.y > 1 ? semaphores.mutable_data_ptr() : nullptr; + batch_norm_collect_statistics_channels_last_kernel + <<>>( + input.const_data_ptr(), + out_mean.mutable_data_ptr(), + out_invstd.mutable_data_ptr(), + staging_data_ptr, + semaphores_ptr, + reduction_size, + stride, + epsilon); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +void batch_norm_elemt_channels_last_cuda_template( + const at::Tensor& output, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& shift, // bias of BN + const at::Tensor& mean, + const at::Tensor& inv_std, + const std::optional& z = std::nullopt, // bias after BN + const bool fuse_relu = false) { + const auto stride = input.sizes()[1]; + const auto reduction_size = input.numel() / stride; + + dim3 block; + dim3 grid; + flexible_launch_configs(reduction_size, stride, block, grid); + + auto stream = at::cuda::getCurrentCUDAStream(); + const auto second_dtype = weight.defined() ? weight.scalar_type() : + (shift.defined() ? shift.scalar_type() : input.scalar_type()); + + if (input.scalar_type() != second_dtype) { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward", [&] { + using accscalar_t = at::acc_type; + batch_norm_transform_input_channels_last_kernel + <<>>( + input.const_data_ptr(), + z.has_value() ? z.value().const_data_ptr() : nullptr, + mean.const_data_ptr(), + inv_std.const_data_ptr(), + weight.defined() ? weight.const_data_ptr() : nullptr, + shift.defined() ? shift.const_data_ptr() : nullptr, + output.mutable_data_ptr(), + reduction_size, + stride, + fuse_relu); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } else { + if (weight.defined()){ + TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_forward: input.scalar_type() ", input.scalar_type(), + " is not supported with weight.scalar_type() ", weight.scalar_type()); + } + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward", [&] { + using accscalar_t = at::acc_type; + batch_norm_transform_input_channels_last_kernel + <<>>( + input.const_data_ptr(), + z.has_value() ? z.value().const_data_ptr() : nullptr, + mean.const_data_ptr(), + inv_std.const_data_ptr(), + weight.defined() ? weight.const_data_ptr() : nullptr, + shift.defined() ? shift.const_data_ptr(): nullptr, + output.mutable_data_ptr(), + reduction_size, + stride, + fuse_relu); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } +} + +std::tuple +batch_norm_backward_reduce_cuda_channels_last_template(const at::Tensor& grad_output, + const at::Tensor& input, + const at::Tensor& mean, + const at::Tensor& inv_std, + const at::Tensor& weight, + const bool input_g, const bool weight_g, const bool bias_g) { + const auto stride = input.sizes()[1]; + const auto reduction_size = input.numel() / stride; + + at::Tensor sumn_dy = at::empty({stride}, mean.options()); + at::Tensor sum_dy_xmu = at::empty({stride}, mean.options()); + + at::Tensor grad_weight; + at::Tensor grad_bias; + if (weight.defined()) { + grad_weight = at::empty({stride}, weight.options()); + grad_bias = at::empty({stride}, weight.options()); + } else { + // because I cannot return an uninitialized at::Tensor + grad_weight = at::empty({0}, mean.options()); + grad_bias = at::empty({0}, mean.options()); + } + + dim3 block; + dim3 grid; + flexible_launch_configs(reduction_size, stride, block, grid, true); + + at::Tensor staging_data; + at::Tensor semaphores; + if (grid.y > 1) { + staging_data = at::empty({2*stride*grid.y}, mean.options()); + semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt)); + } + auto stream = at::cuda::getCurrentCUDAStream(); + + if (weight.defined() && input.scalar_type() != weight.scalar_type()) { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_reduce", [&] { + using accscalar_t = at::acc_type; + accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.mutable_data_ptr() : nullptr; + int* semaphores_ptr = grid.y > 1 ? semaphores.mutable_data_ptr() : nullptr; + batch_norm_backward_reduce_channels_last_kernel + <<>>( + input.const_data_ptr(), + grad_output.const_data_ptr(), + mean.const_data_ptr(), + inv_std.const_data_ptr(), + sumn_dy.mutable_data_ptr(), + sum_dy_xmu.mutable_data_ptr(), + grad_weight.mutable_data_ptr(), + grad_bias.mutable_data_ptr(), + staging_data_ptr, + semaphores_ptr, + reduction_size, + stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } else { + if (weight.defined()) { + TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_backward_reduce: input.scalar_type() ", input.scalar_type(), + " is not supported with weight.scalar_type() ", weight.scalar_type()); + } + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_reduce", [&] { + using accscalar_t = at::acc_type; + accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.mutable_data_ptr() : nullptr; + int* semaphores_ptr = grid.y > 1 ? semaphores.mutable_data_ptr() : nullptr; + batch_norm_backward_reduce_channels_last_kernel + <<>>( + input.const_data_ptr(), + grad_output.const_data_ptr(), + mean.const_data_ptr(), + inv_std.const_data_ptr(), + sumn_dy.mutable_data_ptr(), + sum_dy_xmu.mutable_data_ptr(), + weight.defined() ? grad_weight.mutable_data_ptr() : nullptr, + weight.defined() ? grad_bias.mutable_data_ptr() : nullptr, + staging_data_ptr, + semaphores_ptr, + reduction_size, + stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } + + return std::make_tuple(sumn_dy, sum_dy_xmu, grad_weight, grad_bias); +} + +at::Tensor batch_norm_backward_elemt_channels_last_cuda_template( + const at::Tensor& grad_output, + const at::Tensor& input, + const at::Tensor& mean, + const at::Tensor& inv_std, + const at::Tensor& weight, + const at::Tensor& sum_dy, + const at::Tensor& sum_dy_xmu, + const at::Tensor& count) { + const auto stride = input.sizes()[1]; + const auto reduction_size = input.numel() / stride; + + // Input is guarunteed to be channels-last compatible + at::Tensor grad_input = at::empty_like(input); + + dim3 block; + dim3 grid; + flexible_launch_configs(reduction_size, stride, block, grid); + + auto stream = at::cuda::getCurrentCUDAStream(); + + if (weight.defined() && weight.scalar_type() != input.scalar_type()) { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_element", [&] { + using accscalar_t = at::acc_type; + batch_norm_backward_elemt_channels_last_kernel + <<>>( + grad_output.const_data_ptr(), + input.const_data_ptr(), + mean.const_data_ptr(), + inv_std.const_data_ptr(), + weight.const_data_ptr(), + sum_dy.const_data_ptr(), + sum_dy_xmu.const_data_ptr(), + count.const_data_ptr(), + grad_input.mutable_data_ptr(), + count.numel(), + reduction_size, + stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } else { + if (weight.defined()) { + TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_backward_element: input.scalar_type() ", input.scalar_type(), + " is not supported with weight.scalar_type() ", weight.scalar_type()); + } + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "batchnorm_backward_element", [&] { + using accscalar_t = at::acc_type; + batch_norm_backward_elemt_channels_last_kernel + <<>>( + grad_output.const_data_ptr(), + input.const_data_ptr(), + mean.const_data_ptr(), + inv_std.const_data_ptr(), + weight.defined() ? weight.const_data_ptr() : nullptr, + sum_dy.const_data_ptr(), + sum_dy_xmu.const_data_ptr(), + count.const_data_ptr(), + grad_input.mutable_data_ptr(), + count.numel(), + reduction_size, + stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } + + return grad_input; +} + +at::Tensor batch_norm_backward_elemt_channels_last_cuda_template( + const at::Tensor& grad_output, + const at::Tensor& input, + const at::Tensor& mean, + const at::Tensor& inv_std, + const at::Tensor& weight, + const at::Tensor& sum_dy, + const at::Tensor& sum_dy_xmu) { + const auto stride = input.sizes()[1]; + const auto reduction_size = input.numel() / stride; + auto norm_fct = 1.0 / reduction_size; + + // Input is guarunteed to be channels-last compatible + at::Tensor grad_input = at::empty_like(input); + + dim3 block; + dim3 grid; + flexible_launch_configs(reduction_size, stride, block, grid); + + auto stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_element", [&] { + using accscalar_t = at::acc_type; + + if (weight.defined() && weight.scalar_type() != input.scalar_type()) { + batch_norm_backward_elemt_channels_last_kernel + <<>>( + grad_output.const_data_ptr(), + input.const_data_ptr(), + mean.const_data_ptr(), + inv_std.const_data_ptr(), + weight.const_data_ptr(), + sum_dy.const_data_ptr(), + sum_dy_xmu.const_data_ptr(), + grad_input.mutable_data_ptr(), + static_cast(norm_fct), + reduction_size, + stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + batch_norm_backward_elemt_channels_last_kernel + <<>>( + grad_output.const_data_ptr(), + input.const_data_ptr(), + mean.const_data_ptr(), + inv_std.const_data_ptr(), + weight.defined() ? weight.const_data_ptr() : nullptr, + sum_dy.const_data_ptr(), + sum_dy_xmu.const_data_ptr(), + grad_input.mutable_data_ptr(), + static_cast(norm_fct), + reduction_size, + stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + }); + + return grad_input; +} + +} } // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Pow.cuh b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Pow.cuh new file mode 100644 index 0000000000000000000000000000000000000000..9530b0ede27459d33fe9c8a01b71129621da499c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Pow.cuh @@ -0,0 +1,58 @@ +#pragma once +#include +#include + +namespace at { namespace native { + +namespace { + + +// SFINAE doesn't work well with NVCC under Windows for math functions like pow and sqrt. +// So we need to define the functions with the explicit function signatures. +// As for pow, the following signatures are defined as the device function: +// pow(float, int) +// pow(double, int) +// pow(float, float) +// pow(double, double) +#ifdef _MSC_VER +// Functions for pow +// pow for at::Half +static inline __host__ __device__ at::Half pow_(at::Half base, at::Half exp) { + return static_cast(std::pow(static_cast(base), static_cast(exp))); +} +// pow for at::BFloat16 +static inline __host__ __device__ at::BFloat16 pow_(at::BFloat16 base, at::BFloat16 exp) { + return static_cast(std::pow(static_cast(base), static_cast(exp))); +} +// pow (floating, floating/int) +template +static inline __host__ __device__ typename std::enable_if::value && (std::is_same::value || std::is_same::value), Base_type>::type + pow_(Base_type base, Exp_type exp) { + return std::pow(base, exp); +} +// pow (Otherwise) +template +static inline __host__ __device__ typename std::enable_if::value && !std::is_same::value, Base_type>::type + pow_(Base_type base, Exp_type exp) { + return static_cast(std::pow(static_cast(base), static_cast(exp))); +} +#else +template +static inline __host__ __device__ Base_type pow_(Base_type base, Exp_type exp) { + return ::pow(base, exp); +} +#endif + +template +static inline __host__ __device__ std::enable_if_t::value, T> pow_( + T base, T exp) { + return at::native::powi(base, exp); +} + +template +static inline __host__ __device__ c10::complex pow_(c10::complex base, c10::complex exp) { + return c10_complex_math::pow(base, exp); +} + +} // namespace +}} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Randperm.cuh b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Randperm.cuh new file mode 100644 index 0000000000000000000000000000000000000000..c1dca45feae82915bc7dbd107c2eb391380044e5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Randperm.cuh @@ -0,0 +1,58 @@ +#include +#include +#include + +#include +#include +#include + +namespace { + +// See note [Algorithm of randperm] +template +__global__ void randperm_handle_duplicate_keys_kernel(T *keys, scalar_t *data, T mask, int n, at::PhiloxCudaState philox_args) { + int tid = threadIdx.x + blockDim.x * blockIdx.x; + + // find the beginning of islands + if (tid >= n - 1) return; // out of range + if ((keys[tid] & mask) != (keys[tid + 1] & mask)) return; // not in an island + if (tid != 0 && (keys[tid] & mask) == (keys[tid - 1] & mask)) return; // not the beginning of an island + + // find the size of islands + int island_size = 0; + do { island_size++; } + while ((tid + island_size < n) && (keys[tid + island_size] & mask) == (keys[tid] & mask)); + + // do random permutation inside each island. + data += tid; + auto seeds = at::cuda::philox::unpack(philox_args); + curandStatePhilox4_32_10_t state; + curand_init(std::get<0>(seeds), tid, std::get<1>(seeds), &state); + for (int i = island_size - 1; i > 0; i--) { + unsigned int r = curand(&state) % (i + 1); + if (i != r) { + scalar_t tmp = data[i]; + data[i] = data[r]; + data[r] = tmp; + } + } +} + +// See note [Algorithm of randperm] +template +void randperm_handle_duplicate_keys(T *keys, scalar_t *data, int bits, int64_t n, std::optional &gen_) { + auto gen = at::get_generator_or_default(gen_, at::cuda::detail::getDefaultCUDAGenerator()); + int64_t counter_offset = n; + at::PhiloxCudaState rng_engine_inputs; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_cuda_state(counter_offset); + } + T mask = static_cast((1UL << bits) - 1); + randperm_handle_duplicate_keys_kernel<<<(n + 511) / 512, 512, 0, at::cuda::getCurrentCUDAStream()>>>( + keys, data, mask, n, rng_engine_inputs); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +} diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Resize.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Resize.h new file mode 100644 index 0000000000000000000000000000000000000000..d5de128cac1d2014cfa8274facbf9cea6f43bf4a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Resize.h @@ -0,0 +1,53 @@ +#pragma once + +#include +#include + +#include + +namespace at { namespace native { + +TORCH_CUDA_CPP_API void resize_bytes_cuda(StorageImpl* storage, size_t size_bytes); + +static inline void maybe_resize_storage_cuda(TensorImpl* self, size_t new_size_bytes) { + // It does not make sense to try to resize a storage + // to hold 0 elements, and this can break + // if storage_offset is positive but + // new_size is 0, so just bail in that case + // (same comment is in Resize.h) + if (self->numel() == 0) { + return; + } + + const Storage &storage = self->unsafe_storage(); + TORCH_CHECK(storage, "Tensor: invalid null storage"); + if (new_size_bytes > storage.nbytes()) { + resize_bytes_cuda(storage.unsafeGetStorageImpl(), new_size_bytes); + } +} + +inline TensorImpl* resize_impl_cuda_( + TensorImpl* self, + IntArrayRef size, + at::OptionalIntArrayRef stride) { + if (self->sizes() == size && (!stride || self->strides() == stride)) { + return self; + } + const auto itemsize = self->dtype().itemsize(); + const auto storage_offset = self->storage_offset(); + size_t storage_size = 1; + if (stride) { + self->set_sizes_and_strides(size, *stride); + storage_size = at::detail::computeStorageNbytes( + size, *stride, itemsize, storage_offset); + } else { + self->set_sizes_contiguous(size); + storage_size = at::detail::computeStorageNbytesContiguous( + size, itemsize, storage_offset); + } + maybe_resize_storage_cuda(self, storage_size); + + return self; +} + +}} diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/RowwiseScaledMM.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/RowwiseScaledMM.h new file mode 100644 index 0000000000000000000000000000000000000000..4dee144d24659c89c4cda79fdfe953acd07d867c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/RowwiseScaledMM.h @@ -0,0 +1,15 @@ +#pragma once +#include +#include + + +namespace at::cuda::detail { +TORCH_API void f8f8bf16_rowwise( + at::Tensor XQ, // FP8 + at::Tensor WQ, // FP8 + at::Tensor x_scale, // FP32 + at::Tensor w_scale, // FP32 + std::optional bias, // BF16 + bool use_fast_accum, + at::Tensor& out); +} // at::cuda::detail diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/ScanUtils.cuh b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/ScanUtils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..f9de15fdf912b47635b2cb1388001750eec50959 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/ScanUtils.cuh @@ -0,0 +1,459 @@ +#pragma once +#include +#include +#include +#include + +#include +#include +#include + +namespace at { +namespace native { + +template +constexpr inline integer ceil_div(integer n, integer m) { + return (n + m - 1) / m; +} + +template +constexpr inline integer get_log_num_threads_x_inner_scan(integer num_rows, integer row_size) { + integer log_num_threads_x = 0; + integer log_num_threads_y = 0; + while (((integer)1 << log_num_threads_x) < row_size) { + ++log_num_threads_x; + } + while (((integer)1 << log_num_threads_y) < num_rows) { + ++log_num_threads_y; + } + // we want to keep the ratio between the x-threads and y-threads about the same as + // the ratio between the row_size and num_rows, but the total number of threads in + // a block should be about 512 + integer diff = log_num_threads_x - log_num_threads_y; + // 9 is from log2(512) + log_num_threads_x = ((integer)9 + diff) / (integer)2; + // I found that in having larger log_num_threads_x can give significant speed up in some cases, + // but detrimental in another case, so just keep the lower bound to be log2(16) == 4 to make it + // similar to the previous implementation + // Keeping the upper bound to be log2(512) == 9 as the maximum number of threads in a block. + log_num_threads_x = std::min(std::max((integer)4, log_num_threads_x), (integer)9); + return log_num_threads_x; +} + +template +__device__ void binary_op_update(const scalar_t lhs, scalar_t& rhs, const idx_t lhs_idx, idx_t& rhs_idx, BinaryOperation binary_op) { + if(!at::_isnan(rhs) && (at::_isnan(lhs) || !binary_op(rhs, lhs))) { + rhs = lhs; + rhs_idx = lhs_idx; + } +} +/* Perform an inclusive scan along the innermost dimension of a tensor. + * + * - num_rows is the size of the flattened outer dimensions; + * - row_size is the size of the innermost dimension; + * + * The outer dimensions of the tensor are considered as a single dimension, i.e. the tensor is + * considered as having 'num_rows' rows of size 'row_size'. + * Each thread block processes one or more sets of contiguous rows (processing multiple rows + * per thread block is quicker than processing a single row, especially for short rows). + */ +template +__global__ void tensor_kernel_scan_innermost_dim_with_indices(const scalar_t *self_, scalar_t *values_, int64_t *indices_, + int num_rows, int row_size, + const uint32_t num_threads, const uint32_t log_num_threads_x, + scalar_t init, BinaryFunction binary_op) { + // dynamic memory allocation for vbuf and ibuf + alignas(sizeof(double)) extern __shared__ char buf[]; + scalar_t* vbuf = reinterpret_cast(buf); // the size is num_threads * 2 + int64_t* ibuf = reinterpret_cast(vbuf + num_threads * 2); + const uint32_t num_threads_x = 1 << log_num_threads_x; + scalar_t* row_buf = vbuf + 2 * num_threads_x * threadIdx.y; + int64_t* row_idx_buf = ibuf + 2 * num_threads_x * threadIdx.y; + + for (int block_row = blockIdx.x * blockDim.y; + block_row < num_rows; + block_row += blockDim.y * gridDim.x) { + int row = block_row + threadIdx.y; + const scalar_t *row_self = self_ + row * row_size; + scalar_t *row_values = values_ + row * row_size; + int64_t *row_indices = indices_ + row * row_size; + scalar_t block_total = init; + int64_t block_idx_final = 0; + const bool row_exists = row < num_rows; + // Perform scan on one block at a time, keeping track of the total value of + // all blocks processed so far. + for (int block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) { + // Load data into shared memory (two values per thread). + int col1 = block_col + threadIdx.x; + int col2 = block_col + num_threads_x + threadIdx.x; + if (row_exists) { + if (col1 < row_size) { + row_buf[threadIdx.x] = c10::load(&row_self[col1]); + row_idx_buf[threadIdx.x] = col1; + } else { + row_buf[threadIdx.x] = init; + // No need to set the index here as the value in init will never be selected + } + + if (col2 < row_size) { + row_buf[num_threads_x + threadIdx.x] = c10::load(&row_self[col2]); + row_idx_buf[num_threads_x + threadIdx.x] = col2; + } else { + row_buf[num_threads_x + threadIdx.x] = init; + // No need to set the index here as the value in init will never be selected + } + + // Add the total value of all previous blocks to the first value of this block. + if (threadIdx.x == 0) { + binary_op_update(block_total, row_buf[0], block_idx_final, row_idx_buf[0], binary_op); + } + } + __syncthreads(); + + // Parallel reduction with Sklansky method. The diagram can be seen on this paper: + // https://research.nvidia.com/publication/single-pass-parallel-prefix-scan-decoupled-look-back + for (uint32_t s = 1; s <= num_threads_x; s <<= 1) { + if (row_exists) { + uint32_t a = (threadIdx.x / s) * (2 * s) + s; + uint32_t ti = a + (threadIdx.x % s); + uint32_t si = a - 1; + binary_op_update(row_buf[si], row_buf[ti], row_idx_buf[si], row_idx_buf[ti], binary_op); + } + __syncthreads(); + } + + // Write back to output. + if (row_exists) { + if (col1 < row_size){ + row_values[col1] = row_buf[threadIdx.x]; + row_indices[col1] = row_idx_buf[threadIdx.x]; + } + if (col2 < row_size) { + row_values[col2] = row_buf[num_threads_x + threadIdx.x]; + row_indices[col2] = row_idx_buf[num_threads_x + threadIdx.x]; + } + } + block_total = row_buf[2 * num_threads_x - 1]; + block_idx_final = row_idx_buf[2 * num_threads_x - 1]; + __syncthreads(); + } + } +} + +/* Perform an inclusive scan along an outer dimension of a tensor. + * + * - num_orows is the size of the flattened outer dimensions; + * - num_irows is the size of the flattened inner dimensions; + * - row_size is the size of the dimension along which to compute the variance; + * + * The dimensions to the outside and inside of the specified dimension are considered as flattened. + * Thread blocks with the same blockIdx.y process an "outer row" (i.e. an element of the flattened + * outer dimensions, which contains several "inner rows"). + * Each thread processes a single inner row at a time. + */ +template +__global__ void tensor_kernel_scan_outer_dim_with_indices(const scalar_t *self_, scalar_t *values_, int64_t *indices_, + const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size, scalar_t init, BinaryFunction binary_op) { + for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { + for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { + const scalar_t *self = self_ + orow * row_size * num_irows + irow; + scalar_t *values = values_ + orow * row_size * num_irows + irow; + int64_t *indices = indices_ + orow * row_size * num_irows + irow; + scalar_t out = init; + int64_t out_idx = 0; + + for (auto col = decltype(row_size){0}; col < row_size; ++col) { + const auto val = c10::load(self); + if(at::_isnan(val) || (!at::_isnan(out) && binary_op(val, out))) { + out = val; + out_idx = col; + } + *values = out; + *indices = out_idx; + self += num_irows; + values += num_irows; + indices += num_irows; + } + } + } +} + +inline void check_fits_in_unsigned(int64_t val, const char* name) { + constexpr auto umax = std::numeric_limits::max(); + TORCH_CHECK( + val >= 0 && val <= umax, name, " must fit in a 32-bit uint32_t value"); +} + + +template +__host__ void scan_outer_dim_with_indices( + const TensorBase& self, const TensorBase& values, const TensorBase& indices, + int dim, scalar_t init, BinaryFunction binary_op) { + int64_t row_size = self.size(dim); + auto sizes = self.sizes(); + + // Treat all outer dimensions (i.e. dim_ < dim) as one. + const int64_t num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + dim); + + // Treat all inner dimensions (i.e. dim > dimension) as one. + const int64_t num_irows = c10::multiply_integers(sizes.begin() + dim + 1, sizes.end()); + //for performance reasons, cuda kernels use uint32_t for loops over irows, orows and row, + //make sure that input is not bigger than supported by uint32_t + check_fits_in_unsigned(num_irows, "num_irows"); + check_fits_in_unsigned(num_orows, "num_orows"); + check_fits_in_unsigned(row_size, "row_size"); + + + dim3 threads(std::min(512, int(num_irows))); + int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x}))); + tensor_kernel_scan_outer_dim_with_indices<<>>( + self.const_data_ptr(), values.mutable_data_ptr(), indices.mutable_data_ptr(), + num_orows, num_irows, row_size, init, binary_op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +__host__ void scan_innermost_dim_with_indices( + const TensorBase& self, const TensorBase& values, const TensorBase& indices, + scalar_t init, BinaryFunction binary_op) { + int ndim = self.dim(); + // Treat all outer dimensions as a single dimension. + int row_size = self.size(ndim - 1); + int num_rows = self.numel() / row_size; + + // assuming max_num_threads per block is 512 + const uint32_t num_threads = 512; + const uint32_t log_num_threads_x = get_log_num_threads_x_inner_scan(num_rows, row_size); + const uint32_t num_threads_x = (1 << log_num_threads_x); + const uint32_t num_threads_y = num_threads / num_threads_x; + dim3 threads(num_threads_x, num_threads_y); + dim3 grid(std::min(at::cuda::getCurrentDeviceProperties()->maxGridSize[0], ceil_div(num_rows, int(threads.y)))); + + const uint32_t mem_size = 2 * num_threads * (sizeof(scalar_t) + sizeof(int64_t)); + tensor_kernel_scan_innermost_dim_with_indices<<>>( + self.const_data_ptr(), values.mutable_data_ptr(), indices.mutable_data_ptr(), + num_rows, row_size, num_threads, log_num_threads_x, init, binary_op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void scan_dim_with_indices(const TensorBase& self, const TensorBase& values, const TensorBase& indices, //int64_t dim) { + int64_t dim, scalar_t init, BinaryFunction binary_op) { + int ndim = self.dim(); + auto self_ = self.expect_contiguous(); + TORCH_INTERNAL_ASSERT(values.is_contiguous() && indices.is_contiguous()); + if (dim == ndim - 1) { + scan_innermost_dim_with_indices(*self_, values, indices, init, binary_op); + } else { + scan_outer_dim_with_indices(*self_, values, indices, dim, init, binary_op); + } +} + +// TODO: The implementation of `tensor_kernel_scan_outer_dim` and +// `tensor_kernel_scan_innermost_dim` is similar to +// `tensor_kernel_scan_outer_dim_with_indices` +// `tensor_kernel_scan_outer_dim_with_indices` and should be refactored to +// remove the duplication. + +/* Perform an inclusive scan along an outer dimension of a tensor. + * + * - num_orows is the size of the flattened outer dimensions; + * - num_irows is the size of the flattened inner dimensions; + * - row_size is the size of the dimension along which to scan; + * + * The dimensions to the outside and inside of the specified dimension are considered as flattened. + * Thread blocks with the same blockIdx.y process an "outer row" (i.e. an element of the flattened + * outer dimensions, which contains several "inner rows"). + * Each thread processes a single inner row at a time. + */ +template +__global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, const scalar_t *src_, + const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size, + const scalar_t init, BinaryOp binary_op) +{ + for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { + for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { + const scalar_t *src = src_ + orow * row_size * num_irows + irow; + scalar_t *tgt = tgt_ + orow * row_size * num_irows + irow; + scalar_t acc = init; + + for (uint32_t col = 0; col < row_size; ++col) { + acc = binary_op(acc, c10::load(src)); + *tgt = acc; + + src += num_irows; + tgt += num_irows; + } + } + } +} + +/* Perform an inclusive scan along the innermost dimension of a tensor. + * + * - num_rows is the size of the flattened outer dimensions; + * - row_size is the size of the innermost dimension; + * + * The outer dimensions of the tensor are considered as a single dimension, i.e. the tensor is + * considered as having 'num_rows' rows of size 'row_size'. + * Each thread block processes one or more sets of contiguous rows (processing multiple rows + * per thread block is quicker than processing a single row, especially for short rows). + */ +template +__device__ void tensor_kernel_scan_innermost_dim_impl(T* row_buf, T *tgt_, const T *src_, + const uint32_t num_rows, const uint32_t row_size, + const uint32_t log_num_threads_x, + T init, BinaryFunction binary_op){ + const uint32_t num_threads_x = 1 << log_num_threads_x; + for (uint32_t block_row = blockIdx.x * blockDim.y; + block_row < num_rows; + block_row += blockDim.y * gridDim.x) { + uint32_t row = block_row + threadIdx.y; + T block_total = init; + + const T *row_src = src_ + row * row_size; + T *row_tgt = tgt_ + row * row_size; + const bool row_exists = row < num_rows; + + // Perform scan on one block at a time, keeping track of the total value of + // all blocks processed so far. + for (uint32_t block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) { + // Load data into shared memory (two values per thread). + uint32_t col1 = block_col + threadIdx.x; + uint32_t col2 = block_col + num_threads_x + threadIdx.x; + if (row_exists) { + if (col1 < row_size) { + row_buf[threadIdx.x] = row_src[col1]; + } else { + row_buf[threadIdx.x] = init; + } + + if (col2 < row_size) { + row_buf[num_threads_x + threadIdx.x] = row_src[col2]; + } else { + row_buf[num_threads_x + threadIdx.x] = init; + } + + // Add the total value of all previous blocks to the first value of this block. + if (threadIdx.x == 0) { + row_buf[0] = binary_op(row_buf[0], block_total); + } + } + __syncthreads(); + + // Parallel reduction with Sklansky method. The diagram can be seen on this paper: + // https://research.nvidia.com/publication/single-pass-parallel-prefix-scan-decoupled-look-back + for (uint32_t m = 0; m <= log_num_threads_x; ++m) { + if (row_exists) { + uint32_t s = 1 << m; // s = 2 ^ m + uint32_t a = ((threadIdx.x >> m) << (m + 1)) | s; // a = (threadIdx.x / s) * (2 * s) + s + uint32_t ti = a + (threadIdx.x % s); + uint32_t si = a - 1; + row_buf[ti] = binary_op(row_buf[ti], row_buf[si]); + } + __syncthreads(); + } + + // Write back to output. + if (row_exists) { + if (col1 < row_size) row_tgt[col1] = row_buf[threadIdx.x]; + if (col2 < row_size) row_tgt[col2] = row_buf[num_threads_x + threadIdx.x]; + } + block_total = row_buf[2 * num_threads_x - 1]; + __syncthreads(); + } + } +} + +template < + typename T, + class BinaryFunction> +__global__ void tensor_kernel_scan_innermost_dim( + T* tgt_, + const T* src_, + const uint32_t num_rows, + const uint32_t row_size, + const uint32_t log_num_threads_x, + T init, + BinaryFunction binary_op) { + alignas(sizeof(double)) extern __shared__ char sbuf[]; + T* sbuf2 = reinterpret_cast(sbuf); + const uint32_t num_threads_x = 1 << log_num_threads_x; + T* row_buf = reinterpret_cast(sbuf2 + num_threads_x * 2 * threadIdx.y); + + tensor_kernel_scan_innermost_dim_impl( + row_buf, tgt_, src_, num_rows, row_size, log_num_threads_x, init, binary_op); +} + + +template +__host__ void scan_outer_dim(const TensorBase& self, const TensorBase& result, + int dim, scalar_t init, BinaryFunction binary_op) { + const int64_t row_size = self.size(dim); + auto sizes = self.sizes(); + + // Treat all outer dimensions (i.e. dim_ < dim) as one. + const int64_t num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + dim); + + // Treat all inner dimensions (i.e. dim > dimension) as one. + const int64_t num_irows = c10::multiply_integers(sizes.begin() + dim + 1, sizes.end()); + + dim3 threads(std::min(512, int(num_irows))); + int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x}))); + + check_fits_in_unsigned(num_irows, "num_irows"); + check_fits_in_unsigned(num_orows, "num_orows"); + check_fits_in_unsigned(row_size, "row_size"); + + tensor_kernel_scan_outer_dim<<>>( + result.mutable_data_ptr(), self.const_data_ptr(), + num_orows, num_irows, row_size, init, binary_op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void scan_innermost_dim(const TensorBase& self, const TensorBase& result, + scalar_t init, BinaryFunction binary_op) { + int64_t ndim = self.dim(); + // Treat all outer dimensions as a single dimension. + int64_t row_size = self.size(ndim - 1); + int64_t num_rows = self.numel() / row_size; + + // assuming max_num_threads per block is 512 + const uint32_t num_threads = 512; + const uint32_t log_num_threads_x = get_log_num_threads_x_inner_scan(num_rows, row_size); + const uint32_t num_threads_x = (1 << log_num_threads_x); + const uint32_t num_threads_y = num_threads / num_threads_x; + dim3 threads(num_threads_x, num_threads_y); + int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[0]; + dim3 grid(std::min(maxGridDim, ceil_div(num_rows, int64_t{threads.y}))); + + check_fits_in_unsigned(num_rows, "Number of rows (self.numel()/self.size(self.dim()-1))"); + check_fits_in_unsigned(row_size, "row_size"); + + tensor_kernel_scan_innermost_dim<<>>( + result.mutable_data_ptr(), self.const_data_ptr(), + num_rows, row_size, log_num_threads_x, init, binary_op); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void scan_dim(const TensorBase& self, const TensorBase& result, + int64_t dim, scalar_t init, BinaryFunction binary_op) { + int ndim = self.dim(); + auto self_ = self.expect_contiguous(); + TORCH_INTERNAL_ASSERT(result.is_contiguous()); + + if (self.numel() == self.size(dim)) { + cuda::cub::inclusive_scan(self_->const_data_ptr(), result.mutable_data_ptr(), binary_op, self.numel()); + } else if (dim == ndim - 1) { + scan_innermost_dim(*self_, result, init, binary_op); + } else { + scan_outer_dim(*self_, result, dim, init, binary_op); + } +} + +}} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/SortStable.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/SortStable.h new file mode 100644 index 0000000000000000000000000000000000000000..039c4307c522c9f81bf88554483f67a26127561a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/SortStable.h @@ -0,0 +1,19 @@ +#pragma once +#include +#include + +namespace at { +namespace native { + +// Stable-sort self into values, and set indices to the +// inverse-permutation from values back to self. +// Output tensors must be pre-allocated and contiguous. +void launch_stable_sort_kernel( + const TensorBase& self, + int64_t dim, + bool descending, + const TensorBase& values, + const TensorBase& indices); + +} // namespace native +} // namespace at diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/SortUtils.cuh b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/SortUtils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..646045e930e73a68cb5fa79fc39f1671b1854bd7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/SortUtils.cuh @@ -0,0 +1,343 @@ +#pragma once +#include + +#include +#include +#include +#include +#include +#include +#include + +#define HAS_WARP_MERGE_SORT() (CUDA_VERSION >= 110600) + + +namespace at { namespace native { + +template +__device__ inline void swapVars(T& t1, T& t2) { + T tmp = t1; + t1 = t2; + t2 = tmp; +} + +template +__device__ inline void bitonicSwap(K& kA, V& vA, bool& validA, + K& kB, V& vB, bool& validB, + bool dir, + const Comparator& comp) { + // Invalid entries always sort to the end + bool swap = (comp(kA, kB) && validA) || !validB; + if (swap == dir) { + swapVars(kA, kB); + swapVars(vA, vB); + swapVars(validA, validB); + } +}; + +template +__device__ inline void bitonicSort(K *keys, + V *values, + bool *valid, + const Comparator& comp) { +#if !defined(USE_ROCM) +#pragma unroll +#endif + for (unsigned int size = 2; size < Power2SortSize; size *= 2) { + bool flag = ((threadIdx.x & (size / 2)) != 0); + +#if !defined(USE_ROCM) +#pragma unroll +#endif + for (unsigned int stride = size / 2; stride > 0; stride /= 2) { + + __syncthreads(); + + unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); + bitonicSwap( + keys[pos], values[pos], valid[pos], + keys[pos + stride], values[pos + stride], valid[pos + stride], + flag, comp); + } + } + +#if !defined(USE_ROCM) +#pragma unroll +#endif + for (unsigned int stride = Power2SortSize / 2; stride > 0; stride /= 2) { + + __syncthreads(); + + unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); + bitonicSwap( + keys[pos], values[pos], valid[pos], + keys[pos + stride], values[pos + stride], valid[pos + stride], + false, comp); + } + + __syncthreads(); + +} + +// at::cuda::detail::TensorInfo version +// Sorts (key, value) pairs (in different tensors) in-place; i.e., +// modifies the input `keys` and `values` +template +C10_LAUNCH_BOUNDS_1(block_dim_x * max_block_dim_y) +__global__ void +bitonicSortKVInPlace(at::cuda::detail::TensorInfo keys, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + at::cuda::detail::TensorInfo values, + IndexType valueSliceStride, + Comparator comp) { + // Find the slice of the tensor that we are sorting + // NOTE: blockDim.y may be less max_block_dim_y + const IndexType blockIndex = getLinearBlockId(); + const IndexType linearIndex = blockIndex * blockDim.y + threadIdx.y; + + // If the entire block is out of bounds exit early + if (blockIndex * blockDim.y >= keySlices) { + return; + } + // It's also possible for some rows of a block to be out of bounds + // but all thread need to run for __syncthreads to work. + const bool row_valid = linearIndex < keySlices; + + constexpr int items_per_thread = 2; + constexpr int Power2SortSize = block_dim_x * items_per_thread; + + // Storage for max_block_dim_y sorts performed in parallel + __shared__ K blockSharedKeys[max_block_dim_y][Power2SortSize]; + __shared__ V blockSharedValues[max_block_dim_y][Power2SortSize]; + __shared__ bool blockSharedValid[max_block_dim_y][Power2SortSize]; + + auto sharedKeys = blockSharedKeys[threadIdx.y]; + auto sharedValues = blockSharedValues[threadIdx.y]; + auto sharedValid = blockSharedValid[threadIdx.y]; + + const IndexType keyStartOffset = + at::cuda::detail::IndexToOffset::get(linearIndex, keys); + const IndexType valueStartOffset = + at::cuda::detail::IndexToOffset::get(linearIndex, values); + + // Load 2 values per thread into the shared workspace + #pragma unroll + for (int k = 0; k < items_per_thread; ++k) { + auto idx = threadIdx.x + k * blockDim.x; + bool valid = row_valid && idx < keySliceSize; + + sharedKeys[idx] = valid ? + keys.data[idx * keySliceStride + keyStartOffset] : K{}; + sharedValues[idx] = valid ? + values.data[idx * valueSliceStride + valueStartOffset] : V{}; + sharedValid[idx] = valid; + } + + // Sort! + bitonicSort( + sharedKeys, sharedValues, sharedValid, comp); + + if (!row_valid) { + return; + } + + // Store outputs + #pragma unroll + for (int k = 0; k < items_per_thread; ++k) { + auto idx = threadIdx.x + k * blockDim.x; + if (idx < keySliceSize) { + keys.data[idx * keySliceStride + keyStartOffset] = sharedKeys[idx]; + values.data[idx * valueSliceStride + valueStartOffset] = sharedValues[idx]; + } + } +} + +#if HAS_WARP_MERGE_SORT() + +template +C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE * max_block_dim_y) +__global__ void +warpMergeSortKVInPlace( + at::cuda::detail::TensorInfo keys, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + at::cuda::detail::TensorInfo values, + IndexType valueSliceStride, + Comparator comp, + K invalid_key) { + // Find the slice of the tensor that we are sorting + // NOTE: blockDim.y may be less max_block_dim_y + const IndexType blockIndex = getLinearBlockId(); + const IndexType linearIndex = blockIndex * blockDim.y + threadIdx.y; + + // If this row is out of bounds exit early + if (linearIndex >= keySlices) { + return; + } + + const IndexType keyStartOffset = + at::cuda::detail::IndexToOffset::get(linearIndex, keys); + const IndexType valueStartOffset = + at::cuda::detail::IndexToOffset::get(linearIndex, values); + + K *keys_slice = &keys.data[keyStartOffset]; + V *values_slice = &values.data[valueStartOffset]; + + StridedRandomAccessor keys_iter(keys_slice, keySliceStride); + StridedRandomAccessor values_iter(values_slice, valueSliceStride); + + namespace cub = ROCM_HIPCUB(at_cuda_detail::cub); + + CUDA_KERNEL_ASSERT(blockDim.x == C10_WARP_SIZE); + CUDA_KERNEL_ASSERT(blockDim.y <= max_block_dim_y); + constexpr int items_per_thread = sort_size / C10_WARP_SIZE; + static_assert( + items_per_thread * C10_WARP_SIZE == sort_size, + "sort_size must be a multiple of C10_WARP_SIZE"); + + + using LoadKeys = cub::WarpLoad; + using LoadValues = cub::WarpLoad; + using Sort = cub::WarpMergeSort; + using StoreKeys = cub::WarpStore; + using StoreValues = cub::WarpStore; + + __shared__ union { + typename LoadKeys::TempStorage load_keys; + typename LoadValues::TempStorage load_values; + typename Sort::TempStorage sort; + typename StoreKeys::TempStorage store_keys; + typename StoreValues::TempStorage store_values; + } tmp_storage[max_block_dim_y]; + + auto& warp_storage = tmp_storage[threadIdx.y]; + + // Load inputs + K local_keys[items_per_thread]; + V local_values[items_per_thread]; + + const auto invalid_value = V{}; + LoadKeys(warp_storage.load_keys).Load(keys_iter, local_keys, keySliceSize, invalid_key); + WARP_SYNC(); + LoadValues(warp_storage.load_values).Load(values_iter, local_values, keySliceSize, invalid_value); + WARP_SYNC(); + + // Sort! We use stable sort to ensure that invalid values are never + // sorted before valid values. In testing it performed the same as + // .Sort, so there is no down-side. + Sort(warp_storage.sort).StableSort( + local_keys, local_values, comp, keySliceSize, invalid_key); + WARP_SYNC(); + + // Store outputs + StoreKeys(warp_storage.store_keys).Store(keys_iter, local_keys, keySliceSize); + WARP_SYNC(); + StoreValues(warp_storage.store_values).Store(values_iter, local_values, keySliceSize); +} + +#endif // HAS_WARP_MERGE_SORT() + +template +C10_LAUNCH_BOUNDS_1(block_size) +__global__ void +radixSortKVInPlace(at::cuda::detail::TensorInfo keys, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + at::cuda::detail::TensorInfo values, + IndexType valueSliceStride, + bool descending) { + static_assert(block_size > 0, ""); + + // Find the slice of the tensor that we are sorting + const IndexType linearIndex = getLinearBlockId(); + // Tiling the slices could have us be out of bounds, if there are a + // lot of slices to sort + if (linearIndex >= keySlices) { + return; + } + + const IndexType keyStartOffset = + at::cuda::detail::IndexToOffset::get(linearIndex, keys); + const IndexType valueStartOffset = + at::cuda::detail::IndexToOffset::get(linearIndex, values); + + K *keys_slice = &keys.data[keyStartOffset]; + V *values_slice = &values.data[valueStartOffset]; + + StridedRandomAccessor keys_iter(keys_slice, keySliceStride); + StridedRandomAccessor values_iter(values_slice, valueSliceStride); + + namespace cub = ROCM_HIPCUB(at_cuda_detail::cub); + + using key_t = typename at::cuda::cub::detail::cuda_type::type; + using LoadKeys = cub::BlockLoad; + using LoadValues = cub::BlockLoad; + using Sort = cub::BlockRadixSort; + using StoreKeys = cub::BlockStore; + using StoreValues = cub::BlockStore; + + __shared__ union { + typename LoadKeys::TempStorage load_keys; + typename LoadValues::TempStorage load_values; + typename Sort::TempStorage sort; + typename StoreKeys::TempStorage store_keys; + typename StoreValues::TempStorage store_values; + } tmp_storage; + + // cub's Block operations operate on a fixed number of items, but the + // actual slice we are sorting might be smaller. So, we need to make + // up the difference with keys that will always sort higher. + const K invalid_key = [descending] { + using radix_t = typename cub::Traits::UnsignedBits; + union { + K key; + radix_t radix; + } tmp; + tmp.radix = descending ? + cub::Traits::LOWEST_KEY : + cub::Traits::MAX_KEY; + return tmp.key; + }(); + const V invalid_value = static_cast(0); + + // Load inputs + K local_keys[items_per_thread]; + V local_values[items_per_thread]; + + LoadKeys(tmp_storage.load_keys).Load(keys_iter, local_keys, keySliceSize, invalid_key); + __syncthreads(); + LoadValues(tmp_storage.load_values).Load(values_iter, local_values, keySliceSize, invalid_value); + __syncthreads(); + + // Sort! + if (descending) { + Sort(tmp_storage.sort).SortDescending( + reinterpret_cast(local_keys), + local_values); + } else { + Sort(tmp_storage.sort).Sort( + reinterpret_cast(local_keys), + local_values); + } + __syncthreads(); + + // Store outputs + StoreKeys(tmp_storage.store_keys).Store(keys_iter, local_keys, keySliceSize); + __syncthreads(); + StoreValues(tmp_storage.store_values).Store(values_iter, local_values, keySliceSize); +} + +}} // at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Sorting.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Sorting.h new file mode 100644 index 0000000000000000000000000000000000000000..bd10ffb1a0274182c77bebe1097169f891dad3d3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Sorting.h @@ -0,0 +1,18 @@ +#pragma once +#include + +namespace at { +class TensorBase; +} + +namespace at { +namespace native { + +void launch_kthvalue_kernel( + const TensorBase &values, const TensorBase &indices, + const TensorBase &self, int64_t dim, int64_t k); +void launch_median_kernel( + const TensorBase &vals, const TensorBase &inds, + const TensorBase &in, int64_t dim, bool ignore_nan); + +}} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/SortingCommon.cuh b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/SortingCommon.cuh new file mode 100644 index 0000000000000000000000000000000000000000..30e03f4b43e567e3f031fd6c372908e8e41b8db6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/SortingCommon.cuh @@ -0,0 +1,193 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +// Is this questionable namespace pollution? +#if defined(USE_ROCM) +constexpr int MAX_BLOCK_SIZE = 256; + +#else +constexpr int MAX_BLOCK_SIZE = 1024; +#endif + +// Maximum size per grid dimension that we assume (compute capability >= 2.0) +constexpr int64_t MAX_GRID_SIZE = 65535LL; + +inline bool getGridFromTiles(int64_t gridTiles, dim3& grid) { + if (gridTiles > MAX_GRID_SIZE * MAX_GRID_SIZE * MAX_GRID_SIZE) { + return false; + } + + int64_t gridX = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; + int64_t gridY = 1; + int64_t gridZ = 1; + + if (gridTiles > MAX_GRID_SIZE) { + gridTiles = ceil_div(gridTiles, MAX_GRID_SIZE); + gridY = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; + + if (gridTiles > MAX_GRID_SIZE) { + gridTiles = ceil_div(gridTiles, MAX_GRID_SIZE); + gridZ = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; + } + } + + grid = dim3(gridX, gridY, gridZ); + return true; +} + +template +struct GTOp { + __device__ bool operator()(const scalar_t& lhs, const scalar_t& rhs) const { + return (handleNaN && at::_isnan(lhs) && !at::_isnan(rhs)) || (lhs > rhs); + } +}; + +template +struct LTOp { + __device__ bool operator()(const scalar_t& lhs, const scalar_t& rhs) const { + return (handleNaN && at::_isnan(rhs) && !at::_isnan(lhs)) || (lhs < rhs); + } +}; + +template +__device__ __forceinline__ index_t getLinearBlockId() { + return blockIdx.z * gridDim.y * gridDim.x + blockIdx.y * gridDim.x + + blockIdx.x; +} + +// For slice sorting in Thrust; extracts a slice index from a linear +// index and uses that for comparison +struct SliceComp { + SliceComp(int64_t size) : sliceSize(size) {} + + __device__ bool operator()(const int64_t& a, const int64_t& b) const { + // Since the slices are guaranteed to be innermost, + // the segment is just via int64_t division + int64_t segA = a / sliceSize; + int64_t segB = b / sliceSize; + return segA < segB; + } + + const int64_t sliceSize; +}; + +// For sorting in Thurst; extracts a within-slice index from a linear index +struct GlobalIndexToPerSliceIndex { + GlobalIndexToPerSliceIndex(int64_t size) : sliceSize(size) {} + + __device__ inline void operator()(int64_t& v) const { + v = v % sliceSize; + } + + const int64_t sliceSize; +}; + +// Returns 2^(ceil(lg(n)) from Stanford bit twiddling hacks +inline uint64_t nextHighestPowerOf2(uint64_t n) { + n--; + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; +#ifndef _MSC_VER + n |= n >> 32; +#endif + n++; + + return n; +} + + +// WARNING: This function assumes input tensors are contiguous +template +void run_launcher( + const TensorBase &values, + const TensorBase &indices, + const TensorBase &self, + int64_t dim, + Launcher l) { + auto self_info = cuda::detail::getTensorInfo(self); + auto values_info = cuda::detail::getTensorInfo(values); + auto indices_info = cuda::detail::getTensorInfo(indices); + + int64_t slice_size = self.size(dim); + /* We use these structures solely to find the offset to */ + /* each slice we are operating on */ + self_info.reduceDim(dim); + values_info.reduceDim(dim); + indices_info.reduceDim(dim); + + /* Collapse all other dims */ + int collapse_self_dim = self_info.collapseDims(dim); + int collapse_values_dim = values_info.collapseDims(dim); + int collapse_indices_dim = indices_info.collapseDims(dim); + + int64_t num_slices = 1; + for (int i = 0; i < self_info.dims; ++i) { + num_slices *= self_info.sizes[i]; + } + + /* This is used as a template parameter to calculate indices. */ + /* We only specialize it if all collapsed dim sizes are the */ + /* same; otherwise, we use -1 which is the specialization */ + /* parameter for arbitrary dimensions */ + int all_dims = self_info.dims; + if (values_info.dims != all_dims || indices_info.dims != all_dims) { + all_dims = -1; + } + + if (all_dims == 1) { + l.template launch( + values_info, + collapse_values_dim, + indices_info, + collapse_indices_dim, + self_info, + collapse_self_dim, + num_slices, + slice_size); + } else if (all_dims == 2) { + l.template launch( + values_info, + collapse_values_dim, + indices_info, + collapse_indices_dim, + self_info, + collapse_self_dim, + num_slices, + slice_size); + } else if (all_dims == 3) { + l.template launch( + values_info, + collapse_values_dim, + indices_info, + collapse_indices_dim, + self_info, + collapse_self_dim, + num_slices, + slice_size); + } else { + l.template launch( + values_info, + collapse_values_dim, + indices_info, + collapse_indices_dim, + self_info, + collapse_self_dim, + num_slices, + slice_size); + } +} + +} // namespace native +} // namespace at diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/SortingRadixSelect.cuh b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/SortingRadixSelect.cuh new file mode 100644 index 0000000000000000000000000000000000000000..1aeaca19652a652db6ff3aded81e2bdec8b3a4af --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/SortingRadixSelect.cuh @@ -0,0 +1,429 @@ +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +template +struct TopKTypeConfig {}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + // Converts a float to an integer representation with the same + // sorting; i.e., for floats f1, f2: + // if f1 < f2 then convert(f1) < convert(f2) + // We use this to enable radix selection of floating-point values. + // This also gives a relative order for NaNs, but that's ok, as they + // will all be adjacent + // neg inf: signbit=1 exp=ff fraction=0 --> radix = 0 00 ff.. + // pos inf: signbit=0 exp=ff fraction=0 --> radix = 1 ff 00.. + // pos nan: signbit=0 exp=ff fraction>0 --> radix = 1 ff x>0 + // neg nan: signbit=1 exp=ff fraction>0 --> radix = 0 00 x +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(uint8_t v) { + return v; + } + + static inline __device__ uint8_t deconvert(RadixType v) { + return v; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(int8_t v) { + return 128u + v; + } + + static inline __device__ int8_t deconvert(RadixType v) { + return v - 128; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(int16_t v) { + static_assert(sizeof(short) == 2, ""); + return 32768u + v; + } + + static inline __device__ int16_t deconvert(RadixType v) { + return v - 32768; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(int32_t v) { + static_assert(sizeof(int) == 4, ""); + return 2147483648u + v; + } + + static inline __device__ int32_t deconvert(RadixType v) { + return v - 2147483648u; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint64_t RadixType; + + static inline __device__ RadixType convert(int64_t v) { + static_assert(sizeof(int64_t) == 8, ""); + return 9223372036854775808ull + v; + } + + static inline __device__ int64_t deconvert(RadixType v) { + return v - 9223372036854775808ull; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint64_t RadixType; + + static inline __device__ RadixType convert(double v) { + RadixType x = __double_as_longlong(v); + RadixType mask = -((x >> 63)) | 0x8000000000000000; + return (v == v) ? (x ^ mask) : 0xffffffffffffffff; + } + + static inline __device__ double deconvert(RadixType v) { + RadixType mask = ((v >> 63) - 1) | 0x8000000000000000; + return __longlong_as_double(v ^ mask); + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(at::Half v) { +#if defined(__CUDA_ARCH__) || defined(USE_ROCM) + RadixType x = __half_as_ushort(v); + RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000; + return (v == v) ? (x ^ mask) : 0xffff; +#else + CUDA_KERNEL_ASSERT(false); + return 0u; +#endif + } + + static inline __device__ at::Half deconvert(RadixType v) { +#if defined(__CUDA_ARCH__) || defined(USE_ROCM) + RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff; + return __ushort_as_half(v ^ mask); +#else + CUDA_KERNEL_ASSERT(false); + return static_cast(0); +#endif + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(at::BFloat16 v) { + RadixType x = v.x; + RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000; + return (v == v) ? (x ^ mask) : 0xffff; + } + + static inline __device__ at::BFloat16 deconvert(RadixType v) { + RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff; + at::BFloat16 r; + r.x = (v ^ mask); + return r; + } +}; + +// This function counts the distribution of all input values in a +// slice we are selecting by radix digit at `radixDigitPos`, but only +// those that pass the filter `((v & desiredMask) == desired)`. +// This produces and broadcasts the seen counts for a single block only. +// `smem` must have at least `RadixSize` elements. +template < + typename scalar_t, + typename bitwise_t, + typename index_t, + typename CountType, + int RadixSize, + int RadixBits> +__device__ void countRadixUsingMask( + CountType counts[RadixSize], + CountType* smem, + bitwise_t desired, + bitwise_t desiredMask, + int radixDigitPos, + index_t sliceSize, + index_t withinSliceStride, + const scalar_t* data) { + // Clear out per-thread counts from a previous round +#pragma unroll + for (int i = 0; i < RadixSize; ++i) { + counts[i] = 0; + } + + if (threadIdx.x < RadixSize) { + smem[threadIdx.x] = 0; + } + __syncthreads(); + + // Scan over all the data. Upon a read, the warp will accumulate + // counts per each digit in the radix using warp voting. +#if !defined(USE_ROCM) + // Must be called outside of loop to ensure all threads participate + unsigned mask = WARP_BALLOT(threadIdx.x < sliceSize); +#endif + for (index_t i = threadIdx.x; i < sliceSize;) { + bitwise_t val = + TopKTypeConfig::convert(doLdg(&data[i * withinSliceStride])); + + bool hasVal = ((val & desiredMask) == desired); + bitwise_t digitInRadix = at::cuda::Bitfield::getBitfield( + val, radixDigitPos, RadixBits); + +#pragma unroll + for (uint32_t j = 0; j < RadixSize; ++j) { + bool vote = hasVal && (digitInRadix == j); +#if defined(USE_ROCM) + counts[j] += __popcll(WARP_BALLOT(vote)); +#else + counts[j] += __popc(WARP_BALLOT(vote, mask)); +#endif + } + i += blockDim.x; +#if !defined(USE_ROCM) + mask = WARP_BALLOT(i < sliceSize, mask); +#endif + } + + // Now, for each warp, sum values + if (at::cuda::getLaneId() == 0) { +#pragma unroll + for (uint32_t i = 0; i < RadixSize; ++i) { + gpuAtomicAddNoReturn(&smem[i], counts[i]); + } + } + + __syncthreads(); + + // For each thread, read in the total counts +#pragma unroll + for (uint32_t i = 0; i < RadixSize; ++i) { + counts[i] = smem[i]; + } + + __syncthreads(); +} + +// Over what radix we are selecting values +constexpr int RADIX_BITS = 2; // digits are base-(2 ^ RADIX_BITS) +constexpr int RADIX_SIZE = 4; // 2 ^ RADIX_BITS +constexpr int RADIX_MASK = (RADIX_SIZE - 1); + +// This finds the unique value `v` that matches the pattern +// ((v & desired) == desiredMask) in our sorted int format +template +__device__ scalar_t findPattern( + scalar_t* smem, + const scalar_t* data, + index_t sliceSize, + index_t withinSliceStride, + bitwise_t desired, + bitwise_t desiredMask) { + if (threadIdx.x < 2) { + smem[threadIdx.x] = static_cast(0); + } + __syncthreads(); + + // All threads participate in the loop, in order to sync on the flag + index_t numIterations = + round_up(sliceSize, static_cast(blockDim.x)); + for (index_t i = threadIdx.x; i < numIterations; i += blockDim.x) { + bool inRange = (i < sliceSize); + scalar_t v = inRange ? doLdg(&data[i * withinSliceStride]) + : static_cast(0); + + if (inRange && + ((TopKTypeConfig::convert(v) & desiredMask) == desired)) { + // There should not be conflicts if we are using findPattern, + // since the result is unique + smem[0] = static_cast(1); + smem[1] = v; // can't use val as the flag, since it could be 0 + } + + __syncthreads(); + + scalar_t found = smem[0]; + scalar_t val = smem[1]; + + __syncthreads(); + + // Check to see if a thread found the value + if (found != static_cast(0)) { + // all threads return this value + return val; + } + } + + // should not get here + CUDA_KERNEL_ASSERT(false); + return static_cast(0); +} + +// Returns the top-Kth element found in the data using radix selection +template +__device__ void radixSelect( + const scalar_t* data, + index_t k, + bool largest, + index_t sliceSize, + index_t withinSliceStride, + int* smem, + scalar_t* topK) { + // Per-thread buckets into which we accumulate digit counts in our + // radix + int counts[RADIX_SIZE]; + + // We only consider elements x such that (x & desiredMask) == desired + // Initially, we consider all elements of the array, so the above + // statement is true regardless of input. + bitwise_t desired = 0; + bitwise_t desiredMask = 0; + + // We are looking for the top kToFind-th element when iterating over + // digits; this count gets reduced by elimination when counting + // successive digits + int kToFind = k; + + // We start at the most significant digit in our radix, scanning + // through to the least significant digit + for (int digitPos = sizeof(scalar_t) * 8 - RADIX_BITS; digitPos >= 0; + digitPos -= RADIX_BITS) { + // Count radix distribution for the current position and reduce + // across all threads + countRadixUsingMask< + scalar_t, + bitwise_t, + index_t, + int, + RADIX_SIZE, + RADIX_BITS>( + counts, + smem, + desired, + desiredMask, + digitPos, + sliceSize, + withinSliceStride, + data); + + auto found_unique = [&](int i, int count) -> bool { + /* All threads have the same value in counts here, so all */ + /* threads will return from the function. */ + if (count == 1 && kToFind == 1) { + /* There is a unique answer. */ + desired = at::cuda::Bitfield::setBitfield( + desired, i, digitPos, RADIX_BITS); + desiredMask = at::cuda::Bitfield::setBitfield( + desiredMask, RADIX_MASK, digitPos, RADIX_BITS); + + /* The answer is now the unique element v such that: */ + /* (v & desiredMask) == desired */ + /* However, we do not yet know what the actual element is. We */ + /* need to perform a search through the data to find the */ + /* element that matches this pattern. */ + *topK = findPattern( + (scalar_t*)smem, + data, + sliceSize, + withinSliceStride, + desired, + desiredMask); + return true; + } + return false; + }; + auto found_non_unique = [&](int i, int count) -> bool { + if (count >= kToFind) { + desired = + at::cuda::Bitfield::setBitfield( + desired, i, digitPos, RADIX_BITS); + desiredMask = at::cuda::Bitfield::setBitfield( + desiredMask, RADIX_MASK, digitPos, RADIX_BITS); + + /* The top-Kth element v must now be one such that: */ + /* (v & desiredMask == desired) */ + /* but we haven't narrowed it down; we must check the next */ + /* least-significant digit */ + return true; + } + kToFind -= count; + return false; // continue the loop + }; + + // All threads participate in the comparisons below to know the + // final result + if (largest) { + // Process in descending order +#pragma unroll + for (int i = RADIX_SIZE - 1; i >= 0; --i) { + int count = counts[i]; + if (found_unique(i, count)) { + return; + } + if (found_non_unique(i, count)) { + break; + } + } + } else { + // Process in ascending order +#pragma unroll + for (int i = 0; i < RADIX_SIZE; ++i) { + int count = counts[i]; + if (found_unique(i, count)) { + return; + } + if (found_non_unique(i, count)) { + break; + } + } + } + } // end digitPos for + + // There is no unique result, but there is a non-unique result + // matching `desired` exactly + *topK = TopKTypeConfig::deconvert(desired); +} +} // namespace native +} // namespace at diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/TensorTopK.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/TensorTopK.h new file mode 100644 index 0000000000000000000000000000000000000000..9eebf2cd6040c4f2df9ad64599910ba0e0cee58f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/TensorTopK.h @@ -0,0 +1,14 @@ +#pragma once +#include + +namespace at { +class TensorBase; +} + +namespace at { +namespace native { +void launch_gather_topk_kernel( + const TensorBase& self, + int64_t k, int64_t dim, bool largest, + const TensorBase& values, const TensorBase& indices); +}} diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/UpSample.cuh b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/UpSample.cuh new file mode 100644 index 0000000000000000000000000000000000000000..f2310dd33c4cae8772ffa3071a2439e424ac9b9c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/UpSample.cuh @@ -0,0 +1,370 @@ +#pragma once +#include +#include + +#include +#include +#include + +#include +#include + +namespace at { +namespace native { + +namespace upsample { +// TODO: Remove duplicate declaration. +TORCH_API c10::SmallVector compute_output_size( + c10::IntArrayRef input_size, // Full input tensor size. + at::OptionalIntArrayRef output_size, + std::optional> scale_factors); +} // namespace upsample + +namespace upsample_cuda { + +// TODO: Remove duplication with Upsample.h (CPU). +inline std::optional get_scale_value(std::optional> scales, int idx) { + if (!scales) { + return std::nullopt; + } + return scales->at(idx); +} + +} // namespace upsample_cuda + + +/* TODO: move this to a common place */ +template +__device__ inline scalar_t min(scalar_t a, scalar_t b) { + return a < b ? a : b; +} + +template +__device__ inline scalar_t max(scalar_t a, scalar_t b) { + return a > b ? a : b; +} + +// NOTE [ Nearest neighbor upsampling kernel implementation ] +// +// The nearest neighbor upsampling kernel implementation is symmetrical as +// expected. We launch kernels with threads mapping to destination tensors where +// kernels write data to, each thread reads data from the source tensor, this +// means: +// 1. In the forward kernel, +// src_xxx refers to properties of input tensors; +// dst_xxx refers to properties of output tensors; +// scale_factor is the ratio of src_size to dst_size; +// 2. In the backward kernel, +// src_xxx refers to properties of grad_output tensors; +// dst_xxx refers to properties of grad_input tensors; +// scale_factor is the ratio of src_size to dst_size; +// +// Because of this, we need to take the reciprocal of the scale defined by +// upsample layer during forward path. The motivation is to avoid slow +// division in the kernel code, so we can use faster multiplication instead. +// This is not necessary during backward path, since the scale_factor is already +// the reciprocal of corresponding scale_factor used in the forward path due to +// the swap of source and destination tensor. +// +// Similarly, since the mapping from grad_input to grad_output during backward +// is the reverse of the mapping of output to input, we need to have opposite +// mapping functions to compute the source index. + +// see NOTE [ Nearest neighbor upsampling kernel implementation ] +template +__host__ __forceinline__ accscalar_t compute_scales_value( + const std::optional scale, + int64_t src_size, + int64_t dst_size) { + // FIXME: remove magic > 0 after we ensure no models were serialized with -1 defaults. + return (scale.has_value() && scale.value() > 0.) ? (accscalar_t)(1.0 / scale.value()) + : (accscalar_t)src_size / dst_size; +} + +// see NOTE [ Nearest neighbor upsampling kernel implementation ] +template +__host__ __forceinline__ accscalar_t compute_scales_value_backwards( + const std::optional scale, + int64_t src_size, + int64_t dst_size) { + // FIXME: remove magic > 0 after we ensure no models were serialized with -1 defaults. + return (scale.has_value() && scale.value() > 0.) ? (accscalar_t)scale.value() + : (accscalar_t)src_size / dst_size; +} + +template +__host__ __forceinline__ accscalar_t area_pixel_compute_scale( + int input_size, + int output_size, + bool align_corners, + const std::optional scale) { + if(align_corners) { + if(output_size > 1) { + return (accscalar_t)(input_size - 1) / (output_size - 1); + } + else { + return static_cast(0); + } + } + else{ + return compute_scales_value(scale, input_size, output_size); + } +} + +template +__device__ __forceinline__ accscalar_t area_pixel_compute_source_index( + accscalar_t scale, + int dst_index, + bool align_corners, + bool cubic) { + if (align_corners) { + return scale * dst_index; + } else { + accscalar_t src_idx = scale * (dst_index + static_cast(0.5)) - + static_cast(0.5); + // See Note[Follow Opencv resize logic] + return (!cubic && src_idx < static_cast(0)) + ? static_cast(0) + : src_idx; + } +} + +// see NOTE [ Nearest neighbor upsampling kernel implementation ] +__device__ __forceinline__ int nearest_neighbor_compute_source_index( + const float scale, + int dst_index, + int input_size) { + // index_f32 = (output_index) * scale + // input_index = round(index_f32) + // Same as a buggy OpenCV INTER_NEAREST + // We keep this method for BC and consider as deprecated. + // See nearest_neighbor_exact_compute_source_index as replacement + const int src_index = + min(static_cast(floorf((dst_index) * scale)), input_size - 1); + return src_index; +} + +__device__ __forceinline__ int nearest_neighbor_exact_compute_source_index( + const float scale, + int dst_index, + int input_size) { + // index_f32 = (output_index + 0.5) * scale - 0.5 + // input_index = round(index_f32) + // Same as Pillow and Scikit-Image/Scipy ndi.zoom + const int src_index = + min(static_cast(floorf((dst_index + static_cast(0.5)) * scale)), input_size - 1); + return src_index; +} + +// see NOTE [ Nearest neighbor upsampling kernel implementation ] +__device__ __forceinline__ int nearest_neighbor_bw_compute_source_index( + const float scale, + int dst_index, + int output_size) { + // Equivalent to buggy OpenCV INTER_NEAREST + // We keep this method for BC and consider as deprecated. + // See nearest_neighbor_exact_bw_compute_source_index as replacement + const int src_index = + min(static_cast(ceilf(dst_index * scale)), output_size); + return src_index; +} + +// see NOTE [ Nearest neighbor upsampling kernel implementation ] +__device__ __forceinline__ int nearest_neighbor_exact_bw_compute_source_index( + const float scale, + int dst_index, + int output_size) { + // Equivalent to Pillow and Scikit-Image/Scipy ndi.zoom + const int src_index = + min(static_cast(ceilf(dst_index * scale - static_cast(0.5))), output_size); + return src_index; +} + +/* Used by UpSampleBicubic2d.cu */ +template +__device__ __forceinline__ scalar_t upsample_get_value_bounded( + const PackedTensorAccessor64& data, + int batch, + int channel, + int height, + int width, + int y, + int x) { + int access_y = max(min(y, height - 1), 0); + int access_x = max(min(x, width - 1), 0); + return data[batch][channel][access_y][access_x]; +} + +/* Used by UpSampleBicubic2d.cu */ +template +__device__ __forceinline__ void upsample_increment_value_bounded( + PackedTensorAccessor64& data, + int batch, + int channel, + int height, + int width, + int y, + int x, + accscalar_t value) { + int access_y = max(min(y, height - 1), 0); + int access_x = max(min(x, width - 1), 0); + /* TODO: result here is truncated to scalar_t, + check: https://github.com/pytorch/pytorch/pull/19630#discussion_r281426912 + */ + gpuAtomicAddNoReturn( + &data[batch][channel][access_y][access_x], static_cast(value)); +} + +// Based on +// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm +template +__device__ __forceinline__ accscalar_t cubic_convolution1( + accscalar_t x, + accscalar_t A) { + return ((A + 2) * x - (A + 3)) * x * x + 1; +} + +template +__device__ __forceinline__ accscalar_t cubic_convolution2( + accscalar_t x, + accscalar_t A) { + return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; +} + +template +__device__ __forceinline__ void get_cubic_upsampling_coefficients( + accscalar_t coeffs[4], + accscalar_t t) { + accscalar_t A = -0.75; + + accscalar_t x1 = t; + coeffs[0] = cubic_convolution2(x1 + 1.0, A); + coeffs[1] = cubic_convolution1(x1, A); + + // opposite coefficients + accscalar_t x2 = 1.0 - t; + coeffs[2] = cubic_convolution1(x2, A); + coeffs[3] = cubic_convolution2(x2 + 1.0, A); +} + +template +__device__ __forceinline__ accscalar_t cubic_interp1d( + scalar_t x0, + scalar_t x1, + scalar_t x2, + scalar_t x3, + accscalar_t t) { + accscalar_t coeffs[4]; + get_cubic_upsampling_coefficients(coeffs, t); + + return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]; +} + +namespace upsample_antialias { + +// taken from +// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/ +// src/libImaging/Resample.c#L20-L29 +struct BilinearFilterFunctor { + + template + __device__ accscalar_t operator()(accscalar_t x) const { + if (x < 0) { + x = -x; + } + if (x < 1) { + return 1 - x; + } + return 0; + } + + static const int size = 2; +}; + +// taken from +// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/ +// src/libImaging/Resample.c#L46-L62 +struct BicubicFilterFunctor { + + template + __device__ accscalar_t operator()(accscalar_t x) const { + // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm + const accscalar_t a = -0.5; + if (x < 0) { + x = -x; + } + if (x < 1) { + return ((a + 2) * x - (a + 3)) * x * x + 1; + } + if (x < 2) { + return (((x - 5) * x + 8) * x - 4) * a; + } + return 0; + } + + static const int size = 4; +}; + +template +__device__ __forceinline__ void _compute_weights_span( + const int i, + const int input_size, + const accscalar_t scale, + const accscalar_t support, + int& xmin, + int& xsize, + accscalar_t& center) { + center = scale * (i + static_cast(0.5)); + xmin = max(static_cast(center - support + static_cast(0.5)), static_cast(0)); + xsize = min(static_cast(center + support + static_cast(0.5)), input_size) - xmin; +} + +template +__device__ __forceinline__ void _compute_weights( + scalar_t* wt_ptr, + const accscalar_t scale, + int interp_size, + const interp_filter_t& interp_filter, + accscalar_t xmin_m_center, + int xsize) { + + accscalar_t invscale = (scale >= 1.0) ? 1.0 / scale : 1.0; + accscalar_t total_w = 0.0; + int j = 0; + for (j = 0; j < xsize; j++) { + accscalar_t w = interp_filter((j + xmin_m_center + static_cast(0.5)) * invscale); + wt_ptr[j] = static_cast(w); + total_w += w; + } + for (j = 0; j < xsize; j++) { + if (total_w != 0.0) { + wt_ptr[j] /= total_w; + } + } + for (; j < interp_size; j++) { + wt_ptr[j] = static_cast(0.0); + } +} + +template +__device__ __forceinline__ accscalar_t interpolate_aa_single_dim( + const scalar_t* src, + const scalar_t* weights, + int size) { + scalar_t t = static_cast(*src); + scalar_t wts = static_cast(weights[0]); + accscalar_t output = t * wts; + + int j = 1; + for (; j < size; j++) { + wts = static_cast(weights[j]); + t = static_cast(*(src + j)); + output += t * wts; + } + return output; +} + +} + +} // namespace native +} // namespace at diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/block_reduce.cuh b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/block_reduce.cuh new file mode 100644 index 0000000000000000000000000000000000000000..df757a11761bba517afa1d19531fa04d74313872 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/block_reduce.cuh @@ -0,0 +1,143 @@ +#pragma once + +#include + +#include +#include + +namespace at { +namespace native { +namespace cuda_utils { + +constexpr int kCUDABlockReduceNumThreads = 512; +// Algorithmic limitation: BlockReduce does two WarpReduce calls, each +// of which reduces C10_WARP_SIZE elements. So, at most +// C10_WARP_SIZE**2 elements can be reduced at a time. +// NOTE: This is >= the max block size on current hardware anyway (1024). +constexpr int kCUDABlockReduceMaxThreads = C10_WARP_SIZE * C10_WARP_SIZE; + +// Sums `val` across all threads in a warp. +// +// Assumptions: +// - The size of each block should be a multiple of `C10_WARP_SIZE` +template +__inline__ __device__ T WarpReduceSum(T val) { +#pragma unroll + for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) { + val += WARP_SHFL_DOWN(val, offset); + } + return val; +} + +// Picks the maximum `val` across all threads in a warp. +// +// Assumptions: +// - The size of each block should be a multiple of `C10_WARP_SIZE` +template +__inline__ __device__ T WarpReduceMax(T val) { +#pragma unroll + for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) { + val = max_propagate_nan(val, WARP_SHFL_DOWN(val, offset)); + } + return val; +} + +struct Block1D { + static __forceinline__ __device__ int Tid() { return threadIdx.x; } + + static __forceinline__ __device__ int Warps() { + return blockDim.x / C10_WARP_SIZE; + } +}; + +struct Block2D { + static __forceinline__ __device__ int Tid() { + return threadIdx.x + threadIdx.y * blockDim.x; + } + + static __forceinline__ __device__ int Warps() { + return blockDim.x * blockDim.y / C10_WARP_SIZE; + } +}; + +// Sums `val` across all threads in a block. +// +// Warning: the return value is only valid for thread 0. +// Assumptions: +// - The size of each block should be a multiple of `C10_WARP_SIZE` +// - `shared` should be a pointer to shared memory with size of, at least, +// `sizeof(T) * number_of_warps` +template +__inline__ __device__ T BlockReduceSum(T val, T* shared) { + const int tid = B::Tid(); + const int lid = tid % C10_WARP_SIZE; + const int wid = tid / C10_WARP_SIZE; + val = WarpReduceSum(val); + __syncthreads(); // prevent races when BlockReduces are called in a row. + if (lid == 0) { + shared[wid] = val; + } + __syncthreads(); + val = (tid < B::Warps()) ? shared[lid] : T(0); + if (wid == 0) { + val = WarpReduceSum(val); + } + return val; +} + +// Picks out the maximum `val` across all threads in a block. +// +// Warning: the return value is only valid for thread 0. +// Assumptions: +// - The size of each block should be a multiple of `C10_WARP_SIZE` +// - `shared` should be a pointer to shared memory with size of, at least, +// `sizeof(T) * number_of_warps` +template +__inline__ __device__ T BlockReduceMax(T val, T* shared) { + const int tid = B::Tid(); + const int lid = tid % C10_WARP_SIZE; + const int wid = tid / C10_WARP_SIZE; + val = WarpReduceMax(val); + __syncthreads(); // prevent races when BlockReduces are called in a row. + if (lid == 0) { + shared[wid] = val; + } + __syncthreads(); + val = (tid < B::Warps()) ? shared[lid] : T(std::numeric_limits::lowest()); + if (wid == 0) { + val = WarpReduceMax(val); + } + return val; +} + +template +__inline__ __device__ T WarpReduce(T val, const ReduceOp& op) { +#pragma unroll + for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) { + val = op.combine(val, op.warp_shfl_down(val, offset)); + } + return val; +} + +template +__inline__ __device__ T +BlockReduce(T val, const ReduceOp& op, const T& identity_element, T* shared) { + const int tid = B::Tid(); + const int lid = tid % C10_WARP_SIZE; + const int wid = tid / C10_WARP_SIZE; + val = WarpReduce(val, op); + __syncthreads(); // prevent races when BlockReduces are called in a row. + if (lid == 0) { + shared[wid] = val; + } + __syncthreads(); + val = (tid < B::Warps()) ? shared[lid] : identity_element; + if (wid == 0) { + val = WarpReduce(val, op); + } + return val; +} + +} // namespace cuda_utils +} // namespace native +} // namespace at diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/fused_adam_utils.cuh b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/fused_adam_utils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..182195969ed9a32770876aed7ac9e060e8157e1f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/fused_adam_utils.cuh @@ -0,0 +1,202 @@ +#pragma once +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +enum class ADAM_MODE : uint8_t { ORIGINAL = 0, ADAMW = 1 }; + +namespace { + +constexpr uint8_t kParamIdx = 0; +constexpr uint8_t kGradIdx = 1; +constexpr uint8_t kExpAvgIdx = 2; +constexpr uint8_t kExpAvgSqIdx = 3; +constexpr uint8_t kMaxExpAvgSqIdx = 4; + +template < + typename scalar_type, + typename opmath_t, + int depth, + ADAM_MODE adam_mode, + bool amsgrad> +C10_DEVICE inline void adam_math( + scalar_type r_args[depth][kILP], + const double& lr, + const double& beta1, + const double& beta2, + const double& weight_decay, + const double& eps, + const bool& maximize, + const float* grad_scale_ptr, + const float* found_inf_ptr, + const opmath_t& bias_correction1, + const opmath_t& bias_correction2_sqrt) { + static_assert(depth == 4 || depth == 5); +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + // Load values. + opmath_t param = static_cast(r_args[kParamIdx][ii]); + opmath_t grad = static_cast(r_args[kGradIdx][ii]); + if (grad_scale_ptr) { + grad /= (static_cast(*grad_scale_ptr)); + } + const opmath_t grad_to_store = grad; + if (maximize) { + grad = -grad; + } + opmath_t exp_avg = static_cast(r_args[kExpAvgIdx][ii]); + opmath_t exp_avg_sq = static_cast(r_args[kExpAvgSqIdx][ii]); + opmath_t max_exp_avg_sq; + if (amsgrad) { + max_exp_avg_sq = static_cast(r_args[kMaxExpAvgSqIdx][ii]); + } + // Update param, grad, 1st and 2nd order momentum. + if (weight_decay != 0) { + if constexpr (adam_mode == ADAM_MODE::ORIGINAL) { + grad += param * weight_decay; + } else if constexpr (adam_mode == ADAM_MODE::ADAMW) { + param -= lr * weight_decay * param; + } + } + // todo(crcrpar): use lerp + // ref: https://developer.nvidia.com/blog/lerp-faster-cuda/ + exp_avg = beta1 * exp_avg + (1 - beta1) * grad; + exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad * grad; + const opmath_t step_size = lr / bias_correction1; + opmath_t denom; + if (amsgrad) { + max_exp_avg_sq = std::max(max_exp_avg_sq, exp_avg_sq); + denom = (std::sqrt(max_exp_avg_sq) / bias_correction2_sqrt) + eps; + } else { + denom = (std::sqrt(exp_avg_sq) / bias_correction2_sqrt) + eps; + } + param -= step_size * exp_avg / denom; + + // Store results. + r_args[kParamIdx][ii] = param; + if (grad_scale_ptr) { + r_args[kGradIdx][ii] = grad_to_store; + } + r_args[kExpAvgIdx][ii] = exp_avg; + r_args[kExpAvgSqIdx][ii] = exp_avg_sq; + if (amsgrad) { + r_args[kMaxExpAvgSqIdx][ii] = max_exp_avg_sq; + } + } +} + +// [note: Conditional Gradient Store when `optimizer.step` is called by +// GradScaler] When a user is training their model(s) with an FP16 AMP recipe, +// parameter updates are done via `grad_scaler.step(optimizer)` instead of +// `optimizer.step()`. For most optimizers, GradScaler unscales gradients on +// behalf of those optimizers. Also, before `.step`, it makes sure that all the +// gradients involved are finite, which incurs a device sync. On the other hand, +// fused optimizers set their member variable of `_step_supports_amp_scaling` to +// `True` in order to remove the device sync above. This means that fused +// optimizers have to have their CUDA kernels (a) unscale gradients and (b) skip +// parameter updates accordingly. To be functionally on par with `torch.optim` +// optimizers and `_multi_tensor` ones, the kernel below writes out gradients +// only when `grad_scale_ptr != nullptr. +template +struct FusedAdamMathFunctor { + static_assert( + depth == 4 || depth == 5, + "depth of 4 for Adam, depth of 5 for Adam with AMSGrad."); + using opmath_t = at::opmath_type; + C10_DEVICE __forceinline__ void operator()( + int chunk_size, + FusedOptimizerTensorListMetadata& tl, + const float* lr_ptr, + const double& lr, + const double& beta1, + const double& beta2, + const double& weight_decay, + const double& eps, + const bool& maximize, + const float* grad_scale_ptr, + const float* found_inf_ptr) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + const double lr_double = lr_ptr ? *lr_ptr : lr; + + if (found_inf_ptr && *found_inf_ptr == 1) { + return; + } + const auto [bias_correction1, bias_correction2_sqrt] = + [&]() -> std::pair { + auto* step_count = + reinterpret_cast(tl.state_steps_addresses[tensor_loc]); + const auto bias_correction1 = 1 - at::native::pow_(beta1, *step_count); + const auto bias_correction2 = 1 - at::native::pow_(beta2, *step_count); + const auto bias_correction2_sqrt = std::sqrt(bias_correction2); + return {bias_correction1, bias_correction2_sqrt}; + }(); + + scalar_type* args[depth]; + scalar_type r_args[depth][kILP]; + const auto n = tl.numel_for_tensor[tensor_loc] - chunk_idx * chunk_size; + + const bool all_aligned{ + init_args(args, tl, chunk_idx, chunk_size, tensor_loc)}; + if ((n % kILP == 0) && (chunk_size % kILP == 0) && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { +#pragma unroll + for (int i = 0; i < depth; i++) { + load_store(r_args[i], args[i], 0, i_start); + } + adam_math( + r_args, + lr_double, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr, + bias_correction1, + bias_correction2_sqrt); +#pragma unroll + for (int i = 0; i < depth; i++) { + if (i != kGradIdx || grad_scale_ptr) { + load_store(args[i], r_args[i], i_start, 0); + } + } + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); + adam_math( + r_args, + lr_double, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr, + bias_correction1, + bias_correction2_sqrt); +#pragma unroll + for (int i = 0; i < depth; i++) { + if (i != kGradIdx || grad_scale_ptr) { + store_args(args[i], r_args[i], i_start, chunk_size, n); + } + } + } + } + } +}; +} // namespace + +} // namespace native +} // namespace at diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/im2col.cuh b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/im2col.cuh new file mode 100644 index 0000000000000000000000000000000000000000..ec74617de34a1060fda7e9449171edcc9769dc93 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/im2col.cuh @@ -0,0 +1,345 @@ +#pragma once + +#include +#include +#include + +#include + +namespace at { +namespace native { + +using namespace at::cuda::detail; + +// Kernel for fast unfold+copy +// (borrowed from Caffe: +// https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu) +// CUDA_NUM_THREADS = 1024 + +template +C10_LAUNCH_BOUNDS_1(1024) +__global__ void im2col_kernel( + const int64_t n, + const dt* data_im, + const int64_t height, + const int64_t width, + const int64_t kernel_height, + const int64_t kernel_width, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + const int64_t height_col, + const int64_t width_col, + dt* data_col) { + CUDA_KERNEL_LOOP_TYPE(index, n, int64_t) { + int64_t w_out = index % width_col; + + int64_t idx = index / width_col; + + int64_t h_out = idx % height_col; + int64_t channel_in = idx / height_col; + int64_t channel_out = channel_in * kernel_height * kernel_width; + int64_t h_in = h_out * stride_height - pad_height; + int64_t w_in = w_out * stride_width - pad_width; + + dt* col = data_col + (channel_out * height_col + h_out) * width_col + w_out; + const dt* im = data_im + (channel_in * height + h_in) * width + w_in; + + for (int64_t i = 0; i < kernel_height; ++i) { + for (int64_t j = 0; j < kernel_width; ++j) { + int64_t h = h_in + i * dilation_height; + int64_t w = w_in + j * dilation_width; + *col = (h >= 0 && w >= 0 && h < height && w < width) + ? im[i * dilation_height * width + j * dilation_width] + : static_cast
(0); + col += height_col * width_col; + } + } + } +} + +template +void im2col( + cudaStream_t stream, + const dt* data_im, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t height_col, + const int64_t width_col, + const int64_t kernel_height, + const int64_t kernel_width, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + dt* data_col) { + // We are going to launch channels * height_col * width_col kernels, each + // kernel responsible for copying a single-channel grid. + int64_t num_kernels = channels * height_col * width_col; + // Launch CUDA_NUM_THREADS = 1024 + im2col_kernel<<>>( + num_kernels, + data_im, + height, + width, + kernel_height, + kernel_width, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_col); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +__forceinline__ __device__ void col2im_device( + const int64_t index, + const dt* data_col, + const int64_t height, + const int64_t width, + const int64_t channels, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + const int64_t height_col, + const int64_t width_col, + dt* data_im) { + accT val = static_cast(0); + const int64_t w_im = index % width + pad_width; + const int64_t h_im = (index / width) % height + pad_height; + const int64_t c_im = index / (width * height); + int64_t kernel_extent_w = (kernel_w - 1) * dilation_width + 1; + int64_t kernel_extent_h = (kernel_h - 1) * dilation_height + 1; + // compute the start and end of the output + const int64_t w_col_start = (w_im < kernel_extent_w) + ? 0 + : (w_im - kernel_extent_w) / stride_width + 1; + const int64_t w_col_end = ::min(w_im / stride_width + 1, width_col); + const int64_t h_col_start = (h_im < kernel_extent_h) + ? 0 + : (h_im - kernel_extent_h) / stride_height + 1; + const int64_t h_col_end = ::min(h_im / stride_height + 1, height_col); + + // TODO: use LCM of stride and dilation to avoid unnecessary loops + for (int64_t h_col = h_col_start; h_col < h_col_end; h_col += 1) { + for (int64_t w_col = w_col_start; w_col < w_col_end; w_col += 1) { + int64_t h_k = (h_im - h_col * stride_height); + int64_t w_k = (w_im - w_col * stride_width); + if (h_k % dilation_height == 0 && w_k % dilation_width == 0) { + h_k /= dilation_height; + w_k /= dilation_width; + int64_t data_col_index = + (((c_im * kernel_h + h_k) * kernel_w + w_k) * height_col + + h_col) * + width_col + + w_col; + val += data_col[data_col_index]; + } + } + } + data_im[index] = static_cast
(val); +} + +template +C10_LAUNCH_BOUNDS_1(512) +__global__ void col2im_kernel( + const int64_t n, + const dt* data_col, + const int64_t height, + const int64_t width, + const int64_t channels, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + const int64_t height_col, + const int64_t width_col, + dt* data_im) { + CUDA_KERNEL_LOOP(index, n) { + col2im_device( + index, + data_col, + height, + width, + channels, + kernel_h, + kernel_w, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_im); + } +} + +template +void col2im( + cudaStream_t stream, + const dt* data_col, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t height_col, + const int64_t width_col, + const int64_t patch_height, + const int64_t patch_width, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + dt* data_im) { + int64_t num_kernels = channels * height * width; + // To avoid involving atomic operations, we will launch one kernel per + // bottom dimension, and then in the kernel add up the top dimensions. + // CUDA_NUM_THREADS = 1024 + col2im_kernel + <<>>( + num_kernels, + data_col, + height, + width, + channels, + patch_height, + patch_width, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_im); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +C10_LAUNCH_BOUNDS_1(512) +__global__ void col2im_batched_kernel( + const int64_t n, + const dt* data_col, + const int64_t col_batch_stride, + const int64_t nbatch, + const int64_t height, + const int64_t width, + const int64_t channels, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + const int64_t height_col, + const int64_t width_col, + dt* data_im, + const int64_t im_batch_stride) { + using accT = at::acc_type; + const auto im_numel = n * nbatch; + + CUDA_KERNEL_LOOP_TYPE(index, im_numel, int64_t) { + const auto ibatch = index / n; + const auto slice_index = index % n; + + col2im_device( + slice_index, + data_col + ibatch * col_batch_stride, + height, + width, + channels, + kernel_h, + kernel_w, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_im + ibatch * im_batch_stride); + } +} + +template +void col2im_batched( + cudaStream_t stream, + const dt* data_col, + const int64_t col_batch_stride, + const int64_t nbatch, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t height_col, + const int64_t width_col, + const int64_t patch_height, + const int64_t patch_width, + const int64_t pad_height, + const int64_t pad_width, + const int64_t stride_height, + const int64_t stride_width, + const int64_t dilation_height, + const int64_t dilation_width, + dt* data_im, + const int64_t im_batch_stride) { + const int64_t num_kernels = channels * height * width; + const int64_t output_numel = nbatch * num_kernels; + if (output_numel == 0) { + return; // No work to do + } + + // To avoid involving atomic operations, we will launch one kernel per + // bottom dimension, and then in the kernel add up the top dimensions. + // CUDA_NUM_THREADS = 1024 + col2im_batched_kernel<<>>( + num_kernels, + data_col, + col_batch_stride, + nbatch, + height, + width, + channels, + patch_height, + patch_width, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + height_col, + width_col, + data_im, + im_batch_stride); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +} // namespace native +} // namespace at diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/jit_utils.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/jit_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..575c51c96db36e48f7be2f03336cbf9beaed17fc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/jit_utils.h @@ -0,0 +1,215 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +namespace at { namespace cuda { namespace jit { + +enum class BinaryFuncVariant {NoScalar, RhsScalar, LhsScalar}; + +struct NvrtcFunction { + CUmodule module = CUmodule(); + CUfunction function = nullptr; +}; + +struct KernelDescriptor { + std::string name; + std::string f; + c10::ScalarType f_inputs_type; + c10::ScalarType result_type; + c10::SmallVector extra_args_types; + int nInputs, nOutputs; +}; + +// Helper function to return a vector +// corresponding to the type of the arguments in parameter pack. +template +c10::SmallVector get_extra_args_types() { + return {c10::CppTypeToScalarType::value ...}; +} + +template < + typename result_type, + typename f_inputs_type, + typename... ExtraArgs> +KernelDescriptor make_kernel_descriptor( + std::string name, + std::string f, + int nInputs, + int nOutputs) { + KernelDescriptor ret; + ret.name = std::move(name); + ret.f = std::move(f); + ret.f_inputs_type = c10::CppTypeToScalarType::value; + ret.result_type = c10::CppTypeToScalarType::value; + ret.extra_args_types = get_extra_args_types(); + ret.nInputs = nInputs; + ret.nOutputs = nOutputs; + return ret; +} + +inline int can_vectorize_up_to(size_t default_alignment, void *pointer) { + auto ip = reinterpret_cast(pointer); + if (ip % (4 * default_alignment) == 0) { + return 4; + } + if (ip % (2 * default_alignment) == 0) { + return 2; + } + return 1; +} + +inline int can_vectorize_up_to(const KernelDescriptor &desc, c10::ArrayRef pointers) { + TORCH_INTERNAL_ASSERT(desc.nOutputs == 1); + TORCH_INTERNAL_ASSERT(static_cast(pointers.size()) == 1 + desc.nInputs); + + // Deals with output + auto result_size = c10::scalarTypeToTypeMeta(desc.result_type).itemsize(); + int result = can_vectorize_up_to(result_size, pointers[0]); + + // Incorporates input(s) + auto input_size = c10::scalarTypeToTypeMeta(desc.f_inputs_type).itemsize(); + for (auto i : c10::irange(1, pointers.size())) { + result = std::min(result, can_vectorize_up_to(input_size, pointers[i])); + } + + return result; +} + +std::string generate_code( + int nInputs, + int nOutputs, + const std::string& func, + const std::string& name, + const std::string& f_input_type, + const std::string& compute_type, + const std::string& result_type, + bool contiguous, + bool dynamic_casting, + BinaryFuncVariant scalar_pos, + c10::SmallVector& extra_args_typenames, + bool vectorized=false, + int vec_size=0, + bool return_by_ref=false); + +std::string generate_code( + const KernelDescriptor &desc, + bool contiguous, + bool dynamic_casting, + BinaryFuncVariant scalar_pos, + bool vectorized=false, + int vec_size=0, + bool return_by_ref=false); + +std::string generate_reduction_code( + int nOutputs, + const std::string& func, + const std::string& name, + const int vt0, + const std::string& f_inputs_type, + const std::string& reduction_accum_type, + const std::string& result_type, + bool contiguous, + bool vectorized, + int vec_size, + int max_threads_codegen); + +std::string generate_reduction_code( + const KernelDescriptor &desc, + const int vt0, + bool contiguous, + bool vectorized, + int vec_size, + int max_threads_codegen); + +NvrtcFunction jit_pwise_function( + const std::string& code, + const std::string& kernel_name); + +void launch_jitted_pwise_function( + NvrtcFunction function, + void* args[], + const dim3 nBlocks, + const dim3 kBlockSize, + const int smem=0); + +template +struct delayed_false : std::false_type { +}; + +// Defines type names +// NOTE: General case is instantiated only for invalid types. +// All the valid types have specialization using the TYPE_NAME_FN +// macro below. +template +inline std::string typeName() { + // we can't use static_assert(false) directly as the + // program will be not compiled even if the template is not + // instantiated, so we use `delayed_false` + // to make sure compiler doesn't eagerly raise + // fail this assertion. + static_assert(delayed_false::value, "invalid type for jiterator"); + return "void"; +} + +#define TYPE_NAME_FN(ctype, name) \ +template <> inline std::string typeName(){ \ + return std::string(#ctype); \ +} + +AT_FORALL_SCALAR_TYPES(TYPE_NAME_FN) +#undef TYPE_NAME_FN +// JIT uses std::complex directly, because nvRTC compile programs +// with -default-device, so there is no such issue like: +// "std::sin(complex) is __host__ only" +template <> inline std::string typeName(){ + return "bool"; +} +template <> inline std::string typeName>(){ + return "std::complex"; +} +template <> inline std::string typeName>(){ + return "std::complex"; +} +template <> inline std::string typeName>(){ + return "std::complex"; +} +template <> inline std::string typeName(){ + return "at::Half"; +} +template <> inline std::string typeName(){ + return "at::BFloat16"; +} +template <> inline std::string typeName(){ + return "at::Float8_e5m2"; +} +template <> inline std::string typeName(){ + return "at::Float8_e4m3fn"; +} +template <> inline std::string typeName() { + return "at::Float8_e5m2fnuz"; +} +template <> inline std::string typeName() { + return "at::Float8_e4m3fnuz"; +} + +#define TYPE_NAME_CASE(ctype, scalartype) \ + case ScalarType::scalartype: return typeName(); +inline std::string typeName(ScalarType t) { + switch (t) { + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(TYPE_NAME_CASE) + default: + TORCH_CHECK(false, "invalid type for jiterator"); + } +} +#undef TYPE_NAME_CASE + +TORCH_CUDA_CPP_API void initializeCudaContext(); + +}}} // namespace at::cuda::jit diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/thread_constants.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/thread_constants.h new file mode 100644 index 0000000000000000000000000000000000000000..651053d663e4c204753cdfa4ae31ed60fed34152 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/thread_constants.h @@ -0,0 +1,22 @@ +#pragma once +#include + +// Marks a lambda as executable on both the host and device. The __host__ +// attribute is important so that we can access static type information from +// the host, even if the function is typically only executed on the device. +#ifndef GPU_LAMBDA +#define GPU_LAMBDA __host__ __device__ +#endif + +#if defined(USE_ROCM) +constexpr int num_threads() { + return 256; +} +#else +constexpr uint32_t num_threads() { + return C10_WARP_SIZE * 4; +} +#endif + +constexpr int thread_work_size() { return 4; } +constexpr int block_work_size() { return thread_work_size() * num_threads(); } diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/vol2col.cuh b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/vol2col.cuh new file mode 100644 index 0000000000000000000000000000000000000000..222270e8621606e5212b0d1eee5dd5eda330b781 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/vol2col.cuh @@ -0,0 +1,264 @@ +#pragma once + +#include +#include +#include +#include + +#include + +namespace at { +namespace native { + +using namespace at::cuda::detail; + +// Kernel for fast unfold+copy on volumes +template +C10_LAUNCH_BOUNDS_1(1024) +__global__ void vol2col_kernel( + const int64_t n, + const T* data_vol, + const int depth, + const int height, + const int width, + const int ksize_t, + const int ksize_h, + const int ksize_w, + const int pad_t, + const int pad_h, + const int pad_w, + const int stride_t, + const int stride_h, + const int stride_w, + const int dilation_t, + const int dilation_h, + const int dilation_w, + const int depth_col, + const int height_col, + const int width_col, + T* data_col) { + CUDA_KERNEL_LOOP_TYPE(index, n, int64_t) { + auto w_out = index % width_col; + index /= width_col; + auto h_out = index % height_col; + index /= height_col; + auto t_out = index % depth_col; + auto channel_in = index / depth_col; + auto channel_out = channel_in * ksize_t * ksize_h * ksize_w; + auto t_in = t_out * stride_t - pad_t; + auto h_in = h_out * stride_h - pad_h; + auto w_in = w_out * stride_w - pad_w; + data_col += + ((channel_out * depth_col + t_out) * height_col + h_out) * width_col + + w_out; + data_vol += ((channel_in * depth + t_in) * height + h_in) * width + w_in; + for (int i = 0; i < ksize_t; ++i) { + for (int j = 0; j < ksize_h; ++j) { + for (int k = 0; k < ksize_w; ++k) { + auto t = t_in + i * dilation_t; + auto h = h_in + j * dilation_h; + auto w = w_in + k * dilation_w; + *data_col = (t >= 0 && h >= 0 && w >= 0 && t < depth && h < height && + w < width) + ? data_vol + [i * dilation_t * height * width + j * dilation_h * width + + k * dilation_w] + : static_cast(0); + data_col += depth_col * height_col * width_col; + } + } + } + } +} + +template +void vol2col( + cudaStream_t stream, + const T* data_vol, + const int channels, + const int depth, + const int height, + const int width, + const int depth_col, + const int height_col, + const int width_col, + const int ksize_t, + const int ksize_h, + const int ksize_w, + const int pad_t, + const int pad_h, + const int pad_w, + const int stride_t, + const int stride_h, + const int stride_w, + const int dilation_t, + const int dilation_h, + const int dilation_w, + T* data_col) { + // We are going to launch channels * depth_col * height_col * width_col + // kernels, each kernel responsible for copying a single-channel grid. + // We cast an operand to int64 so that the product will not overflow + const auto num_kernels = static_cast(channels) * depth_col * height_col * width_col; + // Launch + vol2col_kernel<<>>( + num_kernels, + data_vol, + depth, + height, + width, + ksize_t, + ksize_h, + ksize_w, + pad_t, + pad_h, + pad_w, + stride_t, + stride_h, + stride_w, + dilation_t, + dilation_h, + dilation_w, + depth_col, + height_col, + width_col, + data_col); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +__global__ void vol2im_kernel( + const int64_t n, + const T* data_col, + const unsigned depth, + const unsigned height, + const unsigned width, + const unsigned channels, + const unsigned kernel_t, + const unsigned kernel_h, + const unsigned kernel_w, + const unsigned pad_t, + const unsigned pad_h, + const unsigned pad_w, + const unsigned stride_t, + const unsigned stride_h, + const unsigned stride_w, + const unsigned dilation_t, + const unsigned dilation_h, + const unsigned dilation_w, + const unsigned depth_col, + const unsigned height_col, + const unsigned width_col, + T* data_vol) { + CUDA_KERNEL_LOOP(index, n) { + accT val = static_cast(0); + const auto w_im = index % width + pad_w; + const auto h_im = (index / width) % height + pad_h; + const auto t_im = (index / width / height) % depth + pad_t; + const auto c_im = index / (width * height * depth); + auto kernel_extent_w = (kernel_w - 1) * dilation_w + 1; + auto kernel_extent_h = (kernel_h - 1) * dilation_h + 1; + auto kernel_extent_t = (kernel_t - 1) * dilation_t + 1; + // compute the start and end of the output + const auto w_col_start = + (w_im < kernel_extent_w) ? 0 : (w_im - kernel_extent_w) / stride_w + 1; + const auto w_col_end = std::min(w_im / stride_w + 1, width_col); + const auto h_col_start = + (h_im < kernel_extent_h) ? 0 : (h_im - kernel_extent_h) / stride_h + 1; + const auto h_col_end = std::min(h_im / stride_h + 1, height_col); + const auto t_col_start = + (t_im < kernel_extent_t) ? 0 : (t_im - kernel_extent_t) / stride_t + 1; + const auto t_col_end = std::min(t_im / stride_t + 1, depth_col); + // TODO: use LCM of stride and dilation to avoid unnecessary loops + for (unsigned t_col = t_col_start; t_col < t_col_end; t_col += 1) { + for (unsigned h_col = h_col_start; h_col < h_col_end; h_col += 1) { + for (unsigned w_col = w_col_start; w_col < w_col_end; w_col += 1) { + uint64_t t_k = (t_im - t_col * stride_t); + uint64_t h_k = (h_im - h_col * stride_h); + uint64_t w_k = (w_im - w_col * stride_w); + if (t_k % dilation_t == 0 && h_k % dilation_h == 0 && + w_k % dilation_w == 0) { + t_k /= dilation_t; + h_k /= dilation_h; + w_k /= dilation_w; + const int64_t idx_k = + ((c_im * kernel_t + t_k) * kernel_h + h_k) * kernel_w + w_k; + const int64_t data_col_index = + ((idx_k * depth_col + t_col) * + height_col + h_col) * + width_col + w_col; + val += data_col[data_col_index]; + } + } + } + } + data_vol[index] = static_cast(val); + } +} + +template +void col2vol( + cudaStream_t stream, + const T* data_col, + const int64_t channels, + const int64_t depth, + const int64_t height, + const int64_t width, + const int64_t output_depth, + const int64_t output_height, + const int64_t output_width, + const int64_t patch_t, + const int64_t patch_h, + const int64_t patch_w, + const int64_t pad_t, + const int64_t pad_h, + const int64_t pad_w, + const int64_t stride_t, + const int64_t stride_h, + const int64_t stride_w, + const int64_t dilation_t, + const int64_t dilation_h, + const int64_t dilation_w, + T* data_vol) { + const auto num_kernels = channels * depth * height * width; + + auto check_fits_in_unsigned = + [](int64_t val, const char * name) { + constexpr auto umax = std::numeric_limits::max(); + TORCH_CHECK(val >= 0 && val <= umax, + name, " must fit in a 32-bit unsigned value"); + }; + check_fits_in_unsigned(num_kernels, "input size"); + check_fits_in_unsigned( + channels * patch_t * patch_h * patch_w, "channels x kernel size"); + + // To avoid involving atomic operations, we will launch one kernel per + // bottom dimension, and then in the kernel add up the top dimensions. + vol2im_kernel + <<>>( + num_kernels, + data_col, + depth, + height, + width, + channels, + patch_t, + patch_h, + patch_w, + pad_t, + pad_h, + pad_w, + stride_t, + stride_h, + stride_w, + dilation_t, + dilation_h, + dilation_w, + output_depth, + output_height, + output_width, + data_vol); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +} // namespace native +} // namespace at diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/mps/Copy.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/mps/Copy.h new file mode 100644 index 0000000000000000000000000000000000000000..cd65d8ae00e655d05f15ef1f744d771fe0d4eadc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/mps/Copy.h @@ -0,0 +1,14 @@ +// Copyright © 2022 Apple Inc. + +#pragma once +#include + +namespace at::native::mps { + +at::Tensor& mps_copy_( + at::Tensor& dst, + const at::Tensor& src, + bool non_blocking); +void copy_blit_mps(void* dst, const void* src, size_t size); + +} // namespace at::native::mps diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/mps/MPSGraphSequoiaOps.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/mps/MPSGraphSequoiaOps.h new file mode 100644 index 0000000000000000000000000000000000000000..4ec62e33bfb03dc619595831cb4a31398c32cfc4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/mps/MPSGraphSequoiaOps.h @@ -0,0 +1,38 @@ +#pragma once + +#include + +#if !defined(__MAC_15_0) && \ + (!defined(MAC_OS_X_VERSION_15_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_15_0)) + +@interface MPSNDArrayIdentity : MPSNDArrayUnaryKernel +-(MPSNDArray * __nullable) reshapeWithCommandBuffer: (__nullable id ) cmdBuf + sourceArray: (MPSNDArray * __nonnull) sourceArray + shape: (MPSShape * __nonnull) shape + destinationArray: (MPSNDArray * __nullable) destinationArray; +@end + +@interface MPSNDArrayDescriptor() +@property (readwrite, nonatomic) BOOL preferPackedRows; +@end + +@interface MPSNDArray() +-(nonnull instancetype) initWithBuffer:(id _Nonnull) buffer + offset:(NSUInteger) offset + descriptor:(MPSNDArrayDescriptor * _Nonnull) descriptor; +-(MPSNDArray * __nullable) arrayViewWithShape:(MPSShape * _Nullable) shape + strides:(MPSShape * _Nonnull) strides; +@end + +typedef NS_ENUM(NSInteger, MTLMathMode) +{ + MTLMathModeSafe = 0, + MTLMathModeRelaxed = 1, + MTLMathModeFast = 2, +}; + +@interface MTLCompileOptions() +@property (readwrite, nonatomic) MTLMathMode mathMode; +@end + +#endif diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/mps/MPSGraphSonomaOps.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/mps/MPSGraphSonomaOps.h new file mode 100644 index 0000000000000000000000000000000000000000..b4cf3ad5dbcc8200198a1483a7caa4c275e4e2f4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/mps/MPSGraphSonomaOps.h @@ -0,0 +1,53 @@ +#pragma once + +#include + +#if !defined(__MAC_14_0) && \ + (!defined(MAC_OS_X_VERSION_14_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_14_0)) + +typedef NS_ENUM(NSUInteger, MPSGraphFFTScalingMode) +{ + MPSGraphFFTScalingModeNone = 0L, + MPSGraphFFTScalingModeSize = 1L, + MPSGraphFFTScalingModeUnitary = 2L, +}; + +@interface FakeMPSGraphFFTDescriptor : NSObject +@property (readwrite, nonatomic) BOOL inverse; +@property (readwrite, nonatomic) MPSGraphFFTScalingMode scalingMode; +@property (readwrite, nonatomic) BOOL roundToOddHermitean; ++(nullable instancetype) descriptor; +@end + +@compatibility_alias MPSGraphFFTDescriptor FakeMPSGraphFFTDescriptor; + +@interface MPSGraph (SonomaOps) +-(MPSGraphTensor * _Nonnull) conjugateWithTensor:(MPSGraphTensor * _Nonnull) tensor + name:(NSString * _Nullable) name; + +-(MPSGraphTensor * _Nonnull) realPartOfTensor:(MPSGraphTensor * _Nonnull) tensor + name:(NSString * _Nullable) name; + + +-(MPSGraphTensor * _Nonnull) fastFourierTransformWithTensor:(MPSGraphTensor * _Nonnull) tensor + axes:(NSArray * _Nonnull) axes + descriptor:(MPSGraphFFTDescriptor * _Nonnull) descriptor + name:(NSString * _Nullable) name; + +-(MPSGraphTensor * _Nonnull) realToHermiteanFFTWithTensor:(MPSGraphTensor * _Nonnull) tensor + axes:(NSArray * _Nonnull) axes + descriptor:(MPSGraphFFTDescriptor * _Nonnull) descriptor + name:(NSString * _Nullable) name; + +-(MPSGraphTensor * _Nonnull) HermiteanToRealFFTWithTensor:(MPSGraphTensor * _Nonnull) tensor + axes:(NSArray * _Nonnull) axes + descriptor:(MPSGraphFFTDescriptor * _Nonnull) descriptor + name:(NSString * _Nullable) name; +@end + +// define BFloat16 enums for MacOS13 +#define MPSDataTypeBFloat16 ((MPSDataType) (MPSDataTypeAlternateEncodingBit | MPSDataTypeFloat16)) + +// define Metal version +#define MTLLanguageVersion3_1 ((MTLLanguageVersion) ((3 << 16) + 1)) +#endif diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/mps/MPSGraphVenturaOps.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/mps/MPSGraphVenturaOps.h new file mode 100644 index 0000000000000000000000000000000000000000..a1525470d11a66d251c6fcf8a0c163ff2b18047e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/mps/MPSGraphVenturaOps.h @@ -0,0 +1,197 @@ +#pragma once +#include + +// TODO: Remove me when moved to MacOS 13 +#if !defined(__MAC_13_2) && \ + (!defined(MAC_OS_X_VERSION_13_2) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_2)) + +@interface FakeMPSGraphConvolution3DOpDescriptor : NSObject + +@property (readwrite, nonatomic) NSUInteger strideInX; +@property (readwrite, nonatomic) NSUInteger strideInY; +@property (readwrite, nonatomic) NSUInteger strideInZ; +@property (readwrite, nonatomic) NSUInteger dilationRateInX; +@property (readwrite, nonatomic) NSUInteger dilationRateInY; +@property (readwrite, nonatomic) NSUInteger dilationRateInZ; + +@property (readwrite, nonatomic) NSUInteger paddingLeft; +@property (readwrite, nonatomic) NSUInteger paddingRight; +@property (readwrite, nonatomic) NSUInteger paddingTop; +@property (readwrite, nonatomic) NSUInteger paddingBottom; +@property (readwrite, nonatomic) NSUInteger paddingFront; +@property (readwrite, nonatomic) NSUInteger paddingBack; + +@property (readwrite, nonatomic) MPSGraphPaddingStyle paddingStyle; +@property (readwrite, nonatomic) MPSGraphTensorNamedDataLayout dataLayout; +@property (readwrite, nonatomic) MPSGraphTensorNamedDataLayout weightsLayout; + +@property (readwrite, nonatomic) NSUInteger groups; + +@end + +@compatibility_alias MPSGraphConvolution3DOpDescriptor FakeMPSGraphConvolution3DOpDescriptor; + +#endif + +@interface MPSGraph (VenturaOps) + +#if !defined(__MAC_13_0) && \ + (!defined(MAC_OS_X_VERSION_13_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_0)) + +typedef NS_ENUM(NSUInteger, MPSGraphResizeNearestRoundingMode) +{ + MPSGraphResizeNearestRoundingModeRoundPreferCeil = 0L, + MPSGraphResizeNearestRoundingModeRoundPreferFloor = 1L, + MPSGraphResizeNearestRoundingModeCeil = 2L, + MPSGraphResizeNearestRoundingModeFloor = 3L, + MPSGraphResizeNearestRoundingModeRoundToEven = 4L, + MPSGraphResizeNearestRoundingModeRoundToOdd = 5L, +}; + +// Define complex enums for MacOS 12 +#define MPSDataTypeComplexBit 0x01000000 +#define MPSDataTypeComplexFloat32 ((MPSDataType) (MPSDataTypeFloatBit | MPSDataTypeComplexBit | 64)) +#define MPSDataTypeComplexFloat16 ((MPSDataType) (MPSDataTypeFloatBit | MPSDataTypeComplexBit | 32)) +#endif + +- (MPSGraphTensor * _Nonnull) convolution3DWithSourceTensor:(MPSGraphTensor * _Nonnull) source + weightsTensor:(MPSGraphTensor * _Nonnull) weights + descriptor:(MPSGraphConvolution3DOpDescriptor * _Nonnull) descriptor + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) convolution3DDataGradientWithIncomingGradientTensor:(MPSGraphTensor * _Nonnull) incomingGradient + weightsTensor:(MPSGraphTensor * _Nonnull) weights + outputShape:(MPSShape * _Nonnull) outputShape + forwardConvolutionDescriptor:(MPSGraphConvolution3DOpDescriptor * _Nonnull) forwardConvolutionDescriptor + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) convolution3DWeightsGradientWithIncomingGradientTensor:(MPSGraphTensor * _Nonnull) incomingGradient + sourceTensor:(MPSGraphTensor * _Nonnull) source + outputShape:(MPSShape * _Nonnull) outputShape + forwardConvolutionDescriptor:(MPSGraphConvolution3DOpDescriptor * _Nonnull) forwardConvolutionDescriptor + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull)cumulativeSumWithTensor:(MPSGraphTensor * _Nonnull)tensor + axis:(NSInteger)axis + name:(NSString * _Nullable)name; + +- (MPSGraphTensor * _Nonnull)sortWithTensor:(MPSGraphTensor * _Nonnull)tensor + axis:(NSInteger)axis + name:(NSString * _Nullable)name; + +- (MPSGraphTensor * _Nonnull) sortWithTensor:(MPSGraphTensor * _Nonnull) tensor + axis:(NSInteger) axis + descending:(BOOL) descending + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) sortWithTensor:(MPSGraphTensor * _Nonnull) tensor + axisTensor:(MPSGraphTensor * _Nonnull) axisTensor + descending:(BOOL) descending + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) sortWithTensor:(MPSGraphTensor * _Nonnull) tensor + axisTensor:(MPSGraphTensor * _Nonnull) axisTensor + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull)argSortWithTensor:(MPSGraphTensor * _Nonnull)tensor + axis:(NSInteger)axis + name:(NSString * _Nullable)name; + +- (MPSGraphTensor * _Nonnull) argSortWithTensor:(MPSGraphTensor * _Nonnull) tensor + axis:(NSInteger) axis + descending:(BOOL) descending + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) argSortWithTensor:(MPSGraphTensor * _Nonnull) tensor + axisTensor:(MPSGraphTensor * _Nonnull) axisTensor + descending:(BOOL) descending + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) argSortWithTensor:(MPSGraphTensor * _Nonnull) tensor + axisTensor:(MPSGraphTensor * _Nonnull) axisTensor + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull)inverseOfTensor:(MPSGraphTensor * _Nonnull) inputTensor + name:(NSString * _Nullable)name; + +- (MPSGraphTensor * _Nonnull) resizeNearestWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor + sizeTensor:(MPSGraphTensor * _Nonnull) size + nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode + centerResult:(BOOL) centerResult + alignCorners:(BOOL) alignCorners + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) resizeNearestWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor + sizeTensor:(MPSGraphTensor * _Nonnull) size + scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset + nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) resizeBilinearWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor + sizeTensor:(MPSGraphTensor * _Nonnull) size + centerResult:(BOOL) centerResult + alignCorners:(BOOL) alignCorners + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) resizeBilinearWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor + sizeTensor:(MPSGraphTensor * _Nonnull) size + scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) resizeNearestWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient + input:(MPSGraphTensor * _Nonnull) input + nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode + centerResult:(BOOL) centerResult + alignCorners:(BOOL) alignCorners + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) resizeNearestWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient + input:(MPSGraphTensor * _Nonnull) input + scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset + nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) resizeBilinearWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient + input:(MPSGraphTensor * _Nonnull) input + centerResult:(BOOL) centerResult + alignCorners:(BOOL) alignCorners + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) resizeBilinearWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient + input:(MPSGraphTensor * _Nonnull) input + scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) sampleGridWithSourceTensor:(MPSGraphTensor * _Nonnull) source + coordinateTensor:(MPSGraphTensor * _Nonnull) coordinates + layout:(MPSGraphTensorNamedDataLayout) layout + normalizeCoordinates:(BOOL) normalizeCoordinates + relativeCoordinates:(BOOL) relativeCoordinates + alignCorners:(BOOL) alignCorners + paddingMode:(MPSGraphPaddingMode) paddingMode + samplingMode:(MPSGraphResizeMode) samplingMode + constantValue:(double) constantValue + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) sampleGridWithSourceTensor:(MPSGraphTensor * _Nonnull) source + coordinateTensor:(MPSGraphTensor * _Nonnull) coordinates + layout:(MPSGraphTensorNamedDataLayout) layout + normalizeCoordinates:(BOOL) normalizeCoordinates + relativeCoordinates:(BOOL) relativeCoordinates + alignCorners:(BOOL) alignCorners + paddingMode:(MPSGraphPaddingMode) paddingMode + nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode + constantValue:(double) constantValue + name:(NSString * _Nullable) name; +- (MPSGraphTensor * _Nonnull) truncateWithTensor:(MPSGraphTensor * _Nonnull) tensor + name:(NSString * _Nullable) name; + +@end diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/mps/OperationUtils.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/mps/OperationUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..7016292e6efaa38d8f8828a2a0896b6b27f34e25 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/mps/OperationUtils.h @@ -0,0 +1,463 @@ +// Copyright © 2022 Apple Inc. + +#pragma once + +#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#endif + +#include + +// Fwd declarations +namespace at { + struct TensorIteratorBase; +} +using namespace at::mps; + +namespace at::native::mps { + +void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()); + +struct MPSScalar { + id getMTLBuffer() const { return __builtin_bit_cast(id, buffer.get()); } + + size_t size = 0; + ScalarType type = ScalarType::Undefined; + c10::DataPtr buffer; // stores MTLBuffer (frees buffer if MPSScalar instance goes out of scope) + union { + float f; // MPS doesn't support 'double' + at::Half h; + int64_t i; + bool b; + c10::complex cf; + c10::complex ch; + at::BFloat16 bf16; + } value {}; +}; + +void runMPSGraph(MPSStream* mpsStream, + MPSGraph* mpsGraph, + NSDictionary* feeds, + NSDictionary* results); + +MPSDataType getMPSDataType(ScalarType scalar_type); +static inline MPSDataType getMPSDataType(const Tensor& t) { + return getMPSDataType(t.scalar_type()); +} +MPSDataType getMPSScalarType(ScalarType scalar_type); +static inline MPSDataType getMPSScalarType(const Tensor& t) { + return getMPSScalarType(t.scalar_type()); +} +MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type); +std::string getMPSTypeString(ScalarType scalar_type, bool short_name = false); +static inline std::string getMPSTypeString(const Tensor& t, bool short_name = false) { + return getMPSTypeString(t.scalar_type(), short_name); +} +std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type); +static inline std::string scalarToMetalTypeString(const Tensor& t) { + return scalarToMetalTypeString(t.scalar_type()); +} +NSArray* getTensorAxes(const Tensor& t); +NSArray* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim); +std::string getMPSShapeString(MPSShape* shape); +std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = true, bool exclude_shape = false); +std::string getArrayRefString(const IntArrayRef s); +// use has_storage() on the returned tensor to determine if src actually is a view +Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst); +Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output); +bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape); +MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape *mpsShape, const MPSDataType mpsDataType); +MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64 = false); +MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64 = false); + +MPSNDArray* getMPSNDArray(const at::Tensor& t, const IntArrayRef& sizes = {}, const IntArrayRef& strides = {}); +MPSNDArray* getMPSNDArray(const at::Tensor& t, MPSShape* sizes = nil, MPSShape* strides = nil); +// The MPSShape could vary based on memory format +Tensor getTensorView(const Tensor& t, MPSShape* shape); +MPSShape* getMPSShape(const Tensor& t, c10::MemoryFormat memory_format = MemoryFormat::Contiguous); +MPSShape* getMPSShape(IntArrayRef sizes, c10::MemoryFormat memory_format = MemoryFormat::Contiguous); + +static inline id getMTLBufferStorage(const at::Tensor& tensor) { + return __builtin_bit_cast(id, tensor.storage().data()); +} + +class Placeholder { + public: + Placeholder() : _placeholder(nullptr), _value(nullptr), _tensor(Tensor()) {} + Placeholder(MPSGraphTensor* mpsGraphTensor) : _placeholder(mpsGraphTensor), _value(nullptr), _tensor(Tensor()) {} + Placeholder(MPSGraphTensor* mpsGraphTensor, MPSNDArray* mpsNDArray); + Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& self, MPSShape *mpsShape = nullptr, + bool gatherTensorData = true, MPSDataType dataType = MPSDataTypeInvalid, bool useMPSStridedAPI = true); + MPSGraphTensor* getMPSGraphTensor() { + return _placeholder; + } + MPSGraphTensorData* getMPSGraphTensorData() { + return _value; + } + bool isIntermediate() { + return _value == nullptr; + } + + private: + MPSGraphTensor* _placeholder; + MPSGraphTensorData* _value; + Tensor _tensor; +}; + +void resize_tensor(Tensor* output); +Tensor wrapped_scalar_tensor_mps(const Scalar& scalar, const Device device); +MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor); +MPSGraphTensor* convertNHWCtoNCHW(MPSGraph *mpsGraph, MPSGraphTensor* tensor); +MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, ScalarType toType); +MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, MPSDataType toType); +MPSGraphTensorData *getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStream, const Tensor& tensor); +MPSGraphTensorData* getMPSGraphTensorFromScalar(MPSStream* mpsStream, MPSScalar& scalar); + +MPSGraph* make_mps_graph(); +void printTensorNDArray(const Tensor& t); +MPSNDArray* ndArrayFromTensor(const Tensor& tensor, MPSShape *shape, MPSDataType mpsType); + +MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType); +MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType, MPSShape* mpsShape); +MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, const Tensor& tensor); +MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType); +MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, const Scalar& scalar); + +string get_mem_format_string(c10::MemoryFormat memory_format); + +using MPSCacheKey = uint64_t; + +// derive this class to cache a graph and its inputs/outputs +// can be used to store any NSObject +struct MPSCachedGraph +{ + MPSCachedGraph(NSObject *object) : _object([object retain]) {} + virtual ~MPSCachedGraph() { + [_object release]; + _object = nullptr; + } + + template + inline T* as() { + return static_cast(this); + } + + MPSGraph *graph() const { return (MPSGraph *)_object; } + NSObject *object() const { return _object; } +private: + NSObject *_object = nullptr; +}; + +struct MPSUnaryCachedGraph : public MPSCachedGraph +{ + MPSUnaryCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor *inputTensor_ = nil; + MPSGraphTensor *outputTensor_ = nil; +}; + +struct MPSUnaryGradCachedGraph : public MPSCachedGraph +{ + MPSUnaryGradCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor *gradOutputTensor_ = nil; + MPSGraphTensor *inputTensor_ = nil; + MPSGraphTensor *outputTensor_ = nil; // some backward input is actually the forward's output + MPSGraphTensor *gradInputTensor_ = nil; +}; + +struct MPSBinaryCachedGraph : public MPSCachedGraph +{ + MPSBinaryCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor *inputTensor_ = nil; + MPSGraphTensor *otherTensor_ = nil; + MPSGraphTensor *outputTensor_ = nil; +}; + +struct MPSBinaryGradCachedGraph : public MPSCachedGraph +{ + MPSBinaryGradCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor *gradOutputTensor_ = nil; + MPSGraphTensor *inputTensor_ = nil; + MPSGraphTensor *otherTensor_ = nil; + MPSGraphTensor *gradInputTensor_ = nil; +}; + +// TODO: Improve the overall design of MPSGraphCache. +// https://github.com/pytorch/pytorch/issues/77176 +// Cache holding various keys mapped to graphs +struct MPSGraphCache +{ + typedef MPSCachedGraph * (^CreateCachedGraphBlock)(); + + struct CacheEntry { + CacheEntry(const std::string& key, MPSCachedGraph *cachedGraph) : cachedGraph_(cachedGraph), key_(key) {} + MPSCachedGraph* cachedGraph_ = nullptr; + std::string key_; + }; + + public: + + static MPSGraphCache* getInstance() { + if(_instance_cache == nullptr) { + _instance_cache = new MPSGraphCache(); + } + return _instance_cache; + } + + ~MPSGraphCache() { + dispatch_release(serialQueue_); + + for (const auto& i : cache_) { + delete i.second.cachedGraph_; + } + } + + // Disallow the copy constructor and operator= functions + MPSGraphCache(const MPSGraphCache&) = delete; + void operator=(const MPSGraphCache&) = delete; + + MPSCachedGraph* CreateCachedGraph(const std::string& key, CreateCachedGraphBlock createCacheBlock) { + + __block MPSCachedGraph* cachedGraph = nil; + + MPSCacheKey hash = std::hash{}(key); + + dispatch_sync_with_rethrow(serialQueue_, ^() { + // verify the cached entry doesn't already exist + if (cache_.count(hash) != 0) { + auto& entry = cache_.at(hash); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached graph!\n"); + cachedGraph = entry.cachedGraph_; + } else { + cachedGraph = createCacheBlock(); + CacheEntry entry(key, cachedGraph); + cache_.emplace(hash, entry); + profileCachedGraph(entry); + } + }); + return cachedGraph; + } + + template + inline T* CreateCachedGraphAs(const std::string& key, CreateCachedGraphBlock createCacheBlock) { + return static_cast(CreateCachedGraph(key, createCacheBlock)); + } + + MPSCachedGraph* LookUp(const std::string& key) const { + + __block MPSCachedGraph* cachedGraph = nullptr; + + MPSCacheKey hash = std::hash{}(key); + + dispatch_sync(serialQueue_, ^() { + + if (cache_.count(hash) != 0) { + auto& entry = cache_.at(hash); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached graph!\n"); + cachedGraph = entry.cachedGraph_; + profileCachedGraph(entry); + } + }); + return cachedGraph; + } + + template + inline T* LookUpAs(const std::string& key) const { + return static_cast(LookUp(key)); + } + + private: + MPSGraphCache() { + serialQueue_ = dispatch_queue_create("cache queue", DISPATCH_QUEUE_SERIAL); + } + // this is defined in OperationUtils.mm to not include + // MPSProfiler.h in header OperationUtils.h + void profileCachedGraph(const CacheEntry& cacheEntry) const; + + static MPSGraphCache* _instance_cache; + std::unordered_map cache_; + dispatch_queue_t serialQueue_ = nullptr; + +}; + +// Common template for creating graph with a specified cache if missing +template +inline T* LookUpOrCreateCachedGraph(const std::string& key, std::function instantiate) { + auto cache_ = MPSGraphCache::getInstance(); + if (auto rc = cache_->LookUpAs(key)) { + return rc; + } + return cache_->CreateCachedGraphAs(key, ^mps::MPSCachedGraph*() { + T* newCachedGraph = nil; + @autoreleasepool { + // Initialize graph + auto mpsGraph = mps::make_mps_graph(); + newCachedGraph = new T(mpsGraph); + instantiate(mpsGraph, newCachedGraph); + } + return newCachedGraph; + }); +} + +// Common math operations +MPSGraphTensor* log1p(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor); + +#define MPS_CHECK_INT64_OP_SUPPORTED(input_tensor, mac_os_13_3_plus, op_name) \ + if (!mac_os_13_3_plus && input_tensor.scalar_type() == kLong) { \ + TORCH_WARN_ONCE("MPS: no support for int64 for ", op_name, \ + ", downcasting to a smaller data type (int32/float32). Native support for int64 has been added in macOS 13.3."); \ + } + +/** + * Returns distance from lowest to highest element offset in given tensor. + */ +size_t compute_storage_numel_distance(const at::Tensor& t); + +/** + * Checks whether tensor is mapped to a contiguous area in the storage. + */ +inline bool is_dense_in_storage(const at::Tensor& t) { + return compute_storage_numel_distance(t) == static_cast(t.numel()); +} + + +class MetalShaderLibrary { +public: + MetalShaderLibrary(const std::string& src): shaderSource(src), nparams(0), compile_options(nullptr){} + MetalShaderLibrary(const std::string& src, unsigned nparams_): shaderSource(src), nparams(nparams_), compile_options(nullptr){} + MetalShaderLibrary(const std::string& src, unsigned nparams_, MTLCompileOptions* compile_options_): shaderSource(src), nparams(nparams_), compile_options(compile_options_) {} + MetalShaderLibrary(const MetalShaderLibrary&) = delete; + inline id getPipelineStateForFunc(const std::string& fname) { + return getLibraryPipelineState(getLibrary(), fname).first; + } + id getPipelineStateForFunc(const std::string& fname, const std::initializer_list& params) { + return getLibraryPipelineState(getLibrary(params), fname).first; + } + inline id getMTLFunction(const std::string& fname) { + return getLibraryPipelineState(getLibrary(), fname).second; + } + id getMTLFunction(const std::string& fname, const std::initializer_list& params) { + return getLibraryPipelineState(getLibrary(params), fname).second; + } +private: + std::pair, id> getLibraryPipelineState(id lib, const std::string& fname); + id getLibrary(); + id getLibrary(const std::initializer_list& params); + + id compileLibrary(const std::string& src); + std::string shaderSource; + unsigned nparams; + MTLCompileOptions* compile_options; + id library = nil; + std::unordered_map> libMap; + std::unordered_map, id>> cplMap; +}; + +template, encoder_t> || std::is_same_v, encoder_t>>> +static inline void mtl_setBuffer(encoder_t encoder, const Tensor& t, unsigned idx) { + [encoder setBuffer:getMTLBufferStorage(t) + offset:t.storage_offset() * t.element_size() + atIndex:idx]; +} + +template || std::is_same_v>> +static inline void mtl_setBytes(id encoder, const T val, unsigned idx) { + [encoder setBytes:&val length:sizeof(T) atIndex: idx]; +} + +template>> +static inline void mtl_setBytes(id encoder, const Container& values, unsigned idx) { + [encoder setBytes:values.data() length:sizeof(typename Container::value_type) * values.size() atIndex: idx]; +} + +static inline void mtl_dispatch1DJob(id encoder, + id cplState, + uint32_t length) { + const uint32_t maxThreadsPerGroup = [cplState maxTotalThreadsPerThreadgroup]; + auto size = MTLSizeMake(length, 1, 1); + auto threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, length), 1, 1); + [encoder dispatchThreads:size threadsPerThreadgroup:threadGroupSize]; +} + +id generateKernelDataOffsets(id commandEncoder, const TensorIteratorBase& iter, bool use_64bit_index = false); + +inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1) { + return @{ p1.getMPSGraphTensor(): p1.getMPSGraphTensorData() }; +} + +inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2) { + return @{ + p1.getMPSGraphTensor(): p1.getMPSGraphTensorData(), + p2.getMPSGraphTensor(): p2.getMPSGraphTensorData(), + }; +} + +inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2, Placeholder& p3) { + return @{ + p1.getMPSGraphTensor(): p1.getMPSGraphTensorData(), + p2.getMPSGraphTensor(): p2.getMPSGraphTensorData(), + p3.getMPSGraphTensor(): p3.getMPSGraphTensorData(), + }; +} + +inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2, Placeholder& p3, Placeholder& p4) { + return @{ + p1.getMPSGraphTensor(): p1.getMPSGraphTensorData(), + p2.getMPSGraphTensor(): p2.getMPSGraphTensorData(), + p3.getMPSGraphTensor(): p3.getMPSGraphTensorData(), + p4.getMPSGraphTensor(): p4.getMPSGraphTensorData(), + }; +} + +inline void runMPSGraph(MPSStream* stream, MPSGraph* graph, NSDictionary* feeds, Placeholder& result) { + runMPSGraph(stream, graph, feeds, dictionaryFromPlaceholders(result)); +} + +inline bool supportsComplex() { + return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS); +} + +// MPS yet to support double types, but starting from MacOS 14, supports bfloat16 +inline bool supportedFloatingType(ScalarType dtype) { + return dtype == kFloat || dtype == kHalf || dtype == kBFloat16; +} + +inline bool supportedFloatingType(const Tensor& t) { + return supportedFloatingType(t.scalar_type()); +} + +inline bool supportedFloatingOrComplexType(ScalarType dtype) { + if (dtype == kComplexFloat || dtype == kComplexHalf) { + return supportsComplex(); + } + return supportedFloatingType(dtype); +} +inline bool supportedFloatingOrComplexType(const Tensor& t) { + return supportedFloatingOrComplexType(t.scalar_type()); +} + + +inline bool needsGather(const Tensor& t) { + static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS); + return !is_macOS_15_0_or_newer && (!t.is_contiguous() || t.storage_offset()) ; +} + +} // namespace at::native::mps diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/mps/TensorFactory.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/mps/TensorFactory.h new file mode 100644 index 0000000000000000000000000000000000000000..e6c9da0babbbedc71e41820aabf7c1c71274bd44 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/mps/TensorFactory.h @@ -0,0 +1,12 @@ +// Copyright © 2022 Apple Inc. + +#define AT_DISPATCH_MPS_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)) diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/mps/UnaryConstants.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/mps/UnaryConstants.h new file mode 100644 index 0000000000000000000000000000000000000000..8a9a66846449f09852ad05682f8f21e7be90716b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/mps/UnaryConstants.h @@ -0,0 +1,80 @@ +#pragma once + +const char* UNARY_KERNEL_TEMPLATE = R"METAL( +#include +using namespace metal; + +constant float a[4] = {{0.886226899, -1.645349621, 0.914624893, -0.140543331}}; +constant float b[4] = {{-2.118377725, 1.442710462, -0.329097515, 0.012229801}}; +constant float c[4] = {{-1.970840454, -1.624906493, 3.429567803, 1.641345311}}; +constant float d[2] = {{3.543889200, 1.637067800}}; + +kernel void erfinv_kernel( device {0} *output [[buffer(0)]], + device {1} *input [[buffer(1)]], + uint index [[thread_position_in_grid]]) {{ + + float y = input[index]; + float x, z, num, dem; /*working variables */ + /* coefficients in rational expansion */ + + float y_abs = abs(y); + if (y_abs >= 1.0f) {{ + output[index] = {0}( y_abs > 1.0f ? NAN : copysign(INFINITY, y)); + return; + }} + if (y_abs <= 0.7f) {{ + z = y * y; + num = ((a[3] * z + a[2]) * z + a[1])*z + a[0]; + dem = (((b[3] * z + b[2]) * z + b[1]) * z +b[0]) * z + 1.0f; + x = y * num / dem; + }} else {{ + z = sqrt(-1.0f*log((1.0-y_abs)/2.0)); + num = ((c[3] * z + c[2]) * z + c[1]) * z + c[0]; + dem = (d[1] * z + d[0]) * z + 1.0f; + x = copysign(num, y) / dem; + }} + + output[index] = {0}(x); +}} + +kernel void exp_kernel( device {0} *output [[buffer(0)]], + device {1} *input [[ buffer(1)]], + uint index [[thread_position_in_grid]]) {{ + output[index] = {0}(precise::exp(input[index])); +}} + +kernel void exp_complex_kernel( device {0}2 *output [[buffer(0)]], + device {0}2 *input [[ buffer(1)]], + uint index [[thread_position_in_grid]]) {{ + output[index].x = {0}(precise::exp(input[index].x)*precise::cos(input[index].y)); + output[index].y = {0}(precise::exp(input[index].x)*precise::sin(input[index].y)); +}} + +kernel void tanh_kernel( device {0} *output [[buffer(0)]], + device {1} *input [[ buffer(1)]], + uint index [[thread_position_in_grid]]) {{ + output[index] = {0}(precise::tanh(input[index])); +}} + + +#if __METAL_VERSION__ >= 310 +bfloat dot(bfloat2 a, bfloat2 b) {{ + return a.x * b.x + a.y * b.y; +}} +#endif + +template +T complex_div(T a, T b) {{ + auto denom = dot(b, b); + return T(dot(a, b), a.y * b.x - a.x * b.y)/denom; +}} + +kernel void tanh_complex_kernel( device {0}2 *output [[buffer(0)]], + device {0}2 *input [[ buffer(1)]], + uint index [[thread_position_in_grid]]) {{ + //tanh(x+iy)=(tanh(x)+itan(y))/(1+itahnh(x)*tan(y)); + auto tanh_x = {0}(precise::tanh(input[index].x)); + auto tan_y = {0}(precise::tan(input[index].y)); + output[index] = complex_div({0}2(tanh_x, tan_y), {0}2({0}(1), tanh_x * tan_y)); +}} +)METAL"; diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/nested/NestedTensorBinaryOps.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/nested/NestedTensorBinaryOps.h new file mode 100644 index 0000000000000000000000000000000000000000..298c1822418cc86a8530805bb6d7ab515176f815 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/nested/NestedTensorBinaryOps.h @@ -0,0 +1,18 @@ +#pragma once + +#include +#include + +namespace at::native { + +enum class NESTED_DENSE_OP : uint8_t { ADD, MUL }; + +using nested_dense_elementwise_fn = void (*)( + Tensor& result, + const Tensor& self, + const Tensor& other, + const NESTED_DENSE_OP& op); + +DECLARE_DISPATCH(nested_dense_elementwise_fn, nested_dense_elementwise_stub); + +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/nested/NestedTensorMath.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/nested/NestedTensorMath.h new file mode 100644 index 0000000000000000000000000000000000000000..b96c4d72288221116253991ee9f22d43d9a1cfb7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/nested/NestedTensorMath.h @@ -0,0 +1,79 @@ +#pragma once + +#include +#include +#include + +namespace at::native { + +TORCH_API Tensor NestedTensor_to_padded_tensor_generic( + const Tensor& t, + double padding, + OptionalIntArrayRef output_size); + +template +Tensor map_nt(const Tensor& nt, Func f) { + auto* nt_impl = get_nested_tensor_impl(nt); + const auto& sizes = nt_impl->get_nested_sizes(); + return at::detail::make_tensor(f(nt_impl->get_buffer()), sizes); +} +template +Tensor map_nt_binary(const Tensor& nt_1, const Tensor& nt_2, Func f){ + auto* nt_impl_1 = get_nested_tensor_impl(nt_1); + auto* nt_impl_2 = get_nested_tensor_impl(nt_2); + const auto& sizes = nt_impl_1->get_nested_sizes(); + return at::detail::make_tensor(f(nt_impl_1->get_buffer(), nt_impl_2->get_buffer()), sizes); +} + +C10_ALWAYS_INLINE std::pair _check_nested_layer_norm_inputs( + const NestedTensorImpl& input, + IntArrayRef normalized_shape, + const Tensor& weight /* optional */, + const Tensor& bias /* optional */) { + + const size_t normalized_ndim = normalized_shape.size(); + TORCH_CHECK( + normalized_ndim >= 1, + "Expected normalized_shape to be at least 1-dimensional, i.e., ", + "containing at least one element, but got normalized_shape = ", + normalized_shape); + TORCH_CHECK( + !weight.defined() || weight.sizes().equals(normalized_shape), + "Expected weight to be of same shape as normalized_shape, but got ", + "weight of shape ", + weight.sizes(), + " and normalized_shape = ", + normalized_shape); + TORCH_CHECK( + !bias.defined() || bias.sizes().equals(normalized_shape), + "Expected bias to be of same shape as normalized_shape, but got ", + "bias of shape ", + bias.sizes(), + " and normalized_shape = ", + normalized_shape); + + // Check that the normalized_shape has the exact same sizes as the last dimensions from the NestedTensor input + // Also, compute M and N considering the idiosyncracies of NestedTensors + int64_t N = 1; + for (const auto i: c10::irange(normalized_ndim)) { + TORCH_CHECK( + input.opt_size(-normalized_ndim + i) != std::nullopt, + "normalized_shape extends into irregular dimensions for the nested tensor" + ); + TORCH_CHECK( + normalized_shape[i] == *input.opt_size(-normalized_ndim + i), + "The shape at dimension ", + i, + "of normalized_shape doesn't match the input" + ); + N *= normalized_shape[i]; + } + + const int64_t M = input.numel() / N; + + return std::make_pair(M, N); +} + +Tensor reshape_nested(const Tensor& self, IntArrayRef proposed_shape); + +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/nested/NestedTensorTransformerFunctions.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/nested/NestedTensorTransformerFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..45e1b98c943bc40903048f25542dfcbf1954c85c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/nested/NestedTensorTransformerFunctions.h @@ -0,0 +1,103 @@ +/** + * Transformer-specific NestedTensor utility functions. + * + * Not co-located with NestedTensor core code yet because they only + * support specific cases needed in transformers. + */ +#pragma once + +#include + +#include +#include + +namespace c10 { +class Scalar; +} // namespace c10 + +namespace at { +class Tensor; +namespace native { +struct NestedTensorImpl; + +// Requires that self is a contiguous NestedTensor, other is not a +// NestedTensor, self.dim() == 3, and other.dim() == 2. Also, self +// must have a consistent last dimension across its included Tensors +// and that dimension must match other.size(0). +Tensor NestedTensor_matmul(const Tensor& self, const Tensor& other); + +// Requires that mat1 is a contiguous NestedTensor, self & mat2 are +// not NestedTensors, mat1.dim() == 3, mat2.dim() == 2, and that mat1 +// has a consistent last dimension across its included Tensors that +// matches mat2.size(0). +Tensor NestedTensor_times_Tensor_plus_Tensor_addmm( + const Tensor& self, + const Tensor& mat1, + const Tensor& mat2, + const c10::Scalar& beta, + const c10::Scalar& alpha, + std::optional use_gelu = std::nullopt); + +Tensor NestedTensor_add_NestedTensor_in_place( + const Tensor& self, + const Tensor& other); + +TORCH_API Tensor NestedTensor_batch_offsets_from_size_tensor( + const Tensor& sizes, + int64_t extra_elements); + +Tensor NestedTensor_from_padded_tensor_cpu( + const Tensor& padded, + const NestedTensorImpl& nt); + +Tensor NestedTensor_to_mask(const Tensor& nt, std::optional mask_dim, std::optional mask_dim_length); + +template +void remove_padding_kernelLauncher( + const T* input, + T* output, + const int* offsets, + const int* input_sizes, + const int* output_sizes, + int output_dim, + const int batch_size); + +template +void remove_padding_transform0213_kernelLauncher( + const T* input, + T* output, + const int* offsets, + const int* input_sizes, + const int* output_sizes, + int output_dim, + const int batch_size); + +template +void add_padding_kernelLauncher( + T* input, + T* output, + T padding_value, + const int* offsets, + const int* input_sizes, + int input_dim, + const std::vector& output_sizes, + const int batch_size, + const int output_batch_size); + +TORCH_API Tensor flash_attention_helper( + const Tensor& query, + const Tensor& key, + const Tensor& value, + double dropout_p, + bool need_attn_weights, + bool is_causal); + +TORCH_API std::tuple mem_efficient_helper_nested_unpacked( + const Tensor& query, + const Tensor& key, + const Tensor& value, + double dropout_p, + bool need_attn_weights, + bool is_causal); +} // namespace native +} // namespace at diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/nested/NestedTensorTransformerUtils.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/nested/NestedTensorTransformerUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..d3acf229a2383e1967f0c5502498620d502b6b52 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/nested/NestedTensorTransformerUtils.h @@ -0,0 +1,39 @@ +#include + +namespace at::native::preprocessing { + +/** + * This function will take nested query, key, and value + * and will preprocess it in order to run with either + * the flash-attention or efficient-attention kernels. + * @return A tuple containing all the necessary data for running the fused + * kernels + */ +std::tuple +sdpa_nested_preprocessing( + const Tensor& query, + const Tensor& key, + const Tensor& value); + +/** + * This function will take nested query, key, and value, grad_out, and out + * and will preprocess it in order to run with either + * the flash-attention or efficient-attention kernels backwards. + * We use both functions to avoid having to do the same preprocessing + * for cumulative_sequence_length_q and cumulative_sequence_length_kv + * @return A tuple containing all the necessary data for running the fused + * kernels + */ +std::tuple +sdpa_nested_preprocessing_backward( + const at::Tensor& grad_out_, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& out, + const Tensor& cumulative_sequence_length_q, + const Tensor& cumulative_sequence_length_kv, + const int64_t max_seqlen_batch_q, + const int64_t max_seqlen_batch_kv); + +} // namespace at::native::preprocessing diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/nested/NestedTensorUtils.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/nested/NestedTensorUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..0dd89e74eaa1495dcddb620b8a995334da0b862e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/nested/NestedTensorUtils.h @@ -0,0 +1,447 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS + +#include +#include +#else +#include +#include +#include +#include +#include +#include +#endif + +#include +#include + +namespace at::native { +struct NestedTensorImpl; + +// The following functions are used to construct nested tensors from buffers and +// metadata. + +inline at::Tensor wrap_buffer(at::Tensor buffer, at::Tensor nested_sizes) { + TORCH_CHECK( + buffer.dim() == 1, + "Expected given buffer to be 1dim, but got ", + buffer.dim(), + " instead."); + TORCH_CHECK( + buffer.is_contiguous(), "Expected given buffer to be contiguous."); + return at::detail::make_tensor( + std::move(buffer), std::move(nested_sizes)); +} + +// TODO: Figure out if we need a non-moving wrap_buffer() +inline at::Tensor wrap_buffer( + at::Tensor buffer, + at::Tensor nested_sizes, + at::Tensor nested_strides, + at::Tensor storage_offsets) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + buffer.is_contiguous(), "Given buffer must be contiguous."); + return at::detail::make_tensor( + std::move(buffer), + std::move(nested_sizes), + std::move(nested_strides), + std::move(storage_offsets)); +} + +inline at::Tensor get_buffer(const at::Tensor& tensor) { + return get_nested_tensor_impl(tensor)->get_buffer(); +} + +/** + * Create a new nested tensor that is a view of a base nested tensor + * + * create_view_tensor calls a specialized constructor that copys the + * the keys from base onto the new view tensor being created. + * The storage is shared between the base and the returned view tensor + * + * All callers of this helper must: + * - Only return a view of the input + * - Must be explicit and define a derivative + * + * @param base Base tensor to construct view from. + * @param nested_sizes View tensors' sizes. + * @param nested_strides View tensors' strides. + * @param storage_offsets View tensors' offsets. + * @return A newly constructed view tensor + */ +inline at::Tensor create_nested_view_tensor( + const at::Tensor& base, + at::Tensor nested_sizes, + at::Tensor nested_strides, + at::Tensor storage_offsets) { + TORCH_INTERNAL_ASSERT( + base.is_nested(), + "This function can only be used to create nested tensor views"); + TORCH_INTERNAL_ASSERT( + c10::impl::tls_local_dispatch_key_set().excluded_.has( + c10::DispatchKey::AutogradFunctionality), + "Creating a non differentiable nested tensor view in a CompositeImplicit function is not allowed."); + return at::detail::make_tensor( + c10::TensorImpl::VIEW, + base, + nested_sizes, + nested_strides, + storage_offsets); +} +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Helper functions for getting information about a nested tensor's shape. + +int64_t get_consistent_last_dim_of_nested_tensor(const NestedTensorImpl& nt); + +// The sizes of the underlying tensors +inline std::vector NestedTensor_get_sizes( + const NestedTensorImpl* self_ptr) { + int64_t ntensors = self_ptr->size(0); + std::vector sizes(ntensors); + if (ntensors == 0) { + return sizes; + } + const Tensor& sizemat = self_ptr->get_nested_sizes(); + int64_t orig_dim = sizemat.size(1); + // nesting scalars has empty sizes + if (orig_dim == 0) { + return sizes; + } + const int64_t* sizemat_ptr = sizemat.const_data_ptr(); + + for (const auto i : c10::irange(ntensors)) { + sizes[i] = IntArrayRef(sizemat_ptr, sizemat_ptr + orig_dim); + sizemat_ptr += orig_dim; + } + return sizes; +} + +TORCH_API std::vector NestedTensor_get_max_size( + const NestedTensorImpl& nt); + +std::vector NestedTensor_get_max_size_from_size_tensor( + const Tensor& sizes); + +inline std::vector NestedTensor_get_sizes(const at::Tensor& self) { + const NestedTensorImpl* self_ptr = get_nested_tensor_impl(self); + return NestedTensor_get_sizes(self_ptr); +} +// The strides of the underlying tensors +inline std::vector NestedTensor_get_strides( + const NestedTensorImpl* self_ptr) { + int64_t ntensors = self_ptr->size(0); + std::vector strides(ntensors); + if (ntensors == 0) { + return strides; + } + const Tensor& stridemat = self_ptr->get_nested_strides(); + int64_t orig_dim = stridemat.size(1); + // nesting scalars has empty strides + if (orig_dim == 0) { + return strides; + } + const int64_t* stridemat_ptr = stridemat.const_data_ptr(); + for (const auto i : c10::irange(ntensors)) { + strides[i] = IntArrayRef(stridemat_ptr, stridemat_ptr + orig_dim); + stridemat_ptr += orig_dim; + } + return strides; +} + +inline std::vector NestedTensor_get_strides( + const at::Tensor& self) { + const NestedTensorImpl* self_ptr = get_nested_tensor_impl(self); + return NestedTensor_get_strides(self_ptr); +} + +inline void check_numel_equals_buffer_size(const at::Tensor& self) { + auto self_impl = get_nested_tensor_impl(self); + TORCH_CHECK( + self.numel() == static_cast(self_impl->get_buffer_size()), + "Number of elements in nested tensor must match number of elements in buffer."); +} + +inline void check_numel_equals_buffer_size(const NestedTensorImpl* self_ptr) { + TORCH_CHECK( + self_ptr->numel() == static_cast(self_ptr->get_buffer_size()), + "Number of elements in nested tensor must match number of elements in buffer."); +} + +// Helper function to get size / stride / offset for a nested/normal tensor. +inline IntArrayRef get_size_for_index(const Tensor& tensor, int64_t i) { + if (tensor.is_nested()) { + std::vector tensor_sizes = + NestedTensor_get_sizes(get_nested_tensor_impl(tensor)); + return tensor_sizes[i]; + } else { + return tensor.sizes().slice(1); + } +} + +inline IntArrayRef get_stride_for_index(const Tensor& tensor, int64_t i) { + if (tensor.is_nested()) { + std::vector tensor_strides = + NestedTensor_get_strides(get_nested_tensor_impl(tensor)); + return tensor_strides[i]; + } else { + return tensor.strides().slice(1); + } +} + +inline int64_t get_offset_for_index(const Tensor& tensor, int64_t i) { + if (tensor.is_nested()) { + int64_t* offsets_ptr = get_nested_tensor_impl(tensor) + ->get_storage_offsets() + .data_ptr(); + return offsets_ptr[i]; + + } else { + int64_t offset = tensor.storage_offset(); + return offset + tensor.strides()[0] * i; + } +} +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Data structures and functions for generically applying a function on a nested +// tensor. +namespace impl { + +template +struct NestedNode { + NestedNode() = delete; + explicit NestedNode(std::vector children) + : _is_leaf(false), _children(std::move(children)) {} + explicit NestedNode(TensorList children) + : _is_leaf(false), _children(children.vec()) {} + explicit NestedNode(T payload) + : _is_leaf(true), _payload(std::move(payload)) {} + NestedNode(const NestedNode&) = delete; + NestedNode& operator=(const NestedNode&) = delete; + NestedNode(NestedNode&&) noexcept = default; + NestedNode& operator=(NestedNode&&) noexcept = default; + inline bool is_leaf() const { + return _is_leaf; + } + inline size_t degree() const { + return _children.size(); + } + inline const std::vector unbind() const { + return _children; + } + inline T children(size_t i) const { + return _children[i]; + } + inline const T& payload() const { + return _payload; + } + inline T& payload() { + return _payload; + } + + private: + bool _is_leaf; + std::vector _children; + T _payload{}; +}; + +using TensorNode = NestedNode; + +template +class _map; + +template +class _map> { + public: + static A function_one(F&& fn, const Args&... nested_node) { + return std::forward(fn)(nested_node...); + } + static NestedNode function( + const F& fn, + const NestedNode&... nested_node) { + size_t degree = 0; + bool all_leaf = true; + c10::guts::tuple_map( + std::forward_as_tuple(nested_node...), [&all_leaf, °ree](auto n) { + all_leaf = all_leaf && (n.is_leaf()); + if (degree > 1 && n.degree() > 1) { + TORCH_CHECK( + degree == n.degree(), "NestedNodes must match in degree."); + } + if (n.degree() > degree) { + degree = n.degree(); + } + return nullptr; + }); + // All NestedNodes just wrap regular objects. + if (all_leaf) { + return NestedNode(std::forward(fn)(nested_node.payload()...)); + } + // Some NestedNodes wrap regular Tensors, some NestedTensors and some other + // types. + std::vector result; + for (size_t i = 0; i < degree; i++) { + auto children = c10::guts::tuple_map( + std::forward_as_tuple(nested_node...), [&i](auto a) { + static_assert( + c10::guts::is_instantiation_of::value, + "Internal error."); + // Broadcast regular arguments across NestedTensor constituents. + // This could be a Tensor, integer or anything else really. + if (a.is_leaf()) { + return a.payload(); + } + // Broadcast NestedTensors with one constituent. + if (a.degree() == 1 && !a.is_leaf()) { + return a.children(0); + } + TORCH_CHECK(a.degree() > 0, "Internal assert."); + return a.children(i); + }); + c10::guts::apply( + [&result, &fn](Args... filtered) { + result.emplace_back(function_one(fn, filtered...)); + }, + std::move(children)); + } + return NestedNode(std::move(result)); + } +}; + +// TODO: Add static assert to verify lambda arguments match nested_node types +template +static inline NestedNode< + typename c10::guts::infer_function_traits::type::return_type> +map(F&& fn, const NestedNode&... nested_node) { + return _map< + F, + typename c10::guts::infer_function_traits::type::return_type, + typename c10::guts::infer_function_traits::type::parameter_types>:: + function(std::forward(fn), nested_node...); +} + +inline TensorNode get_nested_tensor_structure(at::Tensor tensor) { + if (get_nested_tensor_impl_or_null(tensor) == nullptr) { + return TensorNode(std::move(tensor)); + } + return TensorNode(tensor.unbind()); +} + +inline Tensor wrap_tensor_node( + TensorNode tensor_node, + std::optional dtype, + std::optional layout, + std::optional device, + std::optional pin_memory) { + TORCH_CHECK( + !tensor_node.is_leaf(), "Expected TensorNode to wrap a list of Tensors."); + TensorOptions options_ = + TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory( + pin_memory); + if (tensor_node.degree() == 0) { + return wrap_buffer(ones({0}, dtype, layout, device), ones({})); + } + + // Fast path: if all tensors are on CPU, have contiguous memory, and the same + // dtype, copying can be done much faster. + bool all_tensors_cpu = true; + bool all_tensors_contiguous = true; + bool all_tensors_same_dtype = true; + auto first_dtype = tensor_node.children(0).dtype(); + std::vector start_offsets(tensor_node.degree()); + start_offsets[0] = 0; + long total_size = 0; + for (const auto i : c10::irange(tensor_node.degree())) { + all_tensors_cpu = all_tensors_cpu && tensor_node.children(i).is_cpu(); + all_tensors_contiguous = + all_tensors_contiguous && tensor_node.children(i).is_contiguous(); + all_tensors_same_dtype = all_tensors_same_dtype && + (first_dtype == tensor_node.children(i).dtype()); + if (!(all_tensors_cpu && all_tensors_contiguous && + all_tensors_same_dtype)) { + break; + } + if (i > 0) { + start_offsets[i] = + start_offsets[i - 1] + tensor_node.children(i - 1).numel(); + } + total_size += tensor_node.children(i).numel(); + } + + TensorOptions options; + Tensor nt_buffer, nt_sizes; + if (all_tensors_cpu && all_tensors_contiguous && all_tensors_same_dtype) { + nt_buffer = at::empty({total_size}, tensor_node.children(0).options()); + nt_sizes = at::empty( + {static_cast(tensor_node.degree()), + static_cast(tensor_node.children(0).sizes().size())}, + TensorOptions().dtype(kLong)); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( + at::ScalarType::Half, + at::ScalarType::Bool, + at::ScalarType::BFloat16, + c10::typeMetaToScalarType(first_dtype), + "create_nt_buffer", + [&]() { + at::parallel_for( + 0, tensor_node.degree(), 1, [&](int64_t begin, int64_t end) { + for (int64_t i = begin; i < end; ++i) { + // Only try copying memory if there is more than 0 elements + // for a certain tensor + if (tensor_node.children(i).numel() > 0) { + memcpy( + nt_buffer.mutable_data_ptr() + start_offsets[i], + tensor_node.children(i).const_data_ptr(), + tensor_node.children(i).numel() * sizeof(scalar_t)); + } + } + }); + }); + long sizes_offset = 0; + for (size_t i = 0; i < tensor_node.degree(); ++i) { + auto tensor_sizes = tensor_node.children(i).sizes(); + for (int64_t tensor_size : tensor_sizes) { + nt_sizes.mutable_data_ptr()[sizes_offset++] = tensor_size; + } + } + options = nt_buffer.options().merge_in(options_); + } else { // Slow path + std::vector flat_tensors; + std::vector sizes; + for (const auto i : c10::irange(tensor_node.degree())) { + flat_tensors.push_back(tensor_node.children(i).reshape(-1).contiguous()); + sizes.push_back( + tensor(c10::IntArrayRef(tensor_node.children(i).sizes()))); + } + options = flat_tensors[0].options().merge_in(options_); + nt_buffer = at::cat(flat_tensors); + nt_sizes = at::native::stack(sizes); + } + + return wrap_buffer(nt_buffer.to(options), nt_sizes); +} + +} // namespace impl + +// This function is meant to ease rapid operator coverage for +// NestedTensor kernels. It is not meant to be efficient. Use it judiciously. +template +inline at::Tensor map_nested_tensor(F&& fn, A... a) { + return wrap_tensor_node( + impl::map(std::forward(fn), impl::get_nested_tensor_structure(a)...), + std::nullopt, + std::nullopt, + std::nullopt, + std::nullopt); +} + +} // namespace at::native