diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/ArrayRef.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/ArrayRef.h new file mode 100644 index 0000000000000000000000000000000000000000..0461d5953ed8a7783c82402ca4523b0b0a1ad465 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/ArrayRef.h @@ -0,0 +1,2 @@ +#pragma once +#include diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/DimVector.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/DimVector.h new file mode 100644 index 0000000000000000000000000000000000000000..cb652fffcb14819d8ca5292daa012ad47f4c3fad --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/DimVector.h @@ -0,0 +1,2 @@ +#pragma once +#include diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/Dimname.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/Dimname.h new file mode 100644 index 0000000000000000000000000000000000000000..71836a9e25d3d82d9cd5024b2f33e147e14bf87e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/Dimname.h @@ -0,0 +1 @@ +#include diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/Formatting.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/Formatting.h new file mode 100644 index 0000000000000000000000000000000000000000..392e2a27b0130c7ba55621d6ac1d6fd4e989db02 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/Formatting.h @@ -0,0 +1 @@ +#include diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/MetaFunctions_inl.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/MetaFunctions_inl.h new file mode 100644 index 0000000000000000000000000000000000000000..46dcd458ff1c4a37c9d65f240a491679cbfd17c1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/MetaFunctions_inl.h @@ -0,0 +1,325 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider including a specific operator from \ + . \ + See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/PTThreadPool.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/PTThreadPool.h new file mode 100644 index 0000000000000000000000000000000000000000..1c910dfb97dce44748c054b457b400b13b1b9fda --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/PTThreadPool.h @@ -0,0 +1,17 @@ +#pragma once + +#include +#include + +namespace at { + +class TORCH_API PTThreadPool : public c10::ThreadPool { + public: + explicit PTThreadPool(int pool_size, int numa_node_id = -1) + : c10::ThreadPool(pool_size, numa_node_id, []() { + c10::setThreadName("PTThreadPool"); + at::init_num_threads(); + }) {} +}; + +} // namespace at diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/SequenceNumber.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/SequenceNumber.h new file mode 100644 index 0000000000000000000000000000000000000000..41b7b97cf6abbdcf987c020e14b09a64f7729bfc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/SequenceNumber.h @@ -0,0 +1,13 @@ +#pragma once + +#include +#include + +// A simple thread local enumeration, used to link forward and backward pass +// ops and is used by autograd and observers framework +namespace at::sequence_number { + +TORCH_API uint64_t peek(); +TORCH_API uint64_t get_and_increment(); + +} // namespace at::sequence_number diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/ThreadLocalPythonObjects.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/ThreadLocalPythonObjects.h new file mode 100644 index 0000000000000000000000000000000000000000..c45de86db3abeffe22cb8db559f602d88b35be9c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/ThreadLocalPythonObjects.h @@ -0,0 +1,21 @@ +#pragma once + +#include +#include +#include + +namespace at::impl { + +struct TORCH_API ThreadLocalPythonObjects { + static void set(const std::string& key, std::shared_ptr value); + static const std::shared_ptr& get(const std::string& key); + static bool contains(const std::string& key); + + static const ThreadLocalPythonObjects& get_state(); + static void set_state(ThreadLocalPythonObjects state); + + private: + std::unordered_map> obj_dict_; +}; + +} // namespace at::impl diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/ThreadLocalState.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/ThreadLocalState.h new file mode 100644 index 0000000000000000000000000000000000000000..721ea9957513bf95be86aaa6a13733775283413a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/ThreadLocalState.h @@ -0,0 +1,120 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace at { + +// Thread local state contains values that are preserved across +// thread boundaries (e.g. at::launch/JIT fork, autograd). +// Note at::parallel_for doesn't preserve TLS across thread boundaries. +class TORCH_API ThreadLocalState { + public: + // Saves the thread local variables' values and + // returns them as a ThreadLocalState + ThreadLocalState(); + + // set_grad_mode - force the value of the grad mode TLS in + // the current state object. This is used for example in the + // autograd engine. + void set_grad_mode(bool enabled); + + // set_multithreading_enabled - force the value of the multithreadinmaximum + // threads TLS in + // the current state object. This is used for example in the + // autograd engine. + void set_multithreading_enabled(bool enabled); + + // Sets thread local variables in the current thread, + // according to the thread boundary specified + static void setThreadLocalState(const ThreadLocalState& state); + + private: + c10::impl::LocalDispatchKeySet dispatch_key_; + + // ThreadLocalDebugInfo does not change after being created + // with DebugInfoGuard + std::shared_ptr debug_info_; + + // RecordFunction TLS + RecordFunctionTLS rf_tls_; + + // TLS for out-of-tree functorch + // See NOTE [functorch TLS in pytorch/pytorch] for why this needs to be a + // pointer (spoiler alert: it's due to the indirection) + // This needs to be a shared_ptr instead of a unique_ptr because + // ThreadLocalState is copy-able and does indeed get copied. Maybe we can + // consider adding an explicit copy constructor for ThreadLocalState in the + // future but I didn't want to add one just for this. + std::shared_ptr functorch_tls_; + + // TLS for AutogradModes + AutogradState autograd_tls_; + + // TLS for enable_torch_dispatch_mode + c10::impl::TorchDispatchModeTLS torch_dispatch_mode_state_; + + // TLS for enable_python_dispatcher + c10::impl::PyInterpreter* python_dispatcher_state_; + + // TLS for __torch_function__ (mode and disable_torch_function) + at::impl::PythonTorchFunctionTLS python_torch_function_state_; + + // TLS for saved tensors default hooks + at::impl::SavedTensorDefaultHooksTLS saved_tensors_default_hooks_state_; + + bool functionalization_reapply_views_state_; + + // TLS for arbitrary python objects that is registered via hooks + at::impl::ThreadLocalPythonObjects saved_objects_; + +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && \ + !defined(BUILD_LITE_INTERPRETER) + // TLS for autocast dtypes + std::array + autocast_dtypes_; +#endif + + friend class ThreadLocalStateGuard; +}; + +// Guard to set and reset the thread local state +class TORCH_API ThreadLocalStateGuard { + public: + explicit ThreadLocalStateGuard(const ThreadLocalState& state) + : prev_state_(ThreadLocalState()) { + // set the given state across the thread boundary + ThreadLocalState::setThreadLocalState(state); + } + + ~ThreadLocalStateGuard() { + // restore previously set variables + ThreadLocalState::setThreadLocalState(prev_state_); + } + + private: + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const ThreadLocalState prev_state_; +}; + +template +auto wrapPropagateTLSState(T callback) { + return [tls_state = ThreadLocalState(), + callback = std::move(callback)](auto&&... args) { + ThreadLocalStateGuard g(tls_state); + // Propagate value returned by callback(). + return callback(std::forward(args)...); + }; +} + +} // namespace at diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/Version.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/Version.h new file mode 100644 index 0000000000000000000000000000000000000000..706da58a5da01c35fda7f2c6374c8f5868f1b642 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/Version.h @@ -0,0 +1,18 @@ +#include + +namespace at { + +/// Returns a detailed string describing the configuration PyTorch. +TORCH_API std::string show_config(); + +TORCH_API std::string get_mkl_version(); + +TORCH_API std::string get_mkldnn_version(); + +TORCH_API std::string get_openmp_version(); + +TORCH_API std::string get_cxx_flags(); + +TORCH_API std::string get_cpu_capability(); + +} // namespace at diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/AsmUtils.cuh b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/AsmUtils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..8bd897e64c4fdcdf6bd32c0b176ce2414ebe9438 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/AsmUtils.cuh @@ -0,0 +1,149 @@ +#pragma once +#include + +// Collection of direct PTX functions + +namespace at::cuda { + +template +struct Bitfield {}; + +template <> +struct Bitfield { + static __device__ __host__ __forceinline__ + unsigned int getBitfield(unsigned int val, int pos, int len) { +#if !defined(__CUDA_ARCH__) + pos &= 0xff; + len &= 0xff; + + unsigned int m = (1u << len) - 1u; + return (val >> pos) & m; +#else + unsigned int ret; + asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len)); + return ret; +#endif + } + + static __device__ __host__ __forceinline__ + unsigned int setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) { +#if !defined(__CUDA_ARCH__) + pos &= 0xff; + len &= 0xff; + + unsigned int m = (1u << len) - 1u; + toInsert &= m; + toInsert <<= pos; + m <<= pos; + + return (val & ~m) | toInsert; +#else + unsigned int ret; + asm("bfi.b32 %0, %1, %2, %3, %4;" : + "=r"(ret) : "r"(toInsert), "r"(val), "r"(pos), "r"(len)); + return ret; +#endif + } +}; + +template <> +struct Bitfield { + static __device__ __host__ __forceinline__ + uint64_t getBitfield(uint64_t val, int pos, int len) { +#if !defined(__CUDA_ARCH__) + pos &= 0xff; + len &= 0xff; + + uint64_t m = (1u << len) - 1u; + return (val >> pos) & m; +#else + uint64_t ret; + asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len)); + return ret; +#endif + } + + static __device__ __host__ __forceinline__ + uint64_t setBitfield(uint64_t val, uint64_t toInsert, int pos, int len) { +#if !defined(__CUDA_ARCH__) + pos &= 0xff; + len &= 0xff; + + uint64_t m = (1u << len) - 1u; + toInsert &= m; + toInsert <<= pos; + m <<= pos; + + return (val & ~m) | toInsert; +#else + uint64_t ret; + asm("bfi.b64 %0, %1, %2, %3, %4;" : + "=l"(ret) : "l"(toInsert), "l"(val), "r"(pos), "r"(len)); + return ret; +#endif + } +}; + +__device__ __forceinline__ int getLaneId() { +#if defined(USE_ROCM) + return __lane_id(); +#else + int laneId; + asm("mov.s32 %0, %%laneid;" : "=r"(laneId) ); + return laneId; +#endif +} + +#if defined(USE_ROCM) +__device__ __forceinline__ unsigned long long int getLaneMaskLt() { + const std::uint64_t m = (1ull << getLaneId()) - 1ull; + return m; +} +#else +__device__ __forceinline__ unsigned getLaneMaskLt() { + unsigned mask; + asm("mov.u32 %0, %%lanemask_lt;" : "=r"(mask)); + return mask; +} +#endif + +#if defined (USE_ROCM) +__device__ __forceinline__ unsigned long long int getLaneMaskLe() { + std::uint64_t m = UINT64_MAX >> (sizeof(std::uint64_t) * CHAR_BIT - (getLaneId() + 1)); + return m; +} +#else +__device__ __forceinline__ unsigned getLaneMaskLe() { + unsigned mask; + asm("mov.u32 %0, %%lanemask_le;" : "=r"(mask)); + return mask; +} +#endif + +#if defined(USE_ROCM) +__device__ __forceinline__ unsigned long long int getLaneMaskGt() { + const std::uint64_t m = getLaneMaskLe(); + return m ? ~m : m; +} +#else +__device__ __forceinline__ unsigned getLaneMaskGt() { + unsigned mask; + asm("mov.u32 %0, %%lanemask_gt;" : "=r"(mask)); + return mask; +} +#endif + +#if defined(USE_ROCM) +__device__ __forceinline__ unsigned long long int getLaneMaskGe() { + const std::uint64_t m = getLaneMaskLt(); + return ~m; +} +#else +__device__ __forceinline__ unsigned getLaneMaskGe() { + unsigned mask; + asm("mov.u32 %0, %%lanemask_ge;" : "=r"(mask)); + return mask; +} +#endif + +} // namespace at::cuda diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAContextLight.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAContextLight.h new file mode 100644 index 0000000000000000000000000000000000000000..dc33cb541370f54f8b8b03baadc52709017ca527 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAContextLight.h @@ -0,0 +1,99 @@ +#pragma once +// Light-weight version of CUDAContext.h with fewer transitive includes + +#include + +#include +#include +#include + +// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also +// added bf16 support +#include + +#ifdef CUDART_VERSION +#include +#endif + +#if defined(USE_CUDSS) +#include +#endif + +#if defined(USE_ROCM) +#include +#endif + +#include +#include + +namespace c10 { +struct Allocator; +} + +namespace at::cuda { + +/* +A common CUDA interface for ATen. + +This interface is distinct from CUDAHooks, which defines an interface that links +to both CPU-only and CUDA builds. That interface is intended for runtime +dispatch and should be used from files that are included in both CPU-only and +CUDA builds. + +CUDAContext, on the other hand, should be preferred by files only included in +CUDA builds. It is intended to expose CUDA functionality in a consistent +manner. + +This means there is some overlap between the CUDAContext and CUDAHooks, but +the choice of which to use is simple: use CUDAContext when in a CUDA-only file, +use CUDAHooks otherwise. + +Note that CUDAContext simply defines an interface with no associated class. +It is expected that the modules whose functions compose this interface will +manage their own state. There is only a single CUDA context/state. +*/ + +/** + * DEPRECATED: use device_count() instead + */ +inline int64_t getNumGPUs() { + return c10::cuda::device_count(); +} + +/** + * CUDA is available if we compiled with CUDA, and there are one or more + * devices. If we compiled with CUDA but there is a driver problem, etc., + * this function will report CUDA is not available (rather than raise an error.) + */ +inline bool is_available() { + return c10::cuda::device_count() > 0; +} + +TORCH_CUDA_CPP_API cudaDeviceProp* getCurrentDeviceProperties(); + +TORCH_CUDA_CPP_API int warp_size(); + +TORCH_CUDA_CPP_API cudaDeviceProp* getDeviceProperties(c10::DeviceIndex device); + +TORCH_CUDA_CPP_API bool canDeviceAccessPeer( + c10::DeviceIndex device, + c10::DeviceIndex peer_device); + +TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator(); + +/* Handles */ +TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle(); +TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle(); +TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle(); + +TORCH_CUDA_CPP_API void clearCublasWorkspaces(); + +#if defined(CUDART_VERSION) || defined(USE_ROCM) +TORCH_CUDA_CPP_API cusolverDnHandle_t getCurrentCUDASolverDnHandle(); +#endif + +#if defined(USE_CUDSS) +TORCH_CUDA_CPP_API cudssHandle_t getCurrentCudssHandle(); +#endif + +} // namespace at::cuda diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDADataType.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDADataType.h new file mode 100644 index 0000000000000000000000000000000000000000..1696bb3a0f4430a21de8261cf066a895d014156d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDADataType.h @@ -0,0 +1,105 @@ +#pragma once + +#include + +#include +#include + +namespace at::cuda { + +template +cudaDataType getCudaDataType() { + static_assert(false && sizeof(scalar_t), "Cannot convert type to cudaDataType."); + return {}; +} + +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_16F; +} +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_32F; +} +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_64F; +} +template<> inline cudaDataType getCudaDataType>() { + return CUDA_C_16F; +} +template<> inline cudaDataType getCudaDataType>() { + return CUDA_C_32F; +} +template<> inline cudaDataType getCudaDataType>() { + return CUDA_C_64F; +} + +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_8U; +} +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_8I; +} +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_32I; +} + +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_16I; +} +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_64I; +} +template<> inline cudaDataType getCudaDataType() { + return CUDA_R_16BF; +} + +inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type) { + switch (scalar_type) { + case c10::ScalarType::Byte: + return CUDA_R_8U; + case c10::ScalarType::Char: + return CUDA_R_8I; + case c10::ScalarType::Int: + return CUDA_R_32I; + case c10::ScalarType::Half: + return CUDA_R_16F; + case c10::ScalarType::Float: + return CUDA_R_32F; + case c10::ScalarType::Double: + return CUDA_R_64F; + case c10::ScalarType::ComplexHalf: + return CUDA_C_16F; + case c10::ScalarType::ComplexFloat: + return CUDA_C_32F; + case c10::ScalarType::ComplexDouble: + return CUDA_C_64F; + case c10::ScalarType::Short: + return CUDA_R_16I; + case c10::ScalarType::Long: + return CUDA_R_64I; + case c10::ScalarType::BFloat16: + return CUDA_R_16BF; +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 + case c10::ScalarType::Float8_e4m3fn: + return CUDA_R_8F_E4M3; + case c10::ScalarType::Float8_e5m2: + return CUDA_R_8F_E5M2; +#endif +#if defined(USE_ROCM) +#if defined(HIP_NEW_TYPE_ENUMS) + case c10::ScalarType::Float8_e4m3fnuz: + return HIP_R_8F_E4M3_FNUZ; + case c10::ScalarType::Float8_e5m2fnuz: + return HIP_R_8F_E5M2_FNUZ; +#else + case c10::ScalarType::Float8_e4m3fnuz: + return static_cast(1000); + case c10::ScalarType::Float8_e5m2fnuz: + return static_cast(1001); +#endif +#endif + default: + TORCH_INTERNAL_ASSERT(false, "Cannot convert ScalarType ", scalar_type, " to cudaDataType.") + } +} + +} // namespace at::cuda diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDADevice.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDADevice.h new file mode 100644 index 0000000000000000000000000000000000000000..ba9a5eb849a091be6e86658a8c7af87a0a3fbb8f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDADevice.h @@ -0,0 +1,23 @@ +#pragma once + +#include + +#include +#include + +namespace at::cuda { + +inline Device getDeviceFromPtr(void* ptr) { + cudaPointerAttributes attr{}; + + AT_CUDA_CHECK(cudaPointerGetAttributes(&attr, ptr)); + +#if !defined(USE_ROCM) + TORCH_CHECK(attr.type != cudaMemoryTypeUnregistered, + "The specified pointer resides on host memory and is not registered with any CUDA device."); +#endif + + return {c10::DeviceType::CUDA, static_cast(attr.device)}; +} + +} // namespace at::cuda diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAUtils.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..d5f65dd6a572f9cab15dbae9983df8037b724a99 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAUtils.h @@ -0,0 +1,20 @@ +#pragma once + +#include + +namespace at::cuda { + +// Check if every tensor in a list of tensors matches the current +// device. +inline bool check_device(ArrayRef ts) { + if (ts.empty()) { + return true; + } + Device curDevice = Device(kCUDA, current_device()); + for (const Tensor& t : ts) { + if (t.device() != curDevice) return false; + } + return true; +} + +} // namespace at::cuda diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/NumericLimits.cuh b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/NumericLimits.cuh new file mode 100644 index 0000000000000000000000000000000000000000..7081e94837caa7d5050128e0bfe19aa67f93cd39 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/NumericLimits.cuh @@ -0,0 +1,121 @@ +#pragma once + +#include +#include +#include +#include + +// NumericLimits.cuh is a holder for numeric limits definitions of commonly used +// types. This header is very specific to ROCm HIP and may be removed in the future. +// This header is derived from the legacy THCNumerics.cuh. + +// The lower_bound and upper_bound constants are same as lowest and max for +// integral types, but are -inf and +inf for floating point types. They are +// useful in implementing min, max, etc. + +namespace at { + +template +struct numeric_limits { +}; + +// WARNING: the following at::numeric_limits definitions are there only to support +// HIP compilation for the moment. Use std::numeric_limits if you are not +// compiling for ROCm. +// from @colesbury: "The functions on numeric_limits aren't marked with +// __device__ which is why they don't work with ROCm. CUDA allows them +// because they're constexpr." + +namespace { + // ROCm doesn't like INFINITY too. + constexpr double inf = INFINITY; +} + +template <> +struct numeric_limits { + static inline __host__ __device__ bool lowest() { return false; } + static inline __host__ __device__ bool max() { return true; } + static inline __host__ __device__ bool lower_bound() { return false; } + static inline __host__ __device__ bool upper_bound() { return true; } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ uint8_t lowest() { return 0; } + static inline __host__ __device__ uint8_t max() { return UINT8_MAX; } + static inline __host__ __device__ uint8_t lower_bound() { return 0; } + static inline __host__ __device__ uint8_t upper_bound() { return UINT8_MAX; } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ int8_t lowest() { return INT8_MIN; } + static inline __host__ __device__ int8_t max() { return INT8_MAX; } + static inline __host__ __device__ int8_t lower_bound() { return INT8_MIN; } + static inline __host__ __device__ int8_t upper_bound() { return INT8_MAX; } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ int16_t lowest() { return INT16_MIN; } + static inline __host__ __device__ int16_t max() { return INT16_MAX; } + static inline __host__ __device__ int16_t lower_bound() { return INT16_MIN; } + static inline __host__ __device__ int16_t upper_bound() { return INT16_MAX; } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ int32_t lowest() { return INT32_MIN; } + static inline __host__ __device__ int32_t max() { return INT32_MAX; } + static inline __host__ __device__ int32_t lower_bound() { return INT32_MIN; } + static inline __host__ __device__ int32_t upper_bound() { return INT32_MAX; } +}; + +template <> +struct numeric_limits { +#ifdef _MSC_VER + static inline __host__ __device__ int64_t lowest() { return _I64_MIN; } + static inline __host__ __device__ int64_t max() { return _I64_MAX; } + static inline __host__ __device__ int64_t lower_bound() { return _I64_MIN; } + static inline __host__ __device__ int64_t upper_bound() { return _I64_MAX; } +#else + static inline __host__ __device__ int64_t lowest() { return INT64_MIN; } + static inline __host__ __device__ int64_t max() { return INT64_MAX; } + static inline __host__ __device__ int64_t lower_bound() { return INT64_MIN; } + static inline __host__ __device__ int64_t upper_bound() { return INT64_MAX; } +#endif +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ at::Half lowest() { return at::Half(0xFBFF, at::Half::from_bits()); } + static inline __host__ __device__ at::Half max() { return at::Half(0x7BFF, at::Half::from_bits()); } + static inline __host__ __device__ at::Half lower_bound() { return at::Half(0xFC00, at::Half::from_bits()); } + static inline __host__ __device__ at::Half upper_bound() { return at::Half(0x7C00, at::Half::from_bits()); } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ at::BFloat16 lowest() { return at::BFloat16(0xFF7F, at::BFloat16::from_bits()); } + static inline __host__ __device__ at::BFloat16 max() { return at::BFloat16(0x7F7F, at::BFloat16::from_bits()); } + static inline __host__ __device__ at::BFloat16 lower_bound() { return at::BFloat16(0xFF80, at::BFloat16::from_bits()); } + static inline __host__ __device__ at::BFloat16 upper_bound() { return at::BFloat16(0x7F80, at::BFloat16::from_bits()); } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ float lowest() { return -FLT_MAX; } + static inline __host__ __device__ float max() { return FLT_MAX; } + static inline __host__ __device__ float lower_bound() { return -static_cast(inf); } + static inline __host__ __device__ float upper_bound() { return static_cast(inf); } +}; + +template <> +struct numeric_limits { + static inline __host__ __device__ double lowest() { return -DBL_MAX; } + static inline __host__ __device__ double max() { return DBL_MAX; } + static inline __host__ __device__ double lower_bound() { return -inf; } + static inline __host__ __device__ double upper_bound() { return inf; } +}; + +} // namespace at diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/PeerToPeerAccess.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/PeerToPeerAccess.h new file mode 100644 index 0000000000000000000000000000000000000000..1abf1dcfc12447d7226df19701486cd6bc2affee --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/PeerToPeerAccess.h @@ -0,0 +1,11 @@ +#include +#include + +namespace at::cuda { +namespace detail { +void init_p2p_access_cache(int64_t num_devices); +} + +TORCH_CUDA_CPP_API bool get_p2p_access(int source_dev, int dest_dev); + +} // namespace at::cuda diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/Sleep.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/Sleep.h new file mode 100644 index 0000000000000000000000000000000000000000..ef5e83a832f739e19f13837500824c984013812e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/Sleep.h @@ -0,0 +1,13 @@ +#pragma once +#include +#include + +namespace at::cuda { + +// enqueues a kernel that spins for the specified number of cycles +TORCH_CUDA_CU_API void sleep(int64_t cycles); + +// flushes instruction cache for ROCm; no-op for CUDA +TORCH_CUDA_CU_API void flush_icache(); + +} // namespace at::cuda diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/cub_definitions.cuh b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/cub_definitions.cuh new file mode 100644 index 0000000000000000000000000000000000000000..55b2f21dd1b48c77ccb96ce7de82f9a28ba3725a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/cub_definitions.cuh @@ -0,0 +1,53 @@ +#pragma once + +#if !defined(USE_ROCM) +#include // for CUDA_VERSION +#endif + +#if !defined(USE_ROCM) +#include +#else +#define CUB_VERSION 0 +#endif + +// cub sort support for __nv_bfloat16 is added to cub 1.13 in: +// https://github.com/NVIDIA/cub/pull/306 +#if CUB_VERSION >= 101300 +#define CUB_SUPPORTS_NV_BFLOAT16() true +#else +#define CUB_SUPPORTS_NV_BFLOAT16() false +#endif + +// cub support for CUB_WRAPPED_NAMESPACE is added to cub 1.13.1 in: +// https://github.com/NVIDIA/cub/pull/326 +// CUB_WRAPPED_NAMESPACE is defined globally in cmake/Dependencies.cmake +// starting from CUDA 11.5 +#if defined(CUB_WRAPPED_NAMESPACE) || defined(THRUST_CUB_WRAPPED_NAMESPACE) +#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() true +#else +#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false +#endif + +// cub support for UniqueByKey is added to cub 1.16 in: +// https://github.com/NVIDIA/cub/pull/405 +#if CUB_VERSION >= 101600 +#define CUB_SUPPORTS_UNIQUE_BY_KEY() true +#else +#define CUB_SUPPORTS_UNIQUE_BY_KEY() false +#endif + +// cub support for scan by key is added to cub 1.15 +// in https://github.com/NVIDIA/cub/pull/376 +#if CUB_VERSION >= 101500 +#define CUB_SUPPORTS_SCAN_BY_KEY() 1 +#else +#define CUB_SUPPORTS_SCAN_BY_KEY() 0 +#endif + +// cub support for cub::FutureValue is added to cub 1.15 in: +// https://github.com/NVIDIA/cub/pull/305 +#if CUB_VERSION >= 101500 +#define CUB_SUPPORTS_FUTURE_VALUE() true +#else +#define CUB_SUPPORTS_FUTURE_VALUE() false +#endif diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/CUDAHooks.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/CUDAHooks.h new file mode 100644 index 0000000000000000000000000000000000000000..c23998fda56b678c45d354c57196387178cf4753 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/CUDAHooks.h @@ -0,0 +1,58 @@ +#pragma once + +#include + +#include +#include + +// TODO: No need to have this whole header, we can just put it all in +// the cpp file + +namespace at::cuda::detail { + +// Set the callback to initialize Magma, which is set by +// torch_cuda_cu. This indirection is required so magma_init is called +// in the same library where Magma will be used. +TORCH_CUDA_CPP_API void set_magma_init_fn(void (*magma_init_fn)()); + + +// The real implementation of CUDAHooksInterface +struct CUDAHooks : public at::CUDAHooksInterface { + CUDAHooks(at::CUDAHooksArgs) {} + void initCUDA() const override; + Device getDeviceFromPtr(void* data) const override; + bool isPinnedPtr(const void* data) const override; + const Generator& getDefaultCUDAGenerator(DeviceIndex device_index = -1) const override; + bool hasCUDA() const override; + bool hasMAGMA() const override; + bool hasCuDNN() const override; + bool hasCuSOLVER() const override; + bool hasCuBLASLt() const override; + bool hasROCM() const override; + const at::cuda::NVRTC& nvrtc() const override; + DeviceIndex current_device() const override; + bool hasPrimaryContext(DeviceIndex device_index) const override; + Allocator* getCUDADeviceAllocator() const override; + Allocator* getPinnedMemoryAllocator() const override; + bool compiledWithCuDNN() const override; + bool compiledWithMIOpen() const override; + bool supportsDilatedConvolutionWithCuDNN() const override; + bool supportsDepthwiseConvolutionWithCuDNN() const override; + bool supportsBFloat16ConvolutionWithCuDNNv8() const override; + bool hasCUDART() const override; + long versionCUDART() const override; + long versionCuDNN() const override; + std::string showConfig() const override; + double batchnormMinEpsilonCuDNN() const override; + int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex device_index) const override; + void cuFFTSetPlanCacheMaxSize(DeviceIndex device_index, int64_t max_size) const override; + int64_t cuFFTGetPlanCacheSize(DeviceIndex device_index) const override; + void cuFFTClearPlanCache(DeviceIndex device_index) const override; + int getNumGPUs() const override; +#ifdef USE_ROCM + bool isGPUArch(DeviceIndex device_index, const std::vector& archs) const override; +#endif + void deviceSynchronize(DeviceIndex device_index) const override; +}; + +} // at::cuda::detail diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/DeviceThreadHandles.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/DeviceThreadHandles.h new file mode 100644 index 0000000000000000000000000000000000000000..1f80c863b63944a25aacb3aa8b95d0b82b6c110b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/DeviceThreadHandles.h @@ -0,0 +1,151 @@ +// Some stateful GPU libraries, such as cuDNN, cuBLAS, use handles to store states. +// These handles are tied to device, and these libraries requires/recommends not to +// share handles across host threads. +// +// These libraries recommend using one handle per host thread. We may not want to do +// this because threads are relatively light-weight, but creating and destroying +// handles is expensive (destroying the handle causes synchronizations). DataParallel, +// for example, creates new threads for each forward pass. +// +// This file implements a handle pool mechanism. The handle pool returns handles on +// demand as threads request them. If all existing handles in the pool are in use, +// it creates a new one. As threads terminate, they release handles back into the pool. +// In this way, the handle pool never creates more handles than the high-water mark of +// active threads, so it's efficient with DataParallel. + +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace at::cuda { namespace { + +template +struct DeviceThreadHandlePool : public std::enable_shared_from_this> { + + struct Handle { + Handle_t handle; + Handle(bool create = false) : handle(nullptr) + { + if(create) Create(&handle); + } + // std::vector.emplace() and push_back() may route through temporaries and call + // copy/move constructors along the way. If this is the case, we don't want + // the destructors of temporaries to call cudnnDestroy on the handle. + // We can achieve safety (for the narrow case of stashing within std::vectors) + // by making Handle moveable but not copyable, and transferring handle ownership + // to the latest constructed object. This is not a substitute for full-blown + // reference counting, but reference counting may be overkill here. + // Another alternative is to wrap the saved Handles in unique_ptrs, i.e., + // unordered_map>> created_handles; + Handle(const Handle& rhs) = delete; + // Following https://stackoverflow.com/questions/3279543/what-is-the-copy-and-swap-idiom + Handle(Handle&& rhs) noexcept : Handle() { std::swap(handle, rhs.handle); } + // operator= takes argument by value + Handle& operator=(Handle rhs) { std::swap(handle, rhs.handle); return *this; } + ~Handle() { + if(handle) Destroy(handle); + } + }; + + std::mutex mutex; + + // Handles are lazily created as different threads request them, + // but are never destroyed until the end of the process. + // The maximum number of handles this process will create for each device is equal + // to the high-water mark of the number of concurrently active threads that request + // handles for that device. + // When threads terminate, they release their handles back into the pool for reuse. + // Otherwise, new handles would be created every time new threads were spawned, + // resulting in poor performance for Python modules that repeatedly or frequently + // spawned new sets of threads (like DataParallel, which creates a new set of threads + // for each forward pass). + // + // To prevent potential deadlocks, we explicitly choose not to cap the number + // of handles that are created per device. + // Example of danger: If we cap the max handles at 4, and 5 threads are sharing a device, + // only 4 can make forward progress at any time. The other 4 will not release their + // handles until they exit, so the fifth cannot make progress until then. This is + // not a problem...UNLESS all 5 threads attempt some sort of synchronization at an + // intermediate point (ie, before any of them have exited). We have no way to anticipate + // or enforce that user threads will not attempt such intermediate synchronization. + // The only way to ensure safety is to avoid imposing a cap on the number of handles. + std::unordered_map> created_handles; + std::unordered_map> available_handles; + + // PoolWindow lazily creates and caches the handles that a particular thread is using, + // so in the common case handle access doesn't incur either handle creation or a mutex lock. + class PoolWindow + { + public: + PoolWindow(std::shared_ptr parent): weak_parent(std::move(parent)) {} + ~PoolWindow(){ release(); } + + Handle_t reserve(int device) + { + // If this thread already has a handle for this device, return it + if(my_handles.find(device) != my_handles.end()) + return my_handles[device]; + + // otherwise, either grab a handle from the pool if one is available, + // or if not, create a new one. + auto parent = weak_parent.lock(); + TORCH_CHECK(parent, "Cannot create handle during program termination"); + std::lock_guard guard(parent->mutex); + + if(parent->available_handles[device].size() > 0) + { + my_handles[device] = parent->available_handles[device].back(); + parent->available_handles[device].pop_back(); + } + else + { + // In local testing, I do observe that emplace_back sometimes routes through temporaries + // that incur move-constructor and destructor calls. See comments in Handle above. + parent->created_handles[device].emplace_back(true /*create*/); + my_handles[device] = parent->created_handles[device].back().handle; + } + + return my_handles[device]; + } + + private: + // Stores the per-device handles currently owned by this thread + std::unordered_map my_handles; + + std::weak_ptr weak_parent; + + // Called by the destructor. Releases this thread's handles back into the pool. + void release() { + if(my_handles.size() > 0) { + auto parent = weak_parent.lock(); + if (!parent) { + // If this thread exits after atexit handlers have completed, the + // cuda context itself may be invalid, so we must leak the handles. + return; + } + + std::lock_guard guard(parent->mutex); + for(auto d_h : my_handles) + parent->available_handles[d_h.first].push_back(d_h.second); + } + } + }; + + // Warning: + // If you want to change this function, be aware that this function will be called + // by multiple threads and there is no mutex guarding the call of this function, so + // make sure your implementation is thread-safe. + PoolWindow *newPoolWindow() { + // The returned pointer will be owned by a thread local variable + // so that different threads does not share the same PoolWindow. + return new PoolWindow(this->shared_from_this()); + } +}; + +}} // namespace at::cuda::detail:: diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/IntegerDivider.cuh b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/IntegerDivider.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e8a26b5e06a6ffeeeaf5df26f09175a2c903aa01 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/IntegerDivider.cuh @@ -0,0 +1,124 @@ +#pragma once + +#include +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) +#include +#endif + +namespace at::cuda::detail { + +// A utility class to implement integer division by multiplication, given a fixed +// divisor. +// +// WARNING: The fast divider algorithm is only implemented for unsigned int; +// otherwise we default to plain integer division. For unsigned int, +// we further assume that the dividend is at most INT32_MAX. Thus, +// IntDivider must NOT be used for general integer division. +// +// This reduced range is enough for our purpose, and it allows us to +// slightly simplify the computation. +// +// (NOTE: Below, "2^k" denotes exponentiation, i.e., 1< 0), we can find a "magic number" m (2^N +// <= m < 2^(N+1)) and shift s such that: +// +// \floor(n / d) = \floor((m * n) / 2^(N+s)). +// +// Given such m and s, the integer division can be then implemented as: +// +// let m' = m - 2^N // 0 <= m' < 2^N +// +// fast_integer_division(n): +// // Multiply two N-bit unsigned integers: the result is a 2N-bit unsigned +// // integer. Then take the higher N bits. +// t = (m' * n) >> N +// +// // Here we use the fact that n is less than 2^(N-1): otherwise the value +// // of (t + n) may not fit in an N-bit integer. +// return (t + n) >> s +// +// Finding such a magic number is surprisingly easy: +// +// s = \ceil(\log_2 d) +// m' = \floor(2^N * (2^s - d) / d) + 1 // Need 2N-bit integer arithmetic. +// +// See also: +// - Division by Invariant Integers Using Multiplication, +// Torbjörn Granlund and Peter L. Montgomery, 1994. +// +// - http://www.hackersdelight.org/magic.htm +// +// - http://ridiculousfish.com/blog/posts/labor-of-division-episode-i.html + +// Result of div/mod operation stored together. +template +struct DivMod { + Value div, mod; + + C10_HOST_DEVICE DivMod(Value div, Value mod) : div(div), mod(mod) { } +}; + +// Base case: we only have an implementation for uint32_t for now. For +// everything else, we use plain division. +template +struct IntDivider { + IntDivider() = default; + IntDivider(Value d) : divisor(d) { } + + C10_HOST_DEVICE inline Value div(Value n) const { return n / divisor; } + C10_HOST_DEVICE inline Value mod(Value n) const { return n % divisor; } + C10_HOST_DEVICE inline DivMod divmod(Value n) const { + return DivMod(n / divisor, n % divisor); + } + + Value divisor; +}; + +// Implement fast integer division. +template <> +struct IntDivider { + static_assert(sizeof(unsigned int) == 4, "Assumes 32-bit unsigned int."); + + IntDivider() = default; + + IntDivider(unsigned int d) : divisor(d) { + assert(divisor >= 1 && divisor <= INT32_MAX); + + // TODO: gcc/clang has __builtin_clz() but it's not portable. + for (shift = 0; shift < 32; shift++) if ((1U << shift) >= divisor) break; + + uint64_t one = 1; + uint64_t magic = ((one << 32) * ((one << shift) - divisor)) / divisor + 1; + m1 = magic; + assert(m1 > 0 && m1 == magic); // m1 must fit in 32 bits. + } + + C10_HOST_DEVICE inline unsigned int div(unsigned int n) const { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + // 't' is the higher 32-bits of unsigned 32-bit multiplication of 'n' and + // 'm1'. + unsigned int t = __umulhi(n, m1); + return (t + n) >> shift; +#else + // Using uint64_t so that the addition does not overflow. + uint64_t t = ((uint64_t) n * m1) >> 32; + return (t + n) >> shift; +#endif + } + + C10_HOST_DEVICE inline unsigned int mod(unsigned int n) const { + return n - div(n) * divisor; + } + + C10_HOST_DEVICE inline DivMod divmod(unsigned int n) const { + unsigned int q = div(n); + return DivMod(q, n - q * divisor); + } + + unsigned int divisor; // d above. + unsigned int m1; // Magic number: m' above. + unsigned int shift; // Shift amounts. +}; + +} // namespace at::cuda::detail diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/KernelUtils.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/KernelUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..61f576368c3286a3a4eb233b93513f8c3b560a79 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/KernelUtils.h @@ -0,0 +1,37 @@ +#pragma once + +#include +#include + +namespace at::cuda::detail { + +// CUDA: grid stride looping +// +// int64_t _i_n_d_e_x specifically prevents overflow in the loop increment. +// If input.numel() < INT_MAX, _i_n_d_e_x < INT_MAX, except after the final +// iteration of the loop where _i_n_d_e_x += blockDim.x * gridDim.x can be +// greater than INT_MAX. But in that case _i_n_d_e_x >= n, so there are no +// further iterations and the overflowed value in i=_i_n_d_e_x is not used. +#define CUDA_KERNEL_LOOP_TYPE(i, n, index_type) \ + int64_t _i_n_d_e_x = blockIdx.x * blockDim.x + threadIdx.x; \ + for (index_type i=_i_n_d_e_x; _i_n_d_e_x < (n); _i_n_d_e_x+=blockDim.x * gridDim.x, i=_i_n_d_e_x) + +#define CUDA_KERNEL_LOOP(i, n) CUDA_KERNEL_LOOP_TYPE(i, n, int) + + +// Use 1024 threads per block, which requires cuda sm_2x or above +constexpr int CUDA_NUM_THREADS = 1024; + +// CUDA: number of blocks for threads. +inline int GET_BLOCKS(const int64_t N, const int64_t max_threads_per_block=CUDA_NUM_THREADS) { + TORCH_INTERNAL_ASSERT(N > 0, "CUDA kernel launch blocks must be positive, but got N=", N); + constexpr int64_t max_int = std::numeric_limits::max(); + + // Round up division for positive number that cannot cause integer overflow + auto block_num = (N - 1) / max_threads_per_block + 1; + TORCH_INTERNAL_ASSERT(block_num <= max_int, "Can't schedule too many blocks on CUDA device"); + + return static_cast(block_num); +} + +} // namespace at::cuda::detail diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/LazyNVRTC.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/LazyNVRTC.h new file mode 100644 index 0000000000000000000000000000000000000000..95e52c94377bf568060c25f887454dbbaf2054b7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/LazyNVRTC.h @@ -0,0 +1,11 @@ +#pragma once +#include +namespace at::cuda { +// Forward-declares at::cuda::NVRTC +struct NVRTC; + +namespace detail { +extern NVRTC lazyNVRTC; +} // namespace detail + +} // namespace at::cuda diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/PhiloxCudaStateRaw.cuh b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/PhiloxCudaStateRaw.cuh new file mode 100644 index 0000000000000000000000000000000000000000..231cd167cacb4f9b4f2f48e159431b30f3d6dc28 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/PhiloxCudaStateRaw.cuh @@ -0,0 +1,43 @@ +// No "#pragma once" because this is a raw definition that can be copied by jit codegen. +// Eager mode clients should not include this file directly, instead, +// they should #include , which has a #pragma once. + +// Stores RNG state values. Passed as a kernel argument. +// See Note [CUDA Graph-safe RNG states]. +// +// The raw definition lives in its own file so jit codegen can easily copy it. +namespace at { + +struct PhiloxCudaState { + PhiloxCudaState() = default; + // Called if graph capture is not underway + PhiloxCudaState(uint64_t seed, + uint64_t offset) { + seed_.val = seed; + offset_.val = offset; + } + // Called if graph capture is underway + PhiloxCudaState(int64_t* seed, + int64_t* offset_extragraph, + uint32_t offset_intragraph) { + seed_.ptr = seed; + offset_.ptr = offset_extragraph; + offset_intragraph_ = offset_intragraph; + captured_ = true; + } + + // Public members, directly accessible by at::cuda::philox::unpack. + // If we made them private with getters/setters, the getters/setters + // would have to be __device__, and we can't declare __device__ in ATen. + union Payload { + uint64_t val; + int64_t* ptr; + }; + + Payload seed_{}; + Payload offset_{}; + uint32_t offset_intragraph_ = 0; + bool captured_ = false; +}; + +} // namespace at diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/TensorInfo.cuh b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/TensorInfo.cuh new file mode 100644 index 0000000000000000000000000000000000000000..a320000ae881faa7416bae6ed1f37793f357a73a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/TensorInfo.cuh @@ -0,0 +1,116 @@ +#pragma once + +#include + +namespace at::cuda::detail { + +#define MAX_TENSORINFO_DIMS 25 + +// CUDA kernel argument that defines tensor layout +template +struct TensorInfo { + TensorInfo(); + TensorInfo(T* p, + int dim, + IndexType sz[MAX_TENSORINFO_DIMS], + IndexType st[MAX_TENSORINFO_DIMS]); + + // Set the size of the given dimension to 1, as if it were a + // reduction dim (allows you to calculate offsets of the reduction + // slice) + void reduceDim(int dim); + + // See note on [collapse dims]. + int collapseDims(const int excludeDim = -1); + + // Contiguous tensors of more than one dimension are collapsed down + // to one tensor + __host__ __device__ inline bool isContiguous() const { + return (dims == 1 && strides[0] == 1); + } + + T* data; + IndexType sizes[MAX_TENSORINFO_DIMS]; + IndexType strides[MAX_TENSORINFO_DIMS]; + int dims; +}; + +template +TensorInfo::TensorInfo() { + data = nullptr; + dims = 0; +} + +template +TensorInfo::TensorInfo(T* p, + int dim, + IndexType sz[MAX_TENSORINFO_DIMS], + IndexType st[MAX_TENSORINFO_DIMS]) { + data = p; + dims = dim; + TORCH_CHECK(dims < MAX_TENSORINFO_DIMS, "CUDA Tensors cannot have more than 25 dimensions"); + + for (int i = 0; i < dim; ++i) { + sizes[i] = sz[i]; + strides[i] = st[i]; + } +} + +template +void +TensorInfo::reduceDim(int dim) { + TORCH_CHECK(dim < dims && dim >= 0, "expected dim between 0 and dims - 1"); + sizes[dim] = 1; +} + +template +int +TensorInfo::collapseDims(const int excludeDim) { + auto result = at::collapse_dims(sizes, strides, dims, excludeDim); + dims = std::get<1>(result); + return std::get<0>(result); +} + +// Translate a linear index for the apply to a T* offset; +// specialized on `Dims` to reduce nvcc compilation time +template +struct IndexToOffset { + static __host__ __device__ IndexType get( + IndexType linearId, + const TensorInfo& info) { + + IndexType offset = 0; + + // Uses static dims + for (int i = Dims - 1; i > 0; --i) { + IndexType curDimIndex = linearId % info.sizes[i]; + IndexType curDimOffset = curDimIndex * info.strides[i]; + offset += curDimOffset; + linearId /= info.sizes[i]; + } + + return offset + linearId * info.strides[0]; + } +}; + +// Uses dynamic (runtime) instead of static (compiletime) dims +template +struct IndexToOffset { + static inline __host__ __device__ IndexType get( + IndexType linearId, + const TensorInfo& info) { + + IndexType offset = 0; + + for (int i = info.dims - 1; i > 0; --i) { + IndexType curDimIndex = linearId % info.sizes[i]; + IndexType curDimOffset = curDimIndex * info.strides[i]; + offset += curDimOffset; + linearId /= info.sizes[i]; + } + + return offset + linearId * info.strides[0]; + } +}; + +} // namespace at::cuda::detail diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/UnpackRaw.cuh b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/UnpackRaw.cuh new file mode 100644 index 0000000000000000000000000000000000000000..70cd222a484844cdf4f4cb222af2ac6408598cbd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/detail/UnpackRaw.cuh @@ -0,0 +1,28 @@ +// No "#pragma once" because this is a raw definition that can be copied by jit codegen. +// Eager mode clients should not include this file directly, instead, +// they should #include , which has a #pragma once. + +namespace at::cuda::philox { + +// In-kernel call to retrieve philox seed and offset from a PhiloxCudaState instance whether +// that instance was created with graph capture underway or not. +// See Note [CUDA Graph-safe RNG states]. +// +// We can't write a __device__ function in CUDAGeneratorImpl.h, because it's in ATen. +// Also, whatever call unpacks PhiloxCudaState in consumer kernels must be inlineable. +// Easiest thing that comes to mind is, define a __device__ unpack helper here, in ATen/cuda. +// +// The raw definition lives in its own file so jit codegen can easily copy it. +__host__ __device__ __forceinline__ std::tuple +unpack(at::PhiloxCudaState arg) { + if (arg.captured_) { + // static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long". + // *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel. + // For most threads' reads it will hit in cache, so it shouldn't hurt performance. + return std::make_tuple(static_cast(*arg.seed_.ptr), static_cast(*(arg.offset_.ptr) + arg.offset_intragraph_)); + } else { + return std::make_tuple(arg.seed_.val, arg.offset_.val); + } +} + +} // namespace at::cuda::philox diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/llvm_jit_strings.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/llvm_jit_strings.h new file mode 100644 index 0000000000000000000000000000000000000000..aba40d4f42eac74ee9435d904ac3fb82d1e988c4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/llvm_jit_strings.h @@ -0,0 +1,14 @@ +#pragma once + +#include +#include + +namespace at::cuda { + +TORCH_CUDA_CPP_API const std::string &get_traits_string(); +TORCH_CUDA_CPP_API const std::string &get_cmath_string(); +TORCH_CUDA_CPP_API const std::string &get_complex_body_string(); +TORCH_CUDA_CPP_API const std::string &get_complex_half_body_string(); +TORCH_CUDA_CPP_API const std::string &get_complex_math_string(); + +} // namespace at::cuda diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmCommon.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmCommon.h new file mode 100644 index 0000000000000000000000000000000000000000..c3a171e8d9251a7f873818108253bd24c4423336 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmCommon.h @@ -0,0 +1,397 @@ +// Original TunableOp is from onnxruntime. +// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h +// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +// Adapting TunableOp into PyTorch +// Copyright (c) Advanced Micro Devices, Inc. +// +#pragma once + +#include + +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#endif + +namespace at::cuda::tunable { + +enum class BlasOp { + N = 0, + T = 1 +}; + +inline std::string BlasOpToString(BlasOp op) { + switch (op) { + case BlasOp::N: + return "N"; + case BlasOp::T: + return "T"; + } + TORCH_CHECK(false, "unrecognized BlasOp"); + return "N"; +} + +namespace detail { + +static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size) { + auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA); + // comparison done as 1D tensor + at::Tensor ref = at::from_blob(c, {size}, options); + at::Tensor oth = at::from_blob(other_c, {size}, options); + at::Tensor ref_float = ref.to(at::kFloat); + at::Tensor oth_float = oth.to(at::kFloat); + std::vector atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5}; + std::vector rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5}; + double last_succeed_atol = 1; + double last_succeed_rtol = 1; + for (auto& atol : atols) { + for (auto& rtol : rtols) { + if (at::allclose(ref_float, oth_float, rtol, atol)) { + last_succeed_atol = atol; + last_succeed_rtol = rtol; + } + } + } + if (last_succeed_atol == 1) { + return false; + } + else { + TUNABLE_LOG3("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol); + } + + return true; +} + +} + +template +struct GemmParams : OpParams { + GemmParams() { + duplicate_inputs_ = false; + } + + std::string Signature() const override { + return c10::str(transa, transb, "_", m, "_", n, "_", k); + } + + size_t GetSizeA() const { + return sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m); + } + + size_t GetSizeB() const { + return sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k); + } + + size_t GetSizeC() const { + return sizeof(T) * ldc * n; + } + + size_t GetSize(bool duplicate_inputs) const { + size_t size = GetSizeC(); + if (duplicate_inputs) { + size += GetSizeA(); + size += GetSizeB(); + } + return size; + } + + GemmParams* DeepCopy(bool duplicate_inputs) const { + GemmParams* copy = new GemmParams; + *copy = *this; + c10::DeviceIndex device = 0; + AT_CUDA_CHECK(c10::cuda::GetDevice(&device)); + size_t c_size = GetSizeC(); + copy->c = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(c_size)); + AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync( + copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true)); + if (duplicate_inputs) { + size_t a_size = GetSizeA(); + size_t b_size = GetSizeB(); + copy->a = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(a_size)); + copy->b = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(b_size)); + copy->duplicate_inputs_ = true; + } + return copy; + } + + // only call on object returned by DeepCopy + void Delete() { + c10::cuda::CUDACachingAllocator::raw_delete(c); + if (duplicate_inputs_) { + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(a)); + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(b)); + } + } + + TuningStatus NumericalCheck(GemmParams *other) { + auto c_dtype = c10::CppTypeToScalarType::value; + return detail::NumericalCheck(c_dtype, c, other->c, ldc*n) ? OK : FAIL; + } + + char transa; + char transb; + int64_t m; + int64_t n; + int64_t k; + at::opmath_type alpha; + const T* a; + int64_t lda; + const T* b; + int64_t ldb; + at::opmath_type beta; + T* c; + int64_t ldc; +private: + bool duplicate_inputs_; +}; + +template +struct GemmAndBiasParams : OpParams { + std::string Signature() const override { + return c10::str(transa, transb, "_", m, "_", n, "_", k); + } + + size_t GetSize(bool duplicate_inputs) const { + size_t size = sizeof(T) * ldc * n; + if (duplicate_inputs) { + size += sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m); + size += sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k); + } + return size; + } + + GemmAndBiasParams* DeepCopy(bool duplicate_inputs) const { + GemmAndBiasParams* copy = new GemmAndBiasParams; + *copy = *this; + c10::DeviceIndex device = 0; + AT_CUDA_CHECK(c10::cuda::GetDevice(&device)); + size_t c_size = ldc * n * sizeof(T); + copy->c = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(c_size)); + AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync( + copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true)); + if (duplicate_inputs) { + size_t a_size = sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m); + size_t b_size = sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k); + copy->a = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(a_size)); + copy->b = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(b_size)); + copy->duplicate_inputs_ = true; + } + return copy; + } + + // only call on object returned by DeepCopy + void Delete() { + c10::cuda::CUDACachingAllocator::raw_delete(c); + if (duplicate_inputs_) { + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(a)); + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(b)); + } + } + + TuningStatus NumericalCheck(GemmAndBiasParams *other) { + auto c_dtype = c10::CppTypeToScalarType::value; + return detail::NumericalCheck(c_dtype, c, other->c, ldc*n) ? OK : FAIL; + } + + char transa; + char transb; + int64_t m; + int64_t n; + int64_t k; + at::opmath_type alpha; + const T* a; + int64_t lda; + const T* b; + int64_t ldb; + T* c; + int64_t ldc; + const T* bias; + at::cuda::blas::GEMMAndBiasActivationEpilogue activation; +private: + bool duplicate_inputs_; +}; + +template +struct GemmStridedBatchedParams : OpParams { + GemmStridedBatchedParams() { + duplicate_inputs_ = false; + } + + std::string Signature() const override { + return c10::str(transa, transb, "_", m, "_", n, "_", k, "_B_", batch); + } + + size_t GetSizeA() const { + return sizeof(T) * std::min(lda, stride_a) * ((transa == 'n' || transa == 'N') ? k : m) * batch; + } + + size_t GetSizeB() const { + return sizeof(T) * std::min(ldb, stride_b) * ((transb == 'n' || transb == 'N') ? n : k) * batch; + } + + size_t GetSizeC() const { + return sizeof(T) * std::min(ldc, stride_c) * n * batch; + } + + size_t GetSize(bool duplicate_inputs) const { + size_t size = GetSizeC(); + if (duplicate_inputs) { + size += GetSizeA(); + size += GetSizeB(); + } + return size; + } + + GemmStridedBatchedParams* DeepCopy(bool duplicate_inputs) const { + GemmStridedBatchedParams* copy = new GemmStridedBatchedParams; + *copy = *this; + c10::DeviceIndex device = 0; + AT_CUDA_CHECK(c10::cuda::GetDevice(&device)); + size_t c_size = GetSizeC(); + copy->c = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(c_size)); + AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync( + copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true)); + if (duplicate_inputs) { + size_t a_size = GetSizeA(); + size_t b_size = GetSizeB(); + copy->a = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(a_size)); + copy->b = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(b_size)); + copy->duplicate_inputs_ = true; + } + return copy; + } + + // only call on object returned by DeepCopy + void Delete() { + c10::cuda::CUDACachingAllocator::raw_delete(c); + if (duplicate_inputs_) { + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(a)); + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(b)); + } + } + + TuningStatus NumericalCheck(GemmStridedBatchedParams *other) { + auto c_dtype = c10::CppTypeToScalarType::value; + return detail::NumericalCheck(c_dtype, c, other->c, batch*stride_c) ? OK : FAIL; + } + + char transa; + char transb; + int64_t m; + int64_t n; + int64_t k; + at::opmath_type alpha; + const T* a; + int64_t lda; + int64_t stride_a; + const T* b; + int64_t ldb; + int64_t stride_b; + at::opmath_type beta; + T* c; + int64_t ldc; + int64_t stride_c; + int64_t batch; +private: + bool duplicate_inputs_; +}; + +template +struct ScaledGemmParams : OpParams { + ScaledGemmParams() { + duplicate_inputs_ = false; + } + + std::string Signature() const override { + return c10::str(transa, transb, "_", m, "_", n, "_", k); + } + + size_t GetSizeA() const { + return sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m); + } + + size_t GetSizeB() const { + return sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k); + } + + size_t GetSizeC() const { + return sizeof(T) * ldc * n; + } + + size_t GetSize(bool duplicate_inputs) const { + size_t size = GetSizeC(); + if (duplicate_inputs) { + size += GetSizeA(); + size += GetSizeB(); + } + return size; + } + + ScaledGemmParams* DeepCopy(bool duplicate_inputs) const { + ScaledGemmParams* copy = new ScaledGemmParams; + *copy = *this; + c10::DeviceIndex device = 0; + AT_CUDA_CHECK(c10::cuda::GetDevice(&device)); + size_t c_size = GetSizeC(); + copy->c = c10::cuda::CUDACachingAllocator::raw_alloc(c_size); + AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync( + copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true)); + if (duplicate_inputs) { + size_t a_size = GetSizeA(); + size_t b_size = GetSizeB(); + copy->a = c10::cuda::CUDACachingAllocator::raw_alloc(a_size); + copy->b = c10::cuda::CUDACachingAllocator::raw_alloc(b_size); + copy->duplicate_inputs_ = true; + } + return copy; + } + + // only call on object returned by DeepCopy + void Delete() { + c10::cuda::CUDACachingAllocator::raw_delete(c); + if (duplicate_inputs_) { + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(a)); + c10::cuda::CUDACachingAllocator::raw_delete(const_cast(b)); + } + } + + TuningStatus NumericalCheck(ScaledGemmParams *other) { + return detail::NumericalCheck(c_dtype, c, other->c, ldc*n) ? OK : FAIL; + } + + char transa; + char transb; + int64_t m; + int64_t n; + int64_t k; + const void* a; + const void* a_scale_ptr; + int64_t lda; + ScalarType a_dtype; + const void* b; + const void* b_scale_ptr; + int64_t ldb; + ScalarType b_dtype; + const void* bias_ptr; + ScalarType bias_dtype; + void* c; + const void* c_scale_ptr; + int64_t ldc; + ScalarType c_dtype; + void* amax_ptr; + bool use_fast_accum; +private: + bool duplicate_inputs_; +}; + +} // namespace at::cuda::tunable diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmHipblaslt.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmHipblaslt.h new file mode 100644 index 0000000000000000000000000000000000000000..483b4fb7a91a07b871c804646da0a8e3863b51e7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmHipblaslt.h @@ -0,0 +1,611 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include + +#define TORCH_HIPBLASLT_CHECK(EXPR) \ + do { \ + hipblasStatus_t __err = EXPR; \ + TORCH_CHECK(__err == HIPBLAS_STATUS_SUCCESS, \ + "hipblaslt error: ", \ + hipblasStatusToString(__err), \ + " when calling `" #EXPR "`"); \ + } while (0) + +namespace at::cuda::tunable { + +template +constexpr hipblasDatatype_t HipDataTypeFor(); + +template <> +constexpr hipblasDatatype_t HipDataTypeFor() { + return HIP_R_32F; +} + +template <> +constexpr hipblasDatatype_t HipDataTypeFor() { + return HIP_R_16F; +} + +template <> +constexpr hipblasDatatype_t HipDataTypeFor() { + return HIP_R_16BF; +} + +template <> +constexpr hipblasDatatype_t HipDataTypeFor() { + return HIP_R_64F; +} + +template <> +constexpr hipblasDatatype_t HipDataTypeFor() { + return HIP_R_8F_E4M3_FNUZ; +} + +template <> +constexpr hipblasDatatype_t HipDataTypeFor() { + return HIP_R_8F_E5M2_FNUZ; +} + +template +int GetBatchFromParams(const GemmParams* params) { + return 1; +} + +template +int GetBatchFromParams(const GemmAndBiasParams* params) { + return 1; +} + +template +int GetBatchFromParams(const GemmStridedBatchedParams* params) { + return params->batch; +} + +template +int GetBatchFromParams(const ScaledGemmParams* params) { + return 1; +} + +template +int GetStrideAFromParams(const GemmParams* params) { + return 1; +} + +template +int GetStrideAFromParams(const GemmAndBiasParams* params) { + return 1; +} + +template +int GetStrideAFromParams(const GemmStridedBatchedParams* params) { + return params->stride_a; +} + +template +int GetStrideAFromParams(const ScaledGemmParams* params) { + return 1; +} + +template +int GetStrideBFromParams(const GemmParams* params) { + return 1; +} + +template +int GetStrideBFromParams(const GemmAndBiasParams* params) { + return 1; +} + +template +int GetStrideBFromParams(const GemmStridedBatchedParams* params) { + return params->stride_b; +} + +template +int GetStrideBFromParams(const ScaledGemmParams* params) { + return 1; +} + +template +int GetStrideCFromParams(const GemmParams* params) { + return 1; +} + +template +int GetStrideCFromParams(const GemmAndBiasParams* params) { + return 1; +} + +template +int GetStrideCFromParams(const GemmStridedBatchedParams* params) { + return params->stride_c; +} + +template +int GetStrideCFromParams(const ScaledGemmParams* params) { + return 1; +} + +template +float GetAlphaFromParams(const GemmParams* params) { + return params->alpha; +} + +template +float GetAlphaFromParams(const GemmAndBiasParams* params) { + return params->alpha; +} + +template +float GetAlphaFromParams(const GemmStridedBatchedParams* params) { + return params->alpha; +} + +template +float GetAlphaFromParams(const ScaledGemmParams* params) { + return 1.0; +} + +template +float GetBetaFromParams(const GemmParams* params) { + return params->beta; +} + +template +float GetBetaFromParams(const GemmAndBiasParams* params) { + return 0.0; +} + +template +float GetBetaFromParams(const GemmStridedBatchedParams* params) { + return params->beta; +} + +template +float GetBetaFromParams(const ScaledGemmParams* params) { + return 0.0; +} + +template +const void* GetAScalePointerFromParams(const GemmParams* params) { + return nullptr; +} + +template +const void* GetAScalePointerFromParams(const GemmAndBiasParams* params) { + return nullptr; +} + +template +const void* GetAScalePointerFromParams(const GemmStridedBatchedParams* params) { + return nullptr; +} + +template +const void* GetAScalePointerFromParams(const ScaledGemmParams* params) { + return params->a_scale_ptr; +} + +template +const void* GetBScalePointerFromParams(const GemmParams* params) { + return nullptr; +} + +template +const void* GetBScalePointerFromParams(const GemmAndBiasParams* params) { + return nullptr; +} + +template +const void* GetBScalePointerFromParams(const GemmStridedBatchedParams* params) { + return nullptr; +} + +template +const void* GetBScalePointerFromParams(const ScaledGemmParams* params) { + return params->b_scale_ptr; +} + +template +const void* GetDScalePointerFromParams(const GemmParams* params) { + return nullptr; +} + +template +const void* GetDScalePointerFromParams(const GemmAndBiasParams* params) { + return nullptr; +} + +template +const void* GetDScalePointerFromParams(const GemmStridedBatchedParams* params) { + return nullptr; +} + +template +const void* GetDScalePointerFromParams(const ScaledGemmParams* params) { + return params->c_scale_ptr; +} + +template +const void* GetBiasPointerFromParams(const GemmParams* params) { + return nullptr; +} + +template +const void* GetBiasPointerFromParams(const GemmAndBiasParams* params) { + return params->bias; +} + +template +const void* GetBiasPointerFromParams(const GemmStridedBatchedParams* params) { + return nullptr; +} + +template +const void* GetBiasPointerFromParams(const ScaledGemmParams* params) { + return params->bias_ptr; +} + +template +hipDataType GetBiasTypeFromParams(const GemmParams* params) { + return HIP_R_32F; +} + +template +hipDataType GetBiasTypeFromParams(const GemmAndBiasParams* params) { + return HipDataTypeFor(); +} + +template +hipDataType GetBiasTypeFromParams(const GemmStridedBatchedParams* params) { + return HIP_R_32F; +} + +template +hipDataType GetBiasTypeFromParams(const ScaledGemmParams* params) { + return at::cuda::ScalarTypeToCudaDataType(params->bias_dtype); +} + +template +at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmParams* params) { + return at::cuda::blas::GEMMAndBiasActivationEpilogue::None; +} + +template +at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmAndBiasParams* params) { + return params->activation; +} + +template +at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmStridedBatchedParams* params) { + return at::cuda::blas::GEMMAndBiasActivationEpilogue::None; +} + +template +at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const ScaledGemmParams* params) { + return at::cuda::blas::GEMMAndBiasActivationEpilogue::None; +} + +static hipblasOperation_t _hipblasOpFromChar(char op) { + switch (op) { + case 'n': + case 'N': + return HIPBLAS_OP_N; + case 't': + case 'T': + return HIPBLAS_OP_T; + case 'c': + case 'C': + return HIPBLAS_OP_C; + } + AT_ERROR( + "_hipblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`"); +} + +static char _charFromhipblasOp(hipblasOperation_t op) { + switch (op) { + case HIPBLAS_OP_N: + return 'N'; + case HIPBLAS_OP_T: + return 'T'; + case HIPBLAS_OP_C: + return 'C'; + } + AT_ERROR( + "_charFromhipblasOp input should be HIPBLAS_OP_N/T/C but got `", op, "`"); +} + +static hipblasOperation_t MapLayoutToHipBlasLt(BlasOp layout) { + if (layout == BlasOp::N) { + return HIPBLAS_OP_N; + } + return HIPBLAS_OP_T; +} + +static size_t GetHipblasltWorkspaceSize() { + static const char * env = getenv("HIPBLASLT_WORKSPACE_SIZE"); + // 256MB is max workspace size allowed for hipblaslt + // hipblaslt-bench uses 32MB + // recommendation from hipblaslt author was 76MB + size_t workspace_size = 32*1024; // going with 32MB + if (env) { + try { + workspace_size = std::stoi(env); + } catch(std::invalid_argument const& e) { + TORCH_WARN("invalid HIPBLASLT_WORKSPACE_SIZE,", + " using default workspace size of ", workspace_size, " KiB."); + } catch(std::out_of_range const& e) { + TORCH_WARN("HIPBLASLT_WORKSPACE_SIZE out of range,", + " using default workspace size of ", workspace_size, " KiB."); + } + } + return workspace_size * 1024; +} + +template +struct HipBlasLtDeleter { + void operator()(T* x) { + if (x != nullptr) { + TORCH_CUDABLAS_CHECK(destructor(x)); + } + } +}; + +template +class HipBlasLtDescriptor { + public: + T* descriptor() const { + return descriptor_.get(); + } + T* descriptor() { + return descriptor_.get(); + } + + protected: + std::unique_ptr> descriptor_; +}; + +class HipBlasLtMatmulDescriptor : public HipBlasLtDescriptor< + hipblasLtMatmulDescOpaque_t, + &hipblasLtMatmulDescDestroy> { + public: + HipBlasLtMatmulDescriptor( + hipblasComputeType_t compute_type, + hipDataType scale_type) { + hipblasLtMatmulDesc_t raw_descriptor = nullptr; + TORCH_HIPBLASLT_CHECK( + hipblasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type)); + descriptor_.reset(raw_descriptor); + } + template + inline void setAttribute(hipblasLtMatmulDescAttributes_t attr, const T value) { + TORCH_HIPBLASLT_CHECK(::hipblasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(T))); + } +}; + +template +class HipblasltGemmOp : public Callable { + public: + HipblasltGemmOp(hipblasLtMatmulAlgo_t algo) : algo_{algo} {} + + TuningStatus Call(const ParamsT* params) override { + hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout); + hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout); + auto a_datatype = HipDataTypeFor(); + auto b_datatype = HipDataTypeFor(); + auto in_out_datatype = HipDataTypeFor(); + auto opa = _hipblasOpFromChar(params->transa); + auto opb = _hipblasOpFromChar(params->transb); + + TORCH_CHECK(transa_outer == opa && transb_outer == opb, "trans mismatch, shouldn't happen"); + + float alpha = GetAlphaFromParams(params); + float beta = GetBetaFromParams(params); + + hipblasLtMatrixLayout_t mat_a, mat_b, mat_c; + if (opa == HIPBLAS_OP_N) { + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, a_datatype, params->m, params->k, params->lda)); + } + else { + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, a_datatype, params->k, params->m, params->lda)); + } + if (opb == HIPBLAS_OP_N) { + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, b_datatype, params->k, params->n, params->ldb)); + } + else { + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, b_datatype, params->n, params->k, params->ldb)); + } + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, params->m, params->n, params->ldc)); + + // specific to batched gemmm + int batch = GetBatchFromParams(params); + if (batch > 1) { + int64_t stride_a = GetStrideAFromParams(params); + int64_t stride_b = GetStrideBFromParams(params); + int64_t stride_c = GetStrideCFromParams(params); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( + mat_a, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( + mat_a, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_a, sizeof(stride_a))); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( + mat_b, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( + mat_b, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_b, sizeof(stride_b))); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( + mat_c, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( + mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c))); + } + + HipBlasLtMatmulDescriptor matmul(HIPBLAS_COMPUTE_32F, HIP_R_32F); + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSA, opa); + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSB, opb); + + // specific to scaled gemm + const void* mat1_scale_ptr = GetAScalePointerFromParams(params); + const void* mat2_scale_ptr = GetBScalePointerFromParams(params); + const void* result_scale_ptr = GetDScalePointerFromParams(params); + if (mat1_scale_ptr && mat2_scale_ptr) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr); + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr); + } + if (result_scale_ptr) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); + } + + const void* bias_ptr = GetBiasPointerFromParams(params); + auto bias_datatype = GetBiasTypeFromParams(params); + if (bias_ptr) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr); + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, bias_datatype); + auto activation = GetActivationFromParams(params); + if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::RELU) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_RELU_BIAS); + } + else if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::GELU) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_GELU_BIAS); + } + else { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_BIAS); + } + } + + size_t workspace_size = GetHipblasltWorkspaceSize(); + + auto op_handle = at::cuda::getCurrentCUDABlasLtHandle(); + + size_t ret_workspace_size = 0; + auto status = hipblaslt_ext::matmulIsAlgoSupported(op_handle, + matmul.descriptor(), + &alpha, + mat_a, + mat_b, + &beta, + mat_c, + mat_c, + algo_, + ret_workspace_size); + + if (status == HIPBLAS_STATUS_SUCCESS) { + if (ret_workspace_size >= workspace_size) { + return FAIL; + } + } + else { + return FAIL; + } + + void* workspace_buffer = nullptr; + if (workspace_size > 0) { + workspace_buffer = c10::cuda::CUDACachingAllocator::raw_alloc(workspace_size); + } + + TORCH_HIPBLASLT_CHECK(hipblasLtMatmul(op_handle, + matmul.descriptor(), + &alpha, + params->a, + mat_a, + params->b, + mat_b, + &beta, + params->c, + mat_c, + params->c, + mat_c, + &algo_, + workspace_buffer, + workspace_size, + at::cuda::getCurrentCUDAStream())); + + //TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescDestroy(matmul)); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_a)); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_b)); + TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_c)); + if (workspace_size > 0) { + c10::cuda::CUDACachingAllocator::raw_delete(workspace_buffer); + } + return OK; + } + + private: + hipblasLtMatmulAlgo_t algo_; +}; + +template +auto GetHipBlasLtTypeStringAndOps() { + hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout); + hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout); + auto a_datatype = HipDataTypeFor(); + auto b_datatype = HipDataTypeFor(); + auto in_out_datatype = HipDataTypeFor(); + std::vector heuristic_result; + + hipblasLtHandle_t handle; + TORCH_HIPBLASLT_CHECK(hipblasLtCreate(&handle)); + TORCH_HIPBLASLT_CHECK(hipblaslt_ext::getAllAlgos(handle, + hipblaslt_ext::GemmType::HIPBLASLT_GEMM, + transa_outer, + transb_outer, + a_datatype, + b_datatype, + in_out_datatype, + in_out_datatype, + HIPBLAS_COMPUTE_32F, + heuristic_result)); + TORCH_HIPBLASLT_CHECK(hipblasLtDestroy(handle)); + + // Sort heuristic_result by algo index to make sure the order of returned algos is deterministic. + std::sort(heuristic_result.begin(), + heuristic_result.end(), + [](hipblasLtMatmulHeuristicResult_t& a, hipblasLtMatmulHeuristicResult_t& b) { + return hipblaslt_ext::getIndexFromAlgo(a.algo) < hipblaslt_ext::getIndexFromAlgo(b.algo); + }); + + int returned_algo_count = heuristic_result.size(); + std::vector>>> ret; + for (int i = 0; i < returned_algo_count; i++) { + auto algo = heuristic_result[i].algo; + int algo_index = hipblaslt_ext::getIndexFromAlgo(algo); + auto callable = std::make_unique>(algo); + std::string type_string = c10::str( + "Gemm_Hipblaslt_", _charFromhipblasOp(transa_outer), _charFromhipblasOp(transb_outer), "_", algo_index); + ret.emplace_back(type_string, std::move(callable)); + } + + return ret; +} + +template +auto GetHipBlasLtGemmTypeStringAndOps() { + return GetHipBlasLtTypeStringAndOps>(); +} + +template +auto GetHipBlasLtGemmAndBiasTypeStringAndOps() { + return GetHipBlasLtTypeStringAndOps>(); +} + +template +auto GetHipBlasLtGemmStridedBatchedTypeStringAndOps() { + return GetHipBlasLtTypeStringAndOps>(); +} + +template +auto GetHipBlasLtScaledGemmTypeStringAndOps() { + return GetHipBlasLtTypeStringAndOps>(); +} + +#undef TORCH_HIPBLASLT_CHECK + +} // namespace at::cuda::tunable diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmRocblas.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmRocblas.h new file mode 100644 index 0000000000000000000000000000000000000000..f096ff00fd9b49d109d1bea8589551a8af941a12 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/GemmRocblas.h @@ -0,0 +1,275 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +#define ROCBLAS_BETA_FEATURES_API +#include + +#define TORCH_ROCBLAS_CHECK(EXPR) \ + do { \ + rocblas_status __err = EXPR; \ + TORCH_CHECK(__err == rocblas_status_success, \ + "rocblas error: ", \ + rocblas_status_to_string(__err), \ + " when calling `" #EXPR "`"); \ + } while (0) + +namespace at::cuda::tunable { + +template +constexpr rocblas_datatype RocBlasDataTypeFor(); + +template <> +constexpr rocblas_datatype RocBlasDataTypeFor() { + return rocblas_datatype_f32_r; +} + +template <> +constexpr rocblas_datatype RocBlasDataTypeFor() { + return rocblas_datatype_f64_r; +} + +template <> +constexpr rocblas_datatype RocBlasDataTypeFor() { + return rocblas_datatype_f16_r; +} + +template <> +constexpr rocblas_datatype RocBlasDataTypeFor() { + return rocblas_datatype_bf16_r; +} + +template <> +constexpr rocblas_datatype RocBlasDataTypeFor>() { + return rocblas_datatype_f32_c; +} + +template <> +constexpr rocblas_datatype RocBlasDataTypeFor>() { + return rocblas_datatype_f64_c; +} + +template +constexpr rocblas_datatype RocBlasComputeTypeFor(); + +template <> +constexpr rocblas_datatype RocBlasComputeTypeFor() { + return rocblas_datatype_f32_r; +} + +template <> +constexpr rocblas_datatype RocBlasComputeTypeFor() { + return rocblas_datatype_f64_r; +} + +template <> +constexpr rocblas_datatype RocBlasComputeTypeFor() { + // Note that we're returning the _compute_ type for a given datatype. + // As of 12/2022, using compute type FP16 for 16-bit floats was much + // slower than using compute type FP32. So we use FP32 compute even for + // FP16 datatypes. This is how GEMM is implemented even in the function + // rocblasGemmHelper (see fpgeneric.h) + return rocblas_datatype_f32_r; +} + +template <> +constexpr rocblas_datatype RocBlasComputeTypeFor() { + // Note that we're returning the _compute_ type for a given datatype. + // As of 12/2022, using compute type FP16 for 16-bit floats was much + // slower than using compute type FP32. So we use FP32 compute even for + // BF16 datatypes. This is how GEMM is implemented even in the function + // rocblasGemmHelper (see fpgeneric.h) + return rocblas_datatype_f32_r; +} + +template <> +constexpr rocblas_datatype RocBlasComputeTypeFor>() { + return rocblas_datatype_f32_c; +} + +template <> +constexpr rocblas_datatype RocBlasComputeTypeFor>() { + return rocblas_datatype_f64_c; +} + +template +auto DoCastForHalfOrBfloat16(const T fp) { + return fp; +} + +template <> +inline auto DoCastForHalfOrBfloat16(const Half fp) { + // alpha and beta should be the same as compute_type, in Half case it is float. + float h = fp; + return h; +} + +template <> +inline auto DoCastForHalfOrBfloat16(const BFloat16 fp) { + // alpha and beta should be the same as compute_type, in bfloat16 case it is float. + float h = fp; + return h; +} + +static rocblas_operation _rocblasOpFromChar(char op) { + switch (op) { + case 'n': + case 'N': + return rocblas_operation_none; + case 't': + case 'T': + return rocblas_operation_transpose; + case 'c': + case 'C': + return rocblas_operation_conjugate_transpose; + } + AT_ERROR( + "_rocblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`"); +} + +template +class RocblasGemmOp : public Callable> { + public: + RocblasGemmOp(int solution) : solution_{solution} {} + + TuningStatus Call(const GemmParams* params) override { + auto input_output_type = RocBlasDataTypeFor(); + auto compute_type = RocBlasComputeTypeFor(); + auto h_a = DoCastForHalfOrBfloat16(params->alpha); + auto h_b = DoCastForHalfOrBfloat16(params->beta); + auto status = rocblas_gemm_ex( + (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(), + _rocblasOpFromChar(params->transa), + _rocblasOpFromChar(params->transb), + params->m, params->n, params->k, + &h_a, + params->a, input_output_type, params->lda, + params->b, input_output_type, params->ldb, + &h_b, + params->c, input_output_type, params->ldc, + params->c, input_output_type, params->ldc, + compute_type, + rocblas_gemm_algo_solution_index, + solution_, + rocblas_gemm_flags_none); + if (status != rocblas_status_success) { + return FAIL; + } + return OK; + } + + private: + int solution_; +}; + +template +auto GetRocBlasGemmTypeStringAndOps() { + rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(); + int solution_size; + auto input_output_type = RocBlasDataTypeFor(); + auto compute_type = RocBlasComputeTypeFor(); + // Get the number of available solutions + TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle, + input_output_type, + input_output_type, + compute_type, + rocblas_gemm_flags_none, + nullptr, + &solution_size)); + std::vector solutions(solution_size); + // Get the list of available solutions + TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle, + input_output_type, + input_output_type, + compute_type, + rocblas_gemm_flags_none, + solutions.data(), + &solution_size)); + // Sort the solutions in ascending order to make the solution vector deterministic across runs + std::sort(solutions.begin(), solutions.end()); + + std::vector>>>> ret; + for (size_t i = 0; i < solutions.size(); ++i) { + auto callable = std::make_unique>(solutions[i]); + ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable))); + } + return ret; +} + +template +class RocblasGemmStridedBatchedOp : public Callable> { + public: + RocblasGemmStridedBatchedOp(int solution) : solution_{solution} {} + + TuningStatus Call(const GemmStridedBatchedParams* params) override { + auto input_output_type = RocBlasDataTypeFor(); + auto compute_type = RocBlasComputeTypeFor(); + auto h_a = DoCastForHalfOrBfloat16(params->alpha); + auto h_b = DoCastForHalfOrBfloat16(params->beta); + auto status = rocblas_gemm_strided_batched_ex( + (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(), + _rocblasOpFromChar(params->transa), + _rocblasOpFromChar(params->transb), + params->m, params->n, params->k, + &h_a, + params->a, input_output_type, params->lda, params->stride_a, + params->b, input_output_type, params->ldb, params->stride_b, + &h_b, + params->c, input_output_type, params->ldc, params->stride_c, + params->c, input_output_type, params->ldc, params->stride_c, + params->batch, + compute_type, + rocblas_gemm_algo_solution_index, + solution_, + rocblas_gemm_flags_none); + if (status != rocblas_status_success) { + return FAIL; + } + return OK; + } + + private: + int solution_; +}; + +template +auto GetRocBlasGemmStridedBatchedTypeStringAndOps() { + rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle(); + int solution_size; + auto input_output_type = RocBlasDataTypeFor(); + auto compute_type = RocBlasComputeTypeFor(); + // Get the number of available solutions + TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle, + input_output_type, + input_output_type, + compute_type, + rocblas_gemm_flags_none, + nullptr, + &solution_size)); + std::vector solutions(solution_size); + // Get the list of available solutions + TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle, + input_output_type, + input_output_type, + compute_type, + rocblas_gemm_flags_none, + solutions.data(), + &solution_size)); + // Sort the solutions in ascending order to make the solution vector deterministic across runs + std::sort(solutions.begin(), solutions.end()); + + std::vector>>>> ret; + for (size_t i = 0; i < solutions.size(); ++i) { + auto callable = std::make_unique>(solutions[i]); + ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable))); + } + return ret; +} + +} // namespace at::cuda::tunable diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/StreamTimer.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/StreamTimer.h new file mode 100644 index 0000000000000000000000000000000000000000..c70cb1a908d9df2471dc49b382bee7a51e76884a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/StreamTimer.h @@ -0,0 +1,34 @@ +// Original TunableOp is from onnxruntime. +// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h +// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +// Adapting TunableOp into PyTorch +// Copyright (c) Advanced Micro Devices, Inc. +// +#pragma once + +#include + +#include + +namespace at::cuda::tunable { + +class StreamTimer : public ITimer { + public: + StreamTimer(); + virtual ~StreamTimer() override; + + void Start() override; + + void End() override; + + float Duration() override; + + private: + cudaEvent_t start_; + cudaEvent_t end_; +}; + +} // namespace at::cuda::tunable diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/Tunable.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/Tunable.h new file mode 100644 index 0000000000000000000000000000000000000000..243031cf3da2d070d6ee723077cb80c8b6ebc361 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/Tunable.h @@ -0,0 +1,246 @@ +// Original TunableOp is from onnxruntime. +// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h +// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +// Adapting TunableOp into PyTorch +// Copyright (c) Advanced Micro Devices, Inc. +// +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::cuda::tunable { + +namespace detail { + +struct MaybeDelete { + bool owns_pointer; + void operator()(std::ostream* os) const { if (owns_pointer) delete os; } +}; + +using OstreamPtr = std::unique_ptr; + +static OstreamPtr get_stream(std::string filename) { + if (filename.compare("out") == 0) { + return OstreamPtr { &std::cout, MaybeDelete {false} }; + } + else if (filename.compare("err") == 0) { + return OstreamPtr { &std::cerr, MaybeDelete {false} }; + } + else { + return OstreamPtr { new std::ofstream {filename.c_str()}, MaybeDelete {true} }; + } +} + +} + +static void TunableLog(int level, const std::string& msg) { + static const char *env_file = getenv("PYTORCH_TUNABLEOP_VERBOSE_FILENAME"); + static const char *env_verbose = getenv("PYTORCH_TUNABLEOP_VERBOSE"); + static int level_user = env_verbose ? atoi(env_verbose) : 0; + static auto streamptr = detail::get_stream(env_file ? env_file : "err"); + if (level_user >= level) { + (*streamptr) << msg < KernelMap; +typedef std::unordered_map ResultsMap; + +struct TORCH_CUDA_CPP_API TuningResults { + // Validates if these results are compatible with the libraries + std::unordered_map validators; + + // Mapping from Callable signature to Callable's tuning result + ResultsMap results; +}; + +class TORCH_CUDA_CPP_API TuningResultsManager { + public: + TuningResultsManager() = default; + ~TuningResultsManager() = default; + + KernelMap Lookup(const std::string& op_signature); + + ResultEntry Lookup(const std::string& op_signature, const std::string& params_signature); + + inline void AddImpl(const std::string& op_signature, + const std::string& params_signature, + ResultEntry best, + KernelMap& kernel_map); + + void Add(const std::string& op_signature, + const std::string& params_signature, + ResultEntry best); + + void Delete(const std::string& op_signature, const std::string& params_signature); + + inline void DisjointMergeImpl( + const std::string& op_signature, + const KernelMap& kernel_map, + /*out*/ ResultsMap& results); + + void Load(const ResultsMap& results_to_load); + + ResultsMap Dump(); + + void DisjointMerge(const std::string& op_signature, const KernelMap& kernel_map); + + size_t GetSize(); + + private: + std::mutex lock_; + ResultsMap results_; +}; + +class TORCH_CUDA_CPP_API TuningResultsValidator { + public: + using GetFunc = std::function; + using ValidateFunc = std::function; + using GetValidateFuncs = std::unordered_map>; + + TuningResultsValidator(); + ~TuningResultsValidator() = default; + + std::unordered_map GetAllValidators() const; + TuningStatus ValidateAll(const std::unordered_map& to_validate) const; + void RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf); + + protected: + std::string GetPyTorchVersion() const; + TuningStatus ValidatePyTorchVersion(const std::string& value) const; + + public: + static constexpr const std::array mandatory_keys{"PT_VERSION"}; + + private: + GetValidateFuncs validators_; +}; + +class TORCH_CUDA_CPP_API TuningContext { + public: + TuningContext(); + ~TuningContext(); + TuningContext(TuningContext &) = delete; + TuningContext(TuningContext &&) = delete; + TuningContext &operator=(TuningContext &) = delete; + TuningContext &operator=(TuningContext &&) = delete; + + void EnableTunableOp(bool value); + bool IsTunableOpEnabled() const; + + void EnableTuning(bool value); + bool IsTuningEnabled() const; + + void EnableNumericsCheck(bool value); + bool IsNumericsCheckEnabled() const; + + void SetMaxTuningDurationMs(int max_duration_ms); + int GetMaxTuningDurationMs() const; + + void SetMaxTuningIterations(int max_iter); + int GetMaxTuningIterations() const; + + void SetMaxWarmupDurationMs(int max_duration_ms); + int GetMaxWarmupDurationMs() const; + + void SetMaxWarmupIterations(int max_iter); + int GetMaxWarmupIterations() const; + + void EnableICacheFlush(bool value); + bool IsICacheFlushEnabled() const; + + void SetRotatingBufferSize(int size); + int GetRotatingBufferSize() const; + + TuningResultsManager& GetTuningResultsManager(); + + TuningResultsValidator& GetTuningResultsValidator(); + + TuningResults GetTuningResults(); + + TuningStatus LoadTuningResults(const TuningResults& tr); + + void SetFilename(const std::string& filename, bool insert_device_ordinal=false); + std::string GetFilename() const; + + void WriteFileOnExit(bool value); + + bool ReadFile(const std::string& filename={}); + bool WriteFile(const std::string& filename={}); + + private: + bool enable_; + bool tuning_enable_; + bool manager_initialized_; + bool write_file_on_exit_; + bool numerics_check_enable_; + int max_tuning_duration_ms_; + int max_tuning_iterations_; + int max_warmup_duration_ms_; + int max_warmup_iterations_; + bool icache_flush_; + int rotating_buffer_size_; + mutable TuningResultsManager manager_; + mutable c10::once_flag manager_init_once_; + TuningResultsValidator validator_; + std::string filename_; + size_t results_count_from_input_file_; +}; + +TORCH_CUDA_CPP_API TuningContext* getTuningContext(); + +class ITimer { + public: + ITimer() = default; + virtual ~ITimer() = default; + + virtual void Start() = 0; + virtual void End() = 0; + + /// Computes the elapsed time in milliseconds between Start() and End() + virtual float Duration() = 0; +}; + +} // namespace at::cuda::tunable diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/TunableGemm.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/TunableGemm.h new file mode 100644 index 0000000000000000000000000000000000000000..50a7344b0260c24ab432b3f0462244d3322a5d52 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/TunableGemm.h @@ -0,0 +1,307 @@ +// Original TunableOp is from onnxruntime. +// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h +// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +// Adapting TunableOp into PyTorch +// Copyright (c) Advanced Micro Devices, Inc. +// +#pragma once + +#include +#ifdef USE_ROCM +#include +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::cuda::tunable { + +template +class DefaultGemmOp : public Callable> { + public: + TuningStatus Call(const GemmParams* params) override { + at::cuda::blas::gemm_internal( + params->transa, params->transb, + params->m, params->n, params->k, + params->alpha, + params->a, params->lda, + params->b, params->ldb, + params->beta, + params->c, params->ldc); + return OK; + } +}; + +static bool _transposeBoolFromChar(char op) { + return op == 't' || op == 'T'; +} + +template +class DefaultGemmAndBiasOp : public Callable> { + public: + TuningStatus Call(const GemmAndBiasParams* params) override { + at::cuda::blas::gemm_and_bias( + _transposeBoolFromChar(params->transa), + _transposeBoolFromChar(params->transb), + params->m, params->n, params->k, + params->alpha, + params->a, params->lda, + params->b, params->ldb, + params->bias, + params->c, params->ldc, + params->activation); + return OK; + } +}; + +template +class DefaultGemmStridedBatchedOp : public Callable> { + public: + TuningStatus Call(const GemmStridedBatchedParams* params) override { + at::cuda::blas::bgemm_internal( + params->transa, params->transb, + params->m, params->n, params->k, + params->alpha, + params->a, params->lda, params->stride_a, + params->b, params->ldb, params->stride_b, + params->beta, + params->c, params->ldc, params->stride_c, + params->batch); + return OK; + } +}; + +template +class DefaultScaledGemmOp : public Callable> { + public: + TuningStatus Call(const ScaledGemmParams* params) override { + at::cuda::blas::scaled_gemm( + params->transa, + params->transb, + params->m, + params->n, + params->k, + params->a, + params->a_scale_ptr, + params->lda, + params->a_dtype, + params->b, + params->b_scale_ptr, + params->ldb, + params->b_dtype, + params->bias_ptr, + params->bias_dtype, + params->c, + params->c_scale_ptr, + params->ldc, + params->c_dtype, + params->amax_ptr, + params->use_fast_accum); + return OK; + } +}; + +template +inline bool IsZero(T v) { + return v == 0.0f; +} + +template <> +inline bool IsZero(BFloat16 v) { + return v.x == 0; +} + +template <> +inline bool IsZero(Half v) { + return float(v) == 0.0f; +} + +template <> +inline bool IsZero(c10::complex v) { + return v == 0.0; +} + +template <> +inline bool IsZero(c10::complex v) { + return v == 0.0f; +} + +template +inline std::string TypeName(T v) { + return "unknown"; +} + +template <> +inline std::string TypeName(float v) { + return "float"; +} + +template <> +inline std::string TypeName(double v) { + return "double"; +} + +template <> +inline std::string TypeName(BFloat16 v) { + return "BFloat16"; +} + +template <> +inline std::string TypeName(Half v) { + return "Half"; +} + +template <> +inline std::string TypeName(Float8_e4m3fn v) { + return "Float8_e4m3fn"; +} + +template <> +inline std::string TypeName(Float8_e5m2 v) { + return "Float8_e5m2"; +} + +template <> +inline std::string TypeName(Float8_e4m3fnuz v) { + return "Float8_e4m3fnuz"; +} + +template <> +inline std::string TypeName(Float8_e5m2fnuz v) { + return "Float8_e5m2fnuz"; +} + +template <> +inline std::string TypeName(c10::complex v) { + return "c10::complex"; +} + +template <> +inline std::string TypeName(c10::complex v) { + return "c10::complex"; +} + +template +class GemmTunableOp : public TunableOp, StreamTimer> { + public: + GemmTunableOp() { + this->RegisterOp(std::string("Default"), std::make_unique>()); + +#ifdef USE_ROCM + static const char *env_rocblas = std::getenv("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED"); + if (env_rocblas == nullptr || strcmp(env_rocblas, "1") == 0) { + for (auto&& [name, op] : GetRocBlasGemmTypeStringAndOps()) { + this->RegisterOp(std::move(name), std::move(op)); + } + } + + static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); + if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) { + // disallow tuning of hipblaslt with c10::complex + if constexpr ( + !std::is_same_v> && + !std::is_same_v>) { + for (auto&& [name, op] : GetHipBlasLtGemmTypeStringAndOps()) { + this->RegisterOp(std::move(name), std::move(op)); + } + } + } +#endif + } + + std::string Signature() override { + return c10::str("GemmTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); + } +}; + +template +class GemmAndBiasTunableOp : public TunableOp, StreamTimer> { + public: + GemmAndBiasTunableOp() { + this->RegisterOp(std::string("Default"), std::make_unique>()); + +#ifdef USE_ROCM + static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); + if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) { + // disallow tuning of hipblaslt with c10::complex + if constexpr ( + !std::is_same_v> && + !std::is_same_v>) { + for (auto&& [name, op] : GetHipBlasLtGemmAndBiasTypeStringAndOps()) { + this->RegisterOp(std::move(name), std::move(op)); + } + } + } +#endif + } + + std::string Signature() override { + return c10::str("GemmAndBiasTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); + } +}; + +template +class GemmStridedBatchedTunableOp : public TunableOp, StreamTimer> { + public: + GemmStridedBatchedTunableOp() { + this->RegisterOp(std::string("Default"), std::make_unique>()); + +#ifdef USE_ROCM + static const char *env_rocblas = std::getenv("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED"); + if (env_rocblas == nullptr || strcmp(env_rocblas, "1") == 0) { + for (auto&& [name, op] : GetRocBlasGemmStridedBatchedTypeStringAndOps()) { + this->RegisterOp(std::move(name), std::move(op)); + } + } + + static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); + if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) { + // disallow tuning of hipblaslt with c10::complex + if constexpr ( + !std::is_same_v> && + !std::is_same_v>) { + for (auto&& [name, op] : GetHipBlasLtGemmStridedBatchedTypeStringAndOps()) { + this->RegisterOp(std::move(name), std::move(op)); + } + } + } +#endif + } + + std::string Signature() override { + return c10::str("GemmStridedBatchedTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); + } +}; + +template +class ScaledGemmTunableOp : public TunableOp, StreamTimer> { + public: + ScaledGemmTunableOp() { + this->RegisterOp(std::string("Default"), std::make_unique>()); + +#ifdef USE_ROCM + for (auto&& [name, op] : GetHipBlasLtScaledGemmTypeStringAndOps()) { + this->RegisterOp(std::move(name), std::move(op)); + } +#endif + } + + std::string Signature() override { + return c10::str("ScaledGemmTunableOp", + "_", TypeName(AT{}), + "_", TypeName(BT{}), + "_", TypeName(CT{}), + "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); + } +}; + +} // namespace at::cuda::tunable diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/TunableOp.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/TunableOp.h new file mode 100644 index 0000000000000000000000000000000000000000..9fb7afdb7627f43211dba2ee215cb20b30ef4453 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/tunable/TunableOp.h @@ -0,0 +1,286 @@ +// Original TunableOp is from onnxruntime. +// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h +// https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +// Adapting TunableOp into PyTorch +// Copyright (c) Advanced Micro Devices, Inc. +// +#pragma once + +#include +#include +#include + +#ifndef _WIN32 +#include +#endif + +#include +#include +#include +#include + +namespace at::cuda::tunable { + +template +class Callable { + public: + Callable() = default; + Callable(Callable&&) = default; + virtual ~Callable() = default; + virtual TuningStatus Call(const ParamsT*) { + return FAIL; + } + virtual TuningStatus IsSupported(const ParamsT* params) { + return Call(params); + } +}; + +template +class TunableOp { + public: + TunableOp() = default; + TunableOp(TunableOp&&) = default; + virtual ~TunableOp() = default; + + TuningStatus operator()(const ParamsT* params) { + ResultEntry result = ResultEntry::Null(); + TuningContext* ctx = getTuningContext(); + if (ctx->IsTunableOpEnabled()) { + auto& mgr = ctx->GetTuningResultsManager(); + auto op_sig = Signature(); + auto params_sig = params->Signature(); + result = mgr.Lookup(op_sig, params_sig); + // If there is not previous tuning result been found, we do the tuning iff tuning is enabled + if (result == ResultEntry::Null() && ctx->IsTuningEnabled()) { + result = FindFastest(params); + mgr.Add(op_sig, params_sig, result); + } + } + else { + result = ResultEntry::Default(); + } + if (result == ResultEntry::Null()) { + TUNABLE_LOG2("no result, using default"); + result = ResultEntry::Default(); + } + auto iter = ops_.find(result); + TORCH_CHECK(iter != ops_.end()); + return iter->second->Call(params); + } + + virtual std::string Signature() { + // According to C++17 standard https://wg21.link/n4659 section 15.7.4 + // > if the operand of typeid refers to the + // > object under construction or destruction, typeid yields the std::type_info object representing the constructor + // > or destructor’s class. + // So delay the op signature generation. + c10::call_once(signature_init_once_, [this]() { signature_ = CreateSignature(); }); + return signature_; + } + + protected: + void RegisterOp(const std::string& name, std::unique_ptr> op) { + this->op_names_.emplace_back(name); + this->ops_.emplace(name, std::move(op)); + } + + private: + static void WarmUp(Callable *op, const std::vector ¶m, size_t num_iter, size_t &offset) { + TuningContext* ctx = getTuningContext(); + bool do_flush = ctx->IsICacheFlushEnabled(); + for (size_t i = 0; i < num_iter; i++) { + if (do_flush) { + at::cuda::flush_icache(); + } + TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK); + } + } + + static double Profile(Callable *op, const std::vector ¶m, size_t num_iter, size_t &offset) { + TuningContext* ctx = getTuningContext(); + bool do_flush = ctx->IsICacheFlushEnabled(); + TimerT timer{}; + timer.Start(); + for (size_t i = 0; i < num_iter; i++) { + if (do_flush) { + at::cuda::flush_icache(); + } + TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK); + } + timer.End(); + return timer.Duration() / num_iter; + } + + protected: + virtual ResultEntry FindFastest(const ParamsT* params) { + TuningContext* ctx = getTuningContext(); + auto op_sig = Signature(); + auto params_sig = params->Signature(); + TUNABLE_LOG2("finding fastest for ", op_sig, '(', params_sig, ')', " out of ", op_names_.size(), " candidates"); + auto min_duration_ms = std::numeric_limits::infinity(); + std::string id_name = "Default"; + ParamsT* reference_params = nullptr; + + // numeric check option is controlled by non-static env var, so check it once per tuned operator + bool do_numerics_check = ctx->IsNumericsCheckEnabled(); + + // calcaulte a reference answer for numerical check + if (do_numerics_check) { + reference_params = params->DeepCopy(false); + TORCH_CHECK(ops_[ResultEntry::Default()]->Call(reference_params) == OK); + } + + // need copies of params to reuse + // make as many copies as will fill the requested rotating buffer size, if requested + // rotating_size guaranteed to be >= 0 even though GetRotatingBufferSize() returns int + size_t rotating_size = ctx->GetRotatingBufferSize(); + bool use_buffer_rotation = (rotating_size > 0); + size_t param_size = params->GetSize(use_buffer_rotation); + size_t param_count = (rotating_size / param_size) + 1; + constexpr size_t MB = 1024*1024; + if (use_buffer_rotation) { + TUNABLE_LOG2("Rotating buffer ", rotating_size/MB, " MiB. ", + "Needed Size: ", param_size/MB, " MiB. ", + "Needed number of param copies: ", param_count); + } + TORCH_CHECK(param_count > 0); + + std::vector reusable_params(param_count); + for (size_t i = 0; i < param_count; i++) { + reusable_params[i] = params->DeepCopy(use_buffer_rotation); + } + + // for rotating buffer + size_t offset = 0; + + for (size_t i = 0; i < op_names_.size(); i++) { + auto* candidate = ops_[op_names_[i]].get(); // borrow pointer + + if (do_numerics_check) { + ParamsT* numerical_params = params->DeepCopy(false); + auto status = candidate->Call(numerical_params); + if (status != OK) { + numerical_params->Delete(); + TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); + continue; + } + status = reference_params->NumericalCheck(numerical_params); + numerical_params->Delete(); + if (status != OK) { + TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); + continue; + } + } + else { + auto status = candidate->Call(reusable_params[0]); + if (status != OK) { + TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); + continue; + } + } + + // collect a small profile + constexpr const int approx_num_iter = 3; + auto approx_duration = Profile(candidate, reusable_params, approx_num_iter, offset); + // bail if too slow + if (approx_duration > 2 * min_duration_ms) { + TUNABLE_LOG3("├──skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); + continue; + } + + // for warmup does user set max duration, max iters, or both? + // warmup is allowed to be skipped by setting either iterations or duration to 0 + double max_warmup_duration = ctx->GetMaxWarmupDurationMs(); + int max_warmup_iter = ctx->GetMaxWarmupIterations(); + int warmup_iter = 1; // default + if (max_warmup_duration >= 0) { + int duration_iters = max_warmup_duration / approx_duration; + if (max_warmup_iter >= 0) { + warmup_iter = std::min(max_warmup_iter, duration_iters); + } + else { + warmup_iter = duration_iters; + } + } + else if (max_warmup_iter >= 0) { + warmup_iter = max_warmup_iter; + } + + // for tuning does user set max duration, max iters, or both? + double max_tuning_duration = ctx->GetMaxTuningDurationMs(); + int max_tuning_iter = ctx->GetMaxTuningIterations(); + int tuning_iter = 100; // default + if (max_tuning_duration > 0) { + int duration_iters = max_tuning_duration / approx_duration; + if (max_tuning_iter > 0) { + tuning_iter = std::min(max_tuning_iter, duration_iters); + } + else { + tuning_iter = duration_iters; + } + } + else if (max_tuning_iter > 0) { + tuning_iter = max_tuning_iter; + } + // tuning must run at least 1 iteration + tuning_iter = std::max(1, tuning_iter); + + // do the full warmup followed by tuning + double warmup_ms = warmup_iter * approx_duration; + double tuning_ms = tuning_iter * approx_duration; + TUNABLE_LOG3("├──tuning using " + "warmup iters ", warmup_iter, " [", warmup_ms, " ms] " + "and tuning iters ", tuning_iter, " [", tuning_ms, " ms] ", + "instance id=", i, ", ", op_sig, "(", params_sig, ") ", op_names_[i]); + TUNABLE_LOG3("├──offset at ", offset); + WarmUp(candidate, reusable_params, warmup_iter, offset); + auto duration_ms = Profile(candidate, reusable_params, tuning_iter, offset); + if (duration_ms < min_duration_ms) { + TUNABLE_LOG3("├──found better instance id=", i, ". " , duration_ms, "ms. ", op_names_[i]); + min_duration_ms = duration_ms; + id_name = op_names_[i]; + } + } + + for (size_t i = 0; i < reusable_params.size(); i++) { + reusable_params[i]->Delete(); + } + if (reference_params) { + reference_params->Delete(); + } + + TUNABLE_LOG2("└──found fastest for ", op_sig, '(', params_sig, ") ", id_name); + return ResultEntry(id_name, min_duration_ms); + } + + private: + std::string CreateSignature() { +#ifndef _WIN32 + const auto* name = typeid(*this).name(); + char buf[256]; + size_t buf_len = 256; + abi::__cxa_demangle(name, buf, &buf_len, nullptr); + buf[255] = '\0'; + return buf; +#else + return typeid(*this).name(); +#endif + } + + mutable c10::once_flag signature_init_once_; + std::string signature_; + + std::unordered_map>> ops_; + std::vector op_names_; +}; + +struct OpParams { + OpParams() {} + virtual ~OpParams() = default; + virtual std::string Signature() const = 0; +}; + +} // namespace at::cuda::tunable diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Activation.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Activation.h new file mode 100644 index 0000000000000000000000000000000000000000..dc84547b7fe17521c202df9f81acd47af29d10d6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Activation.h @@ -0,0 +1,98 @@ +#pragma once + +#include +#include +#include + +namespace c10 { +class Scalar; +} + +namespace at { +struct TensorIterator; +struct TensorIteratorBase; +class TensorBase; +} + +namespace at::native { + +// These constants control the approximation behavior of gelu function. +enum class GeluType { + None, // Baseline Gelu + Tanh, // Tahn Gelu Approximation + END +}; + +inline GeluType get_gelutype_enum(const c10::string_view approximate) { + if (approximate == "none") { + return GeluType::None; + } else if (approximate == "tanh") { + return GeluType::Tanh; + } else { + TORCH_CHECK(false, "approximate argument must be either none or tanh."); + } +} + +inline std::string gelutype_to_string(const GeluType type) { + switch(type) { + case GeluType::None: return "none"; + case GeluType::Tanh: return "tanh"; + default: TORCH_CHECK(false, "unknown GELU type: ", static_cast(type)); + } +} + +using structured_activation_fn = void (*)(TensorIteratorBase&); +using structured_activation_backward_fn = void (*)(TensorIteratorBase&); + +using activation_fn = void (*)(TensorIterator&); +using activation_backward_fn = void (*)(TensorIterator&); +using softplus_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&); +using softplus_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&); +using threshold_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&); +using hardtanh_backward_fn = void (*)(TensorIterator&, const c10::Scalar&, const c10::Scalar&); +using hardsigmoid_fn = void(*)(TensorIteratorBase&); +using hardsigmoid_backward_fn = void(*)(TensorIteratorBase&); +using hardswish_fn = void(*)(TensorIterator&); +using hardswish_backward_fn = void(*)(TensorIterator&); +using shrink_fn = void (*)(TensorIteratorBase&, const c10::Scalar&); +using softshrink_fn = void (*)(TensorIteratorBase&, const c10::Scalar&); +using shrink_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&); +using elu_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&, const c10::Scalar&); +using elu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&, const c10::Scalar&, bool); +using leaky_relu_fn = void (*)(TensorIteratorBase&, const c10::Scalar&); +using leaky_relu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&); +using log_sigmoid_cpu_fn = void (*)(TensorBase&, TensorBase&, const TensorBase&); +using gelu_fn = void (*)(TensorIteratorBase&, GeluType); +using gelu_backward_fn = void (*)(TensorIteratorBase&, GeluType); +using glu_jvp_fn = void (*)(TensorIteratorBase&); + +DECLARE_DISPATCH(elu_fn, elu_stub); +DECLARE_DISPATCH(elu_backward_fn, elu_backward_stub); +DECLARE_DISPATCH(softplus_fn, softplus_stub); +DECLARE_DISPATCH(softplus_backward_fn, softplus_backward_stub); +DECLARE_DISPATCH(log_sigmoid_cpu_fn, log_sigmoid_cpu_stub); +DECLARE_DISPATCH(activation_backward_fn, log_sigmoid_backward_stub); +DECLARE_DISPATCH(threshold_fn, threshold_stub); +DECLARE_DISPATCH(gelu_fn, GeluKernel); +DECLARE_DISPATCH(gelu_backward_fn, GeluBackwardKernel); +DECLARE_DISPATCH(hardtanh_backward_fn, hardtanh_backward_stub); +DECLARE_DISPATCH(hardsigmoid_fn, hardsigmoid_stub); +DECLARE_DISPATCH(hardsigmoid_backward_fn, hardsigmoid_backward_stub); +DECLARE_DISPATCH(hardswish_fn, hardswish_stub); +DECLARE_DISPATCH(hardswish_backward_fn, hardswish_backward_stub); +DECLARE_DISPATCH(shrink_fn, hardshrink_stub); +DECLARE_DISPATCH(softshrink_fn, softshrink_stub); +DECLARE_DISPATCH(shrink_backward_fn, shrink_backward_stub); +DECLARE_DISPATCH(leaky_relu_fn, leaky_relu_stub); +DECLARE_DISPATCH(leaky_relu_backward_fn, leaky_relu_backward_stub); +DECLARE_DISPATCH(structured_activation_fn, glu_stub); +DECLARE_DISPATCH(activation_backward_fn, glu_backward_stub); +DECLARE_DISPATCH(glu_jvp_fn, glu_jvp_stub); +DECLARE_DISPATCH(structured_activation_fn, silu_stub); +DECLARE_DISPATCH(structured_activation_backward_fn, silu_backward_stub); +DECLARE_DISPATCH(structured_activation_fn, mish_stub); +DECLARE_DISPATCH(activation_backward_fn, mish_backward_stub); +DECLARE_DISPATCH(activation_fn, prelu_stub); +DECLARE_DISPATCH(activation_backward_fn, prelu_backward_stub); + +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/AdaptivePooling.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/AdaptivePooling.h new file mode 100644 index 0000000000000000000000000000000000000000..6c49fd38d940991ba217213a4b56a02517b81f95 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/AdaptivePooling.h @@ -0,0 +1,49 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace at::native { + +using adaptive_avg_pooling2d_fn = void(*)(Tensor& output, const Tensor& input, IntArrayRef output_size); +using adaptive_avg_pooling2d_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output); +DECLARE_DISPATCH(adaptive_avg_pooling2d_fn, adaptive_avg_pool2d_kernel); +DECLARE_DISPATCH(adaptive_avg_pooling2d_backward_fn, adaptive_avg_pool2d_backward_kernel); + +using adaptive_max_pooling2d_fn = void(*)(const Tensor& output, const Tensor& indices, const Tensor& input, IntArrayRef output_size); +using adaptive_max_pooling2d_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices); +DECLARE_DISPATCH(adaptive_max_pooling2d_fn, adaptive_max_pool2d_kernel); +DECLARE_DISPATCH(adaptive_max_pooling2d_backward_fn, adaptive_max_pool2d_backward_kernel); + +using adaptive_avg_pooling3d_fn = void(*)(Tensor& output, const Tensor& input, IntArrayRef output_size); +using adaptive_avg_pooling3d_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output); +DECLARE_DISPATCH(adaptive_avg_pooling3d_fn, adaptive_avg_pool3d_kernel); +DECLARE_DISPATCH(adaptive_avg_pooling3d_backward_fn, adaptive_avg_pool3d_backward_kernel); + +using adaptive_max_pooling3d_fn = void(*)(const Tensor& output, const Tensor& indices, const Tensor& input, IntArrayRef output_size); +using adaptive_max_pooling3d_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices); +DECLARE_DISPATCH(adaptive_max_pooling3d_fn, adaptive_max_pool3d_kernel); +DECLARE_DISPATCH(adaptive_max_pooling3d_backward_fn, adaptive_max_pool3d_backward_kernel); + +inline int64_t start_index(int64_t a, int64_t b, int64_t c) { + return (a / b) * c + ((a % b) * c) / b; +} + +inline int64_t end_index(int64_t a, int64_t b, int64_t c) { + return 1 + ((a + 1) * c - 1) / b; +} + +inline void adaptive_pool_empty_output_check(const Tensor& gradOutput_, const char* arg_name) { + int64_t ndim = gradOutput_.ndimension(); + for (const auto i : c10::irange(1, ndim)) { + TORCH_CHECK(gradOutput_.size(i) > 0, + arg_name, "(): Expected grad_output to have non-zero size for non-batch dimensions, " + "but grad_output has sizes ", gradOutput_.sizes(), " with dimension ", i, + " being empty"); + } +} + +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/BatchLinearAlgebra.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/BatchLinearAlgebra.h new file mode 100644 index 0000000000000000000000000000000000000000..58d46aacd473147ec42ab03d90ddec77d3e7245d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/BatchLinearAlgebra.h @@ -0,0 +1,321 @@ +#pragma once + +#include +#include +#include +#include + +// Forward declare TI +namespace at { +class Tensor; +struct TensorIterator; + +namespace native { +enum class TransposeType; +} + +} + +namespace at::native { + +enum class LapackLstsqDriverType : int64_t { Gels, Gelsd, Gelsy, Gelss}; + +#if AT_BUILD_WITH_LAPACK() +// Define per-batch functions to be used in the implementation of batched +// linear algebra operations + +template +void lapackCholesky(char uplo, int n, scalar_t *a, int lda, int *info); + +template +void lapackCholeskyInverse(char uplo, int n, scalar_t *a, int lda, int *info); + +template +void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *w, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, value_t *rwork, int *info); + +template +void lapackGeqrf(int m, int n, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info); + +template +void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info); + +template +void lapackOrmqr(char side, char trans, int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *c, int ldc, scalar_t *work, int lwork, int *info); + +template +void lapackSyevd(char jobz, char uplo, int n, scalar_t* a, int lda, value_t* w, scalar_t* work, int lwork, value_t* rwork, int lrwork, int* iwork, int liwork, int* info); + +template +void lapackGels(char trans, int m, int n, int nrhs, + scalar_t *a, int lda, scalar_t *b, int ldb, + scalar_t *work, int lwork, int *info); + +template +void lapackGelsd(int m, int n, int nrhs, + scalar_t *a, int lda, scalar_t *b, int ldb, + value_t *s, value_t rcond, int *rank, + scalar_t* work, int lwork, + value_t *rwork, int* iwork, int *info); + +template +void lapackGelsy(int m, int n, int nrhs, + scalar_t *a, int lda, scalar_t *b, int ldb, + int *jpvt, value_t rcond, int *rank, + scalar_t *work, int lwork, value_t* rwork, int *info); + +template +void lapackGelss(int m, int n, int nrhs, + scalar_t *a, int lda, scalar_t *b, int ldb, + value_t *s, value_t rcond, int *rank, + scalar_t *work, int lwork, + value_t *rwork, int *info); + +template +struct lapackLstsq_impl; + +template +struct lapackLstsq_impl { + static void call( + char trans, int m, int n, int nrhs, + scalar_t *a, int lda, scalar_t *b, int ldb, + scalar_t *work, int lwork, int *info, // Gels flavor + int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor + value_t *s, // Gelss flavor + int *iwork // Gelsd flavor + ) { + lapackGels( + trans, m, n, nrhs, + a, lda, b, ldb, + work, lwork, info); + } +}; + +template +struct lapackLstsq_impl { + static void call( + char trans, int m, int n, int nrhs, + scalar_t *a, int lda, scalar_t *b, int ldb, + scalar_t *work, int lwork, int *info, // Gels flavor + int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor + value_t *s, // Gelss flavor + int *iwork // Gelsd flavor + ) { + lapackGelsy( + m, n, nrhs, + a, lda, b, ldb, + jpvt, rcond, rank, + work, lwork, rwork, info); + } +}; + +template +struct lapackLstsq_impl { + static void call( + char trans, int m, int n, int nrhs, + scalar_t *a, int lda, scalar_t *b, int ldb, + scalar_t *work, int lwork, int *info, // Gels flavor + int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor + value_t *s, // Gelss flavor + int *iwork // Gelsd flavor + ) { + lapackGelsd( + m, n, nrhs, + a, lda, b, ldb, + s, rcond, rank, + work, lwork, + rwork, iwork, info); + } +}; + +template +struct lapackLstsq_impl { + static void call( + char trans, int m, int n, int nrhs, + scalar_t *a, int lda, scalar_t *b, int ldb, + scalar_t *work, int lwork, int *info, // Gels flavor + int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor + value_t *s, // Gelss flavor + int *iwork // Gelsd flavor + ) { + lapackGelss( + m, n, nrhs, + a, lda, b, ldb, + s, rcond, rank, + work, lwork, + rwork, info); + } +}; + +template +void lapackLstsq( + char trans, int m, int n, int nrhs, + scalar_t *a, int lda, scalar_t *b, int ldb, + scalar_t *work, int lwork, int *info, // Gels flavor + int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor + value_t *s, // Gelss flavor + int *iwork // Gelsd flavor + ) { + lapackLstsq_impl::call( + trans, m, n, nrhs, + a, lda, b, ldb, + work, lwork, info, + jpvt, rcond, rank, rwork, + s, + iwork); +} + +template +void lapackLuSolve(char trans, int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info); + +template +void lapackLu(int m, int n, scalar_t *a, int lda, int *ipiv, int *info); + +template +void lapackLdlHermitian( + char uplo, + int n, + scalar_t* a, + int lda, + int* ipiv, + scalar_t* work, + int lwork, + int* info); + +template +void lapackLdlSymmetric( + char uplo, + int n, + scalar_t* a, + int lda, + int* ipiv, + scalar_t* work, + int lwork, + int* info); + +template +void lapackLdlSolveHermitian( + char uplo, + int n, + int nrhs, + scalar_t* a, + int lda, + int* ipiv, + scalar_t* b, + int ldb, + int* info); + +template +void lapackLdlSolveSymmetric( + char uplo, + int n, + int nrhs, + scalar_t* a, + int lda, + int* ipiv, + scalar_t* b, + int ldb, + int* info); + +template +void lapackSvd(char jobz, int m, int n, scalar_t *a, int lda, value_t *s, scalar_t *u, int ldu, scalar_t *vt, int ldvt, scalar_t *work, int lwork, value_t *rwork, int *iwork, int *info); +#endif + +#if AT_BUILD_WITH_BLAS() +template +void blasTriangularSolve(char side, char uplo, char trans, char diag, int n, int nrhs, scalar_t* a, int lda, scalar_t* b, int ldb); +#endif + +using cholesky_fn = void (*)(const Tensor& /*input*/, const Tensor& /*info*/, bool /*upper*/); +DECLARE_DISPATCH(cholesky_fn, cholesky_stub); + +using cholesky_inverse_fn = Tensor& (*)(Tensor& /*result*/, Tensor& /*infos*/, bool /*upper*/); + +DECLARE_DISPATCH(cholesky_inverse_fn, cholesky_inverse_stub); + +using linalg_eig_fn = void (*)(Tensor& /*eigenvalues*/, Tensor& /*eigenvectors*/, Tensor& /*infos*/, const Tensor& /*input*/, bool /*compute_eigenvectors*/); + +DECLARE_DISPATCH(linalg_eig_fn, linalg_eig_stub); + +using geqrf_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/); +DECLARE_DISPATCH(geqrf_fn, geqrf_stub); + +using orgqr_fn = Tensor& (*)(Tensor& /*result*/, const Tensor& /*tau*/); +DECLARE_DISPATCH(orgqr_fn, orgqr_stub); + +using ormqr_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/, const Tensor& /*other*/, bool /*left*/, bool /*transpose*/); +DECLARE_DISPATCH(ormqr_fn, ormqr_stub); + +using linalg_eigh_fn = void (*)( + const Tensor& /*eigenvalues*/, + const Tensor& /*eigenvectors*/, + const Tensor& /*infos*/, + bool /*upper*/, + bool /*compute_eigenvectors*/); +DECLARE_DISPATCH(linalg_eigh_fn, linalg_eigh_stub); + +using lstsq_fn = void (*)( + const Tensor& /*a*/, + Tensor& /*b*/, + Tensor& /*rank*/, + Tensor& /*singular_values*/, + Tensor& /*infos*/, + double /*rcond*/, + std::string /*driver_name*/); +DECLARE_DISPATCH(lstsq_fn, lstsq_stub); + +using triangular_solve_fn = void (*)( + const Tensor& /*A*/, + const Tensor& /*B*/, + bool /*left*/, + bool /*upper*/, + TransposeType /*transpose*/, + bool /*unitriangular*/); +DECLARE_DISPATCH(triangular_solve_fn, triangular_solve_stub); + +using lu_factor_fn = void (*)( + const Tensor& /*input*/, + const Tensor& /*pivots*/, + const Tensor& /*infos*/, + bool /*compute_pivots*/); +DECLARE_DISPATCH(lu_factor_fn, lu_factor_stub); + +using unpack_pivots_fn = void(*)( + TensorIterator& iter, + const int64_t dim_size, + const int64_t max_pivot); +DECLARE_DISPATCH(unpack_pivots_fn, unpack_pivots_stub); + +using lu_solve_fn = void (*)( + const Tensor& /*LU*/, + const Tensor& /*pivots*/, + const Tensor& /*B*/, + TransposeType /*trans*/); +DECLARE_DISPATCH(lu_solve_fn, lu_solve_stub); + +using ldl_factor_fn = void (*)( + const Tensor& /*LD*/, + const Tensor& /*pivots*/, + const Tensor& /*info*/, + bool /*upper*/, + bool /*hermitian*/); +DECLARE_DISPATCH(ldl_factor_fn, ldl_factor_stub); + +using svd_fn = void (*)( + const Tensor& /*A*/, + const bool /*full_matrices*/, + const bool /*compute_uv*/, + const std::optional& /*driver*/, + const Tensor& /*U*/, + const Tensor& /*S*/, + const Tensor& /*Vh*/, + const Tensor& /*info*/); +DECLARE_DISPATCH(svd_fn, svd_stub); + +using ldl_solve_fn = void (*)( + const Tensor& /*LD*/, + const Tensor& /*pivots*/, + const Tensor& /*result*/, + bool /*upper*/, + bool /*hermitian*/); +DECLARE_DISPATCH(ldl_solve_fn, ldl_solve_stub); +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/BinaryOps.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/BinaryOps.h new file mode 100644 index 0000000000000000000000000000000000000000..8f3f8bcb7e68fb5f8cb77ffd003accd83801f0a8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/BinaryOps.h @@ -0,0 +1,119 @@ +#pragma once + +#include +#include +#include +#include + + +namespace at { +struct TensorIterator; +struct TensorIteratorBase; +} + +namespace at::native { + +inline void alpha_check(const ScalarType dtype, const Scalar& alpha) { + TORCH_CHECK(! alpha.isBoolean() || dtype == ScalarType::Bool, + "Boolean alpha only supported for Boolean results."); + TORCH_CHECK(isFloatingType(dtype) || isComplexType(dtype) + || alpha.isIntegral(true), + "For integral input tensors, argument alpha must not be a floating point number."); + TORCH_CHECK(isComplexType(dtype) || !alpha.isComplex(), + "For non-complex input tensors, argument alpha must not be a complex number.") +} + +// Basic checking for all sub functions. +inline void sub_check(const TensorBase& self, const TensorBase& other) { + TORCH_CHECK(self.scalar_type() != kBool || other.scalar_type() != kBool, + "Subtraction, the `-` operator, with two bool tensors is not supported. " + "Use the `^` or `logical_xor()` operator instead.") + TORCH_CHECK(self.scalar_type() != kBool && other.scalar_type() != kBool, + "Subtraction, the `-` operator, with a bool tensor is not supported. " + "If you are trying to invert a mask, use the `~` or `logical_not()` operator instead."); +} + +inline void sub_check(const TensorBase& self, const Scalar& scalar) { + TORCH_CHECK(self.scalar_type() != kBool || !scalar.isBoolean(), + "Subtraction, the `-` operator, with two bool tensors is not supported. " + "Use the `^` or `logical_xor()` operator instead.") + TORCH_CHECK(self.scalar_type() != kBool && !scalar.isBoolean(), + "Subtraction, the `-` operator, with a bool tensor is not supported. " + "If you are trying to invert a mask, use the `~` or `logical_not()` operator instead."); +} + +using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha); +using structured_binary_fn_double = void(*)(TensorIteratorBase&, double); +using structured_binary_fn = void(*)(TensorIteratorBase&); + +using binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha); +using binary_fn_double = void(*)(TensorIterator&, double); +using binary_fn = void(*)(TensorIterator&); +using binary_clamp_fn_alpha = + void(*)(TensorIterator&, const Scalar& alpha, const Scalar& min_val, const Scalar& max_val); + +// NB: codegenned +DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub); + +DECLARE_DISPATCH(binary_clamp_fn_alpha, add_clamp_stub); +DECLARE_DISPATCH(structured_binary_fn_alpha, sub_stub); +DECLARE_DISPATCH(structured_binary_fn, mul_stub); +DECLARE_DISPATCH(structured_binary_fn, div_true_stub); +DECLARE_DISPATCH(structured_binary_fn, div_floor_stub); +DECLARE_DISPATCH(structured_binary_fn, div_trunc_stub); +DECLARE_DISPATCH(structured_binary_fn, atan2_stub); +DECLARE_DISPATCH(structured_binary_fn, remainder_stub); +DECLARE_DISPATCH(structured_binary_fn, bitwise_and_stub); +DECLARE_DISPATCH(structured_binary_fn, bitwise_or_stub); +DECLARE_DISPATCH(structured_binary_fn, bitwise_xor_stub); +DECLARE_DISPATCH(structured_binary_fn, lshift_stub); +DECLARE_DISPATCH(structured_binary_fn, rshift_stub); +DECLARE_DISPATCH(binary_fn, logical_xor_stub); +DECLARE_DISPATCH(binary_fn, logical_and_stub); +DECLARE_DISPATCH(binary_fn, logical_or_stub); +DECLARE_DISPATCH(structured_binary_fn, lt_stub); +DECLARE_DISPATCH(structured_binary_fn, le_stub); +DECLARE_DISPATCH(structured_binary_fn, gt_stub); +DECLARE_DISPATCH(structured_binary_fn, ge_stub); +DECLARE_DISPATCH(structured_binary_fn, eq_stub); +DECLARE_DISPATCH(structured_binary_fn, ne_stub); +DECLARE_DISPATCH(binary_fn, max_elementwise_stub); +DECLARE_DISPATCH(binary_fn, min_elementwise_stub); +DECLARE_DISPATCH(structured_binary_fn, maximum_stub); +DECLARE_DISPATCH(structured_binary_fn, minimum_stub); +DECLARE_DISPATCH(structured_binary_fn, fmax_stub); +DECLARE_DISPATCH(structured_binary_fn, fmin_stub); +DECLARE_DISPATCH(structured_binary_fn_double, smooth_l1_stub); +DECLARE_DISPATCH(binary_fn_double, huber_stub); +DECLARE_DISPATCH(structured_binary_fn, sigmoid_backward_stub); +DECLARE_DISPATCH(binary_fn_alpha, logit_backward_stub); +DECLARE_DISPATCH(structured_binary_fn, tanh_backward_stub); +DECLARE_DISPATCH(structured_binary_fn, mse_stub); +DECLARE_DISPATCH(structured_binary_fn, fmod_stub); +DECLARE_DISPATCH(structured_binary_fn, logaddexp_stub); +DECLARE_DISPATCH(structured_binary_fn, logaddexp2_stub); +DECLARE_DISPATCH(structured_binary_fn, gcd_stub); +DECLARE_DISPATCH(structured_binary_fn, lcm_stub); +DECLARE_DISPATCH(structured_binary_fn, hypot_stub); +DECLARE_DISPATCH(structured_binary_fn, igamma_stub); +DECLARE_DISPATCH(structured_binary_fn, igammac_stub); +DECLARE_DISPATCH(structured_binary_fn, nextafter_stub); +DECLARE_DISPATCH(structured_binary_fn, heaviside_stub); +DECLARE_DISPATCH(structured_binary_fn, copysign_stub); +DECLARE_DISPATCH(structured_binary_fn, xlogy_stub); +DECLARE_DISPATCH(structured_binary_fn, xlog1py_stub); +DECLARE_DISPATCH(structured_binary_fn, zeta_stub); +DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_t_stub); +DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_u_stub); +DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_v_stub); +DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_w_stub); +DECLARE_DISPATCH(structured_binary_fn, hermite_polynomial_h_stub); +DECLARE_DISPATCH(structured_binary_fn, hermite_polynomial_he_stub); +DECLARE_DISPATCH(structured_binary_fn, laguerre_polynomial_l_stub); +DECLARE_DISPATCH(structured_binary_fn, legendre_polynomial_p_stub); +DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_t_stub); +DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_u_stub); +DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_v_stub); +DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_w_stub); + +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ComplexHelper.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ComplexHelper.h new file mode 100644 index 0000000000000000000000000000000000000000..1b09350aef6ebbef4bd339cdf911b2a68991a17c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ComplexHelper.h @@ -0,0 +1,97 @@ +#pragma once + +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include + +#include +#endif + +// WARNING: this header contains non-inline functions and should be only +// included from ONE cpp file + +namespace at::native { + +// View tensor with new dtype, storage offset, sizes and strides +inline Tensor view_tensor( + const Tensor &tensor, ScalarType dtype, + c10::SymInt offset, SymIntArrayRef sizes, SymIntArrayRef strides) { + Storage storage = tensor.storage(); + auto key_set = tensor.key_set().remove(DispatchKey::Conjugate); + auto new_tensor = detail::make_tensor( + c10::TensorImpl::VIEW, std::move(storage), key_set, scalarTypeToTypeMeta(dtype)); + auto * impl = new_tensor.unsafeGetTensorImpl(); + impl->set_sizes_and_strides(sizes, strides, offset); + return new_tensor; +} + +inline SymDimVector computeStrideForViewAsReal(SymIntArrayRef oldstride) { + SymDimVector res(oldstride.size() + 1); + for (const auto i : c10::irange(oldstride.size())) { + res[i] = oldstride[i] * 2; + } + res.back() = 1; + return res; +} + +inline Tensor _view_as_real_physical(const Tensor& self) { + TORCH_CHECK(self.is_complex(), "view_as_real is only supported for complex tensors"); + auto old_sizes = self.sym_sizes(); + SymDimVector new_sizes(old_sizes.size() + 1); + std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin()); + // last dimension will always have two elements containing the real and imag vals + new_sizes.back() = 2; + auto new_strides = computeStrideForViewAsReal(self.sym_strides()); + auto new_storage_offset = self.sym_storage_offset() * 2; + const auto float_type = c10::toRealValueType(self.scalar_type()); + auto real_tensor = view_tensor(self, float_type, std::move(new_storage_offset), new_sizes, new_strides); + return real_tensor; +} + +// expects as input a complex tensor and returns back a tensor +// with corresponding real dtype containing the complex values +// in the last two dimensions +Tensor view_as_real(const Tensor& self) { + TORCH_CHECK(!self.is_conj(), "view_as_real doesn't work on unresolved conjugated tensors. To resolve the conjugate tensor so you can view it as real, use self.resolve_conj(); however, be warned that the resulting tensor will NOT alias the original."); + return _view_as_real_physical(self); +} + +inline SymDimVector computeStrideForViewAsComplex(SymIntArrayRef oldstride) { + const auto dim = oldstride.size(); + TORCH_CHECK(dim > 0 && oldstride[dim - 1] == 1, "Tensor must have a last dimension with stride 1"); + + SymDimVector res(dim - 1); + for (const auto i : c10::irange(res.size())) { + TORCH_CHECK(oldstride[i] % 2 == 0, "Tensor must have a stride divisible by 2 for all but last dimension"); + res[i] = oldstride[i] / 2; + } + return res; +} + +// expects as input a float or double tensor with last dimension of size 2 +// and returns back a tensor with corresponding complex dtype +Tensor view_as_complex(const Tensor& self) { + TORCH_CHECK( + self.scalar_type() == kFloat || self.scalar_type() == kDouble || self.scalar_type() == kHalf, + "view_as_complex is only supported for half, float and double tensors, but got a tensor of scalar type: ", self.scalar_type()); + + auto old_sizes = self.sym_sizes(); + TORCH_CHECK(!old_sizes.empty(), "Input tensor must have one or more dimensions"); + TORCH_CHECK(old_sizes[old_sizes.size()-1] == 2, "Tensor must have a last dimension of size 2"); + SymDimVector new_sizes(old_sizes.begin(), old_sizes.end() - 1); + + const auto new_strides = computeStrideForViewAsComplex(self.sym_strides()); + const auto complex_type = c10::toComplexType(self.scalar_type()); + + TORCH_CHECK(self.sym_storage_offset() % 2 == 0, "Tensor must have a storage_offset divisible by 2"); + const auto new_storage_offset = self.sym_storage_offset() / 2; + + return view_tensor(self, complex_type, new_storage_offset, new_sizes, new_strides); +} + +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Distance.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Distance.h new file mode 100644 index 0000000000000000000000000000000000000000..c2d881ae66f6af001c255d23cb1acd613af70d5f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Distance.h @@ -0,0 +1,20 @@ +#pragma once + +#include + +namespace at { +class Tensor; + +namespace native { + +using pdist_forward_fn = void(*)(Tensor&, const Tensor&, const double p); +using pdist_backward_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const double p, const Tensor&); +using cdist_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const double p); +using cdist_backward_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const Tensor&, const double p, const Tensor&); + +DECLARE_DISPATCH(pdist_forward_fn, pdist_forward_stub); +DECLARE_DISPATCH(pdist_backward_fn, pdist_backward_stub); +DECLARE_DISPATCH(cdist_fn, cdist_stub); +DECLARE_DISPATCH(cdist_backward_fn, cdist_backward_stub); + +}} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Fill.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Fill.h new file mode 100644 index 0000000000000000000000000000000000000000..f6de9580ae7c33340d2929c4c5f743e4aaf42339 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Fill.h @@ -0,0 +1,21 @@ +// Functions that fill Tensors with constants. Implementations are in Fill.cpp. + +#pragma once + +#include + +namespace c10 { +class Scalar; +} + +namespace at { +class Tensor; +struct TensorIterator; + +namespace native { + +DECLARE_DISPATCH(void(*)(TensorIterator&, const c10::Scalar&), fill_stub); + +Tensor& fill_out(Tensor& self, const Scalar& value); + +}} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/FractionalMaxPooling.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/FractionalMaxPooling.h new file mode 100644 index 0000000000000000000000000000000000000000..58c07ac63d72e3eff7554513584e206aaa179978 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/FractionalMaxPooling.h @@ -0,0 +1,80 @@ +#pragma once +#include +#include +#include + +namespace at::native { + +template +inline std::vector generate_intervals( + scalar_t sample, + int64_t inputSize, + int64_t outputSize, + int64_t poolSize) { + std::vector sequence(outputSize); + if (outputSize > 1) { + scalar_t alpha = static_cast(inputSize - poolSize) / + static_cast(outputSize - 1); + + for (const auto i : c10::irange(outputSize - 1)) { + sequence[i] = + static_cast((i + sample) * alpha) - static_cast(sample * alpha); + } + } + if (outputSize > 0) { + sequence[outputSize - 1] = inputSize - poolSize; + } + return sequence; +} + +template +inline void fractional_max_pool_check_shape( + const Tensor& input, + const Tensor& randomSamples) { + + TORCH_CHECK( + input.scalar_type() == randomSamples.scalar_type(), + "Expect _random_samples to have the same dtype as input"); + + int64_t ndimension = randomSamples.ndimension(); + TORCH_CHECK( + ndimension == 3, + "Expect _random_samples to have 3 dimensions, got ", ndimension); + + int64_t N = randomSamples.size(0); + int64_t C = randomSamples.size(1); + int64_t D = randomSamples.size(2); + + int64_t input_batch = 0, input_channel = 0; + if (ndim == 2) { + // fractional_max_pool2d + if (input.ndimension() == 3) { + input_batch = 1; + input_channel = input.size(0); + } else { + input_batch = input.size(0); + input_channel = input.size(1); + } + } else { + // factional_max_pool3d + if (input.ndimension() == 4) { + input_batch = 1; + input_channel = input.size(0); + } else { + input_batch = input.size(0); + input_channel = input.size(1); + } + } + + TORCH_CHECK( + N >= input_batch, + "Expect _random_samples.size(0) no less then input batch size."); + TORCH_CHECK( + C == input_channel, + "Expect _random_samples.size(1) equals to input channel size."); + TORCH_CHECK( + D == ndim, + "Expect _random_samples.size(2) equals to ", ndim, "; got ", D, "."); +} + +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/FunctionOfAMatrixUtils.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/FunctionOfAMatrixUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..68b26ed1381133db9de0ba7cb2187578fb7d680d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/FunctionOfAMatrixUtils.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include + +namespace at { +struct TensorIterator; + +namespace native { + +using _compute_linear_combination_fn = void(*)( + TensorIterator& iter, + int64_t in_stride, + int64_t coeff_stride, + int64_t num_summations +); + +DECLARE_DISPATCH(_compute_linear_combination_fn, _compute_linear_combination_stub); + +}} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/GridSampler.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/GridSampler.h new file mode 100644 index 0000000000000000000000000000000000000000..509a305fe4b5ed33c128b06fec8473816eeca46a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/GridSampler.h @@ -0,0 +1,298 @@ +#pragma once + +#include +#include +#include +#include + +#include + +namespace at::native { + +using detail::GridSamplerInterpolation; +using detail::GridSamplerPadding; + +// Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value, +// where we view each pixel as an area between (idx - 0.5) and (idx + 0.5). +// if align_corners: -1 and +1 get sent to the centers of the corner pixels +// -1 --> 0 +// +1 --> (size - 1) +// scale_factor = (size - 1) / 2 +// if not align_corners: -1 and +1 get sent to the image edges +// -1 --> -0.5 +// +1 --> (size - 1) + 0.5 == size - 0.5 +// scale_factor = size / 2 +template +static inline scalar_t grid_sampler_unnormalize(scalar_t coord, int64_t size, + bool align_corners) { + if (align_corners) { + // unnormalize coord from [-1, 1] to [0, size - 1] + return ((coord + 1) / 2) * (size - 1); + } else { + // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] + return ((coord + 1) * size - 1) / 2; + } +} + +// grid_sampler_unnormalize_set_grad works the same as grid_sampler_unnormalize +// except that it also returns the `d output / d input` via pointer argument +// `grad_in`. +// This is useful in the backward pass of grid_sampler. +template +static inline scalar_t grid_sampler_unnormalize_set_grad(scalar_t coord, int64_t size, + bool align_corners, scalar_t *grad_in) { + if (align_corners) { + // unnormalize coord from [-1, 1] to [0, size - 1] + *grad_in = static_cast(size - 1) / 2; + return ((coord + 1) / 2) * (size - 1); + } else { + // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] + *grad_in = static_cast(size) / 2; + return ((coord + 1) * size - 1) / 2; + } +} + +// Clips coordinates to between 0 and clip_limit - 1 +template +static inline scalar_t clip_coordinates(scalar_t in, int64_t clip_limit) { + return std::min(static_cast(clip_limit - 1), std::max(in, static_cast(0))); +} + +// clip_coordinates_set_grad works similarly to clip_coordinates except that +// it also returns the `d output / d input` via pointer argument `grad_in`. +// This is useful in the backward pass of grid_sampler. +template +static inline scalar_t clip_coordinates_set_grad(scalar_t in, int64_t clip_limit, + scalar_t *grad_in) { + // Note that it is important for the gradient calculation that borders + // are considered out of bounds. + if (in <= static_cast(0)) { + *grad_in = static_cast(0); + return static_cast(0); + } else { + scalar_t max = static_cast(clip_limit - 1); + if (in >= max) { + *grad_in = static_cast(0); + return max; + } else { + *grad_in = static_cast(1); + return in; + } + } +} + +// Reflects coordinates until they fall between low and high (inclusive). +// The bounds are passed as twice their value so that half-integer values +// can be represented as ints. +template +static inline scalar_t reflect_coordinates(scalar_t in, int64_t twice_low, + int64_t twice_high) { + if (twice_low == twice_high) { + return static_cast(0); + } + scalar_t min = static_cast(twice_low) / 2; + scalar_t span = static_cast(twice_high - twice_low) / 2; + in = std::fabs(in - min); + // `fmod` returns same sign as `in`, which is positive after the `fabs` above. + scalar_t extra = std::fmod(in, span); + int flips = static_cast(std::floor(in / span)); + if (flips % 2 == 0) { + return extra + min; + } else { + return span - extra + min; + } +} + +// reflect_coordinates_set_grad works similarly to reflect_coordinates except +// that it also returns the `d output / d input` via pointer argument +// `grad_in`. +// This is useful in the backward pass of grid_sampler. +template +static inline scalar_t reflect_coordinates_set_grad(scalar_t in, int64_t twice_low, + int64_t twice_high, scalar_t *grad_in) { + if (twice_low == twice_high) { + *grad_in = static_cast(0); + return static_cast(0); + } + int grad_in_mult_; + scalar_t min = static_cast(twice_low) / 2; + scalar_t span = static_cast(twice_high - twice_low) / 2; + in = in - min; + if (in < static_cast(0)) { + grad_in_mult_ = -1; + in = -in; + } else { + grad_in_mult_ = 1; + } + // `fmod` returns same sign as `in`, which is positive after the `if` above. + scalar_t extra = std::fmod(in, span); + int flips = static_cast(std::floor(in / span)); + if (flips % 2 == 0) { + *grad_in = static_cast(grad_in_mult_); + return extra + min; + } else { + *grad_in = static_cast(-grad_in_mult_); + return span - extra + min; + } +} + +// Mapping the out-of-boundary points back into boundary +// This would only affect padding_mode=border or reflection +template +static inline scalar_t compute_coordinates(scalar_t coord, int64_t size, + GridSamplerPadding padding_mode, + bool align_corners) { + if (padding_mode == GridSamplerPadding::Border) { + // clip coordinates to image borders + coord = clip_coordinates(coord, size); + } else if (padding_mode == GridSamplerPadding::Reflection) { + // reflect coordinates by image borders + if (align_corners) { + coord = reflect_coordinates(coord, 0, 2*(size - 1)); + } else { + coord = reflect_coordinates(coord, -1, 2*size - 1); + } + // clip coordinates to image borders + coord = clip_coordinates(coord, size); + } + return coord; +} + +// Computes the pixel source index value for a grid coordinate +template +static inline scalar_t grid_sampler_compute_source_index( + scalar_t coord, + int64_t size, + GridSamplerPadding padding_mode, + bool align_corners) { + coord = grid_sampler_unnormalize(coord, size, align_corners); + coord = compute_coordinates(coord, size, padding_mode, align_corners); + return coord; +} + +// grid_sampler_compute_source_index_set_grad works similarly to +// grid_sampler_compute_source_index except that it also returns the +// `d output / d input` via pointer argument `grad_in`. +// This is useful in the backward pass of grid_sampler. +template +static inline scalar_t grid_sampler_compute_source_index_set_grad( + scalar_t coord, + int64_t size, + GridSamplerPadding padding_mode, + bool align_corners, + scalar_t *grad_in) { + scalar_t grad_clip, grad_refl; + coord = grid_sampler_unnormalize_set_grad(coord, size, align_corners, grad_in); + if (padding_mode == GridSamplerPadding::Border) { + // clip coordinates to image borders + coord = clip_coordinates_set_grad(coord, size, &grad_clip); + *grad_in = (*grad_in) * grad_clip; + } else if (padding_mode == GridSamplerPadding::Reflection) { + // reflect coordinates by image borders + if (align_corners) { + coord = reflect_coordinates_set_grad(coord, 0, 2*(size - 1), &grad_refl); + } else { + coord = reflect_coordinates_set_grad(coord, -1, 2*size - 1, &grad_refl); + } + // clip coordinates to image borders + coord = clip_coordinates_set_grad(coord, size, &grad_clip); + *grad_in = (*grad_in) * grad_refl * grad_clip; + } + return coord; +} + +static inline bool within_bounds_2d(int64_t h, int64_t w, int64_t H, int64_t W) { + return h >= 0 && h < H && w >= 0 && w < W; +} + +static inline bool within_bounds_3d(int64_t d, int64_t h, int64_t w, int64_t D, int64_t H, int64_t W) { + return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W; +} + +template +static inline scalar_t get_value_bounded( + const scalar_t* data, + scalar_t x, + scalar_t y, + int64_t W, + int64_t H, + int64_t sW, + int64_t sH, + GridSamplerPadding padding_mode, + bool align_corners) { + + x = compute_coordinates(x, W, padding_mode, align_corners); + y = compute_coordinates(y, H, padding_mode, align_corners); + + int64_t ix = static_cast(x); + int64_t iy = static_cast(y); + + if (within_bounds_2d(iy, ix, H, W)) { + return data[iy * sH + ix * sW]; + } + return static_cast(0); +} + +template +static inline void safe_add_2d(scalar_t *data, int64_t h, int64_t w, + int64_t sH, int64_t sW, int64_t H, int64_t W, + scalar_t delta) { + if (within_bounds_2d(h, w, H, W)) { + data[h * sH + w * sW] += delta; + } +} + +template +static inline void safe_add_3d(scalar_t *data, int64_t d, int64_t h, int64_t w, + int64_t sD, int64_t sH, int64_t sW, + int64_t D, int64_t H, int64_t W, + scalar_t delta) { + if (within_bounds_3d(d, h, w, D, H, W)) { + data[d * sD + h * sH + w * sW] += delta; + } +} + +template +static inline void add_value_bounded( + scalar_t* data, + scalar_t x, + scalar_t y, + int64_t W, + int64_t H, + int64_t sW, + int64_t sH, + scalar_t delta, + GridSamplerPadding padding_mode, + bool align_corners) { + + x = compute_coordinates(x, W, padding_mode, align_corners); + y = compute_coordinates(y, H, padding_mode, align_corners); + + int64_t ix = static_cast(x); + int64_t iy = static_cast(y); + + safe_add_2d(data, iy, ix, sH, sW, H, W, delta); +} + +// Calculate the differential of the cubic convolution, i.e. `d coeff / d x` +template +static inline void get_cubic_coefficients_grad( + scalar_t coeffs[4], + scalar_t t) { + + // Must be the same as forward calculation in + // aten/src/ATen/native/UpSample.h:get_cubic_upsample_coefficients + scalar_t A = -0.75; + + scalar_t x; + x = -1 - t; // 1 < x = |-1 - tx| < 2 + coeffs[0] = (-3 * A * x - 10 * A ) * x - 8 * A; + x = -t; // x = |0 - tx| <= 1 + coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x; + x = 1 - t; // x = |1 - tx| <= 1 + coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x; + x = 2 - t; // 1 < x = |2 - tx| < 2 + coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A; +} + +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/LossMulti.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/LossMulti.h new file mode 100644 index 0000000000000000000000000000000000000000..8877b05a54cc380e99bfce40ef9b9b05b0031c49 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/LossMulti.h @@ -0,0 +1,69 @@ +#pragma once +#include +#include +#include +#include + +namespace at::native { + inline void multilabel_margin_loss_shape_check( + int64_t& nframe, + int64_t& dim, + const int64_t& ndims, + const Tensor& input, + const Tensor& target) { + TORCH_CHECK( + (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0, + "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ", + input.sizes()); + + if (ndims <= 1) { + nframe = 1; + dim = ndims == 0 ? 1 : input.size(0); + TORCH_CHECK( + target.dim() <= 1 && target.numel() == dim, + "inconsistent target size: ", target.sizes(), " for input of size: ", + input.sizes()); + } else { + nframe = input.size(0); + dim = input.size(1); + TORCH_CHECK( + target.dim() == 2 && target.size(0) == nframe && + target.size(1) == dim, + "inconsistent target size: ", target.sizes(), " for input of size: ", + input.sizes()); + } + } + + inline void multi_margin_loss_shape_check( + int64_t& nframe, + int64_t& dim, + const int64_t& ndims, + const Tensor& input, + const Tensor& target, + const std::optional& weight) { + TORCH_CHECK( + (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0, + "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ", + input.sizes()); + + if (ndims <= 1) { + nframe = 1; + dim = ndims == 0 ? 1 : input.size(0); + } else { + nframe = input.size(0); + dim = input.size(1); + } + + TORCH_CHECK( + target.dim() <= 1 && target.numel() == nframe, + "inconsistent target size, expected ", nframe, " but got ", + target.sizes()); + if (weight && weight->defined()) { + TORCH_CHECK( + weight->dim() <= 1 && weight->numel() == dim, + "inconsistent weight size, expected ", dim, " but got ", + weight->sizes()); + } +} + +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Math.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Math.h new file mode 100644 index 0000000000000000000000000000000000000000..e86a9aea411af523af3c94e8a3ec03f61da47d6e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Math.h @@ -0,0 +1,3901 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") +#endif + +/* The next function is taken from https://github.com/antelopeusersgroup/antelope_contrib/blob/master/lib/location/libgenloc/erfinv.c. +Below is the copyright. +Output was modified to be inf or -inf when input is 1 or -1. */ + + +/* + Copyright (c) 2014 Indiana University + All rights reserved. + + Written by Prof. Gary L. Pavlis, Dept. of Geol. Sci., + Indiana University, Bloomington, IN + + This software is licensed under the New BSD license: + + Redistribution and use in source and binary forms, + with or without modification, are permitted provided + that the following conditions are met: + + Redistributions of source code must retain the above + copyright notice, this list of conditions and the + following disclaimer. + + Redistributions in binary form must reproduce the + above copyright notice, this list of conditions and + the following disclaimer in the documentation and/or + other materials provided with the distribution. + + Neither the name of Indiana University nor + the names of its contributors may be used to endorse + or promote products derived from this software without + specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND + CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED + WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A + PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL + THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY + DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF + USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER + IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE + USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + POSSIBILITY OF SUCH DAMAGE. +*/ + +namespace { +/* + * This function is derived from the implementation of the i0e function in the + * Cephes Math Library. See note [3-Clause BSD License for the Cephes Math + * Library]. + * + * Computes an approximation of the exponentially scaled zeroth order modified + * Bessel function of the first kind. The approximation is actually two + * (sub)approximations, both using a Chebyshev polynomial expansion. One + * approximates the function over [0, 8], and the other over (8, infinity). This + * function takes the absolute value of all inputs to convert them into the + * domain of the approximation. + */ +jiterator_also_stringify_as(jiterator_code( + template + JITERATOR_HOST_DEVICE T chbevl(T x, const T array[], const int len) { + T b0, b1, b2; + + b0 = array[0]; + b1 = 0; + + for (int i = 1; i < len; ++i) { + b2 = b1; + b1 = b0; + b0 = x * b1 - b2 + array[i]; + } + + return T{0.5} * (b0 - b2); + } + + template + JITERATOR_HOST_DEVICE T calc_i0e(T _x) { + T x = std::fabs(_x); + + if (x <= T{8.0}) { + static const T coefficients[] = { + -4.41534164647933937950E-18, 3.33079451882223809783E-17, + -2.43127984654795469359E-16, 1.71539128555513303061E-15, + -1.16853328779934516808E-14, 7.67618549860493561688E-14, + -4.85644678311192946090E-13, 2.95505266312963983461E-12, + -1.72682629144155570723E-11, 9.67580903537323691224E-11, + -5.18979560163526290666E-10, 2.65982372468238665035E-9, + -1.30002500998624804212E-8, 6.04699502254191894932E-8, + -2.67079385394061173391E-7, 1.11738753912010371815E-6, + -4.41673835845875056359E-6, 1.64484480707288970893E-5, + -5.75419501008210370398E-5, 1.88502885095841655729E-4, + -5.76375574538582365885E-4, 1.63947561694133579842E-3, + -4.32430999505057594430E-3, 1.05464603945949983183E-2, + -2.37374148058994688156E-2, 4.93052842396707084878E-2, + -9.49010970480476444210E-2, 1.71620901522208775349E-1, + -3.04682672343198398683E-1, 6.76795274409476084995E-1}; + + T y = (x / T{2.0}) - T{2.0}; + return chbevl(y, coefficients, int{30}); + } + + // x > 8 + static const T coefficients[] = { + -7.23318048787475395456E-18, -4.83050448594418207126E-18, + 4.46562142029675999901E-17, 3.46122286769746109310E-17, + -2.82762398051658348494E-16, -3.42548561967721913462E-16, + 1.77256013305652638360E-15, 3.81168066935262242075E-15, + -9.55484669882830764870E-15, -4.15056934728722208663E-14, + 1.54008621752140982691E-14, 3.85277838274214270114E-13, + 7.18012445138366623367E-13, -1.79417853150680611778E-12, + -1.32158118404477131188E-11, -3.14991652796324136454E-11, + 1.18891471078464383424E-11, 4.94060238822496958910E-10, + 3.39623202570838634515E-9, 2.26666899049817806459E-8, + 2.04891858946906374183E-7, 2.89137052083475648297E-6, + 6.88975834691682398426E-5, 3.36911647825569408990E-3, + 8.04490411014108831608E-1}; + + return chbevl(T{32.0} / x - T{2.0}, coefficients, int{25}) / std::sqrt(x); + }), + i0e_string); // i0e_string +} + +#define CENTRAL_RANGE 0.7 + +template +inline typename std::enable_if::value, T>::type +calc_erfinv(T y) { +/* Function to calculate inverse error function. Rational approximation +is used to generate an initial approximation, which is then improved to +full accuracy by two steps of Newton's method. Code is a direct +translation of the erfinv m file in matlab version 2.0. +Author: Gary L. Pavlis, Indiana University +Date: February 1996 +*/ + T x, z, num, dem; /*working variables */ + /* coefficients in rational expansion */ + T a[4] = { T(0.886226899), T(-1.645349621), T(0.914624893), T(-0.140543331) }; + T b[4] = { T(-2.118377725), T(1.442710462), T(-0.329097515), T(0.012229801) }; + T c[4] = { T(-1.970840454), T(-1.624906493), T(3.429567803), T(1.641345311) }; + T d[2] = { T(3.543889200), T(1.637067800) }; + T y_abs = std::abs(y); + if(y_abs > 1.0) return std::numeric_limits::quiet_NaN(); +#ifdef _WIN32 + // error C2039: '_copysign': is not a member of 'std' + if(y_abs == 1.0) return copysign(std::numeric_limits::infinity(), y); +#else + if(y_abs == 1.0) return std::copysign(std::numeric_limits::infinity(), y); +#endif + if(y_abs <= static_cast(CENTRAL_RANGE)) { + z = y * y; + num = (((a[3]*z + a[2])*z + a[1])*z + a[0]); + dem = ((((b[3]*z + b[2])*z + b[1])*z +b[0]) * z + static_cast(1.0)); + x = y * num / dem; + } + else{ + z = std::sqrt(-std::log((static_cast(1.0)-y_abs)/static_cast(2.0))); + num = ((c[3]*z + c[2])*z + c[1]) * z + c[0]; + dem = (d[1]*z + d[0])*z + static_cast(1.0); +#ifdef _WIN32 + // error C2039: '_copysign': is not a member of 'std' + x = copysign(num, y) / dem; +#else + x = std::copysign(num, y) / dem; +#endif + } + /* Two steps of Newton-Raphson correction */ + x = x - (std::erf(x) - y) / ((static_cast(2.0)/static_cast(std::sqrt(c10::pi)))*std::exp(-x*x)); + x = x - (std::erf(x) - y) / ((static_cast(2.0)/static_cast(std::sqrt(c10::pi)))*std::exp(-x*x)); + + return(x); +} + +#undef CENTRAL_RANGE + +/* + * Note [3-Clause BSD License for the Cephes Math Library] + * Code derived from implementations in the Cephes Math Library should mention its derivation and reference + * this note (ex. 'This function is derived from the implementation of X in the Cephes Math Library. See note + * [3-Clause BSD License for the Cephes Math Library]. The license is: + * Copyright (c) 2018, Steven Moshier + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL Steven Moshier BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +/* + * This function is derived from the implementation of the zeta function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + */ +template +C10_HOST_DEVICE inline scalar_t zeta(scalar_t x, scalar_t q) __ubsan_ignore_float_divide_by_zero__ { + using acc_t = at::acc_type; + const acc_t MACHEP = acc_t{1.11022302462515654042E-16}; + constexpr acc_t zero = acc_t{0.0}; + constexpr acc_t half = acc_t{0.5}; + constexpr acc_t one = acc_t{1.0}; + static const acc_t A[] = { + 12.0, + -720.0, + 30240.0, + -1209600.0, + 47900160.0, + -1.8924375803183791606e9, /*1.307674368e12/691*/ + 7.47242496e10, + -2.950130727918164224e12, /*1.067062284288e16/3617*/ + 1.1646782814350067249e14, /*5.109094217170944e18/43867*/ + -4.5979787224074726105e15, /*8.028576626982912e20/174611*/ + 1.8152105401943546773e17, /*1.5511210043330985984e23/854513*/ + -7.1661652561756670113e18 /*1.6938241367317436694528e27/236364091*/ + }; + + int i = 0; + acc_t a, b, k, s, t, w; + if (x == one) { + return std::numeric_limits::infinity(); + } + + if (x < one) { + return std::numeric_limits::quiet_NaN(); + } + + if (q <= zero) { + if (q == std::floor(q)) { + return std::numeric_limits::infinity(); + } + if (x != std::floor(x)) { + return std::numeric_limits::quiet_NaN(); + } + } + + s = std::pow(q, -x); + a = q; + i = 0; + b = zero; + while ((i < 9) || (a <= acc_t{9.0})) { + i += 1; + a += one; + b = ::pow(a, -x); + s += b; + if ((-MACHEP * s < b) && (b < MACHEP * s)) { + return static_cast(s); + } + }; + + w = a; + s += b * w / (x - one); + s -= half * b; + a = one; + k = zero; + for (int i = 0; i < 12; i++) { + a *= x + k; + b /= w; + t = a * b / A[i]; + s = s + t; + t = ::fabs(t / s); + if (t < MACHEP) { + return static_cast(s); + } + k += one; + a *= x + k; + b /= w; + k += one; + } + return static_cast(s); +} + +/* + * This function is derived from the implementation of the digamma function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + * + * Evaluates polynomial of degree N: + * + * 2 N + * y = C + C x + C x +...+ C x + * 0 1 2 N + * + * Coefficients are stored in reverse order: + * + * coef[0] = C , ..., coef[N] = C . + * N 0 + */ +template +C10_HOST_DEVICE inline T polevl(const T x, const T A[], size_t len) { + T result = 0; + for (size_t i = 0; i <= len; i++) { + result = result * x + A[i]; + } + return result; +} + +inline double trigamma(double x) __ubsan_ignore_float_divide_by_zero__ { + double sign = +1; + double result = 0; + if (x < 0.5) { + sign = -1; + const double sin_pi_x = sin(c10::pi * x); + result -= (c10::pi * c10::pi) / (sin_pi_x * sin_pi_x); + x = 1 - x; + } + for (int i = 0; i < 6; ++i) { + result += 1 / (x * x); + x += 1; + } + const double ixx = 1 / (x*x); + result += (1 + 1 / (2*x) + ixx * (1./6 - ixx * (1./30 - ixx * (1./42)))) / x; + return sign * result; +} + +inline float trigamma(float x) __ubsan_ignore_float_divide_by_zero__ { + float sign = +1; + float result = 0; + if (x < 0.5f) { + sign = -1; + const float sin_pi_x = sinf(c10::pi * x); + result -= (c10::pi * c10::pi) / (sin_pi_x * sin_pi_x); + x = 1 - x; + } + for (int i = 0; i < 6; ++i) { + result += 1 / (x * x); + x += 1; + } + const float ixx = 1 / (x*x); + result += (1 + 1 / (2*x) + ixx * (1.f/6 - ixx * (1.f/30 - ixx * (1.f/42)))) / x; + return sign * result; +} + +/* + * This function is derived from the implementation of the digamma function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + */ +inline double calc_digamma(double x) { + // [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma + static double PSI_10 = 2.25175258906672110764; + if (x == 0) { + // As per C++ standard for gamma related functions and SciPy, + // If the argument is ±0, ±∞ is returned + return std::copysign(INFINITY, -x); + } + + bool x_is_integer = x == trunc(x); + if (x < 0) { + if (x_is_integer) { + // As per C++ standard for gamma related functions and SciPy, + // If the argument is a negative integer, NaN is returned + return std::numeric_limits::quiet_NaN(); + } + // Extracts the fractional part of x as r, since tan(pi * r) is more numerically + // accurate than tan(pi * x). While these operations are mathematically equivalent + // since both x and r are in radians and tan() has a periodicity of pi, in practice + // the computation of pi * x is a source of error (when |x| > 1). + double q, r; + r = std::modf(x, &q); + return calc_digamma(1 - x) - c10::pi / tan(c10::pi * r); + } + + // Push x to be >= 10 + double result = 0; + while (x < 10) { + result -= 1 / x; + x += 1; + } + if (x == 10) { + return result + PSI_10; + } + + // Compute asymptotic digamma + static const double A[] = { + 8.33333333333333333333E-2, + -2.10927960927960927961E-2, + 7.57575757575757575758E-3, + -4.16666666666666666667E-3, + 3.96825396825396825397E-3, + -8.33333333333333333333E-3, + 8.33333333333333333333E-2, + }; + + double y = 0; + if (x < 1.0e17) { + double z = 1.0 / (x * x); + y = z * polevl(z, A, 6); + } + return result + log(x) - (0.5 / x) - y; +} + +/* + * This function is derived from the implementation of the digamma function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + */ +inline float calc_digamma(float x) { + // See [C++ Standard Reference: Gamma Function] + static float PSI_10 = 2.25175258906672110764f; + if (x == 0) { + // As per C++ standard for gamma related functions and SciPy, + // If the argument is ±0, ±∞ is returned + return std::copysign(INFINITY, -x); + } + + bool x_is_integer = x == truncf(x); + if (x < 0) { + if (x_is_integer) { + // As per C++ standard for gamma related functions and SciPy, + // If the argument is a negative integer, NaN is returned + return std::numeric_limits::quiet_NaN(); + } + // Extracts the fractional part of x as r, since tan(pi * r) is more numerically + // accurate than tan(pi * x). While these operations are mathematically equivalent + // since both x and r are in radians and tan() has a periodicity of pi, in practice + // the computation of pi * x is a source of error (when |x| > 1). + double q, r; + r = std::modf(x, &q); + float pi_over_tan_pi_x = (float)(c10::pi / tan(c10::pi * r)); + return calc_digamma(1 - x) - pi_over_tan_pi_x; + } + + // Push x to be >= 10 + float result = 0; + while (x < 10) { + result -= 1 / x; + x += 1; + } + if (x == 10) { + return result + PSI_10; + } + + // Compute asymptotic digamma + static const float A[] = { + 8.33333333333333333333E-2f, + -2.10927960927960927961E-2f, + 7.57575757575757575758E-3f, + -4.16666666666666666667E-3f, + 3.96825396825396825397E-3f, + -8.33333333333333333333E-3f, + 8.33333333333333333333E-2f, + }; + + float y = 0; + if (x < 1.0e17f) { + float z = 1 / (x * x); + y = z * polevl(z, A, 6); + } + return result + logf(x) - (0.5f / x) - y; +} + +inline c10::BFloat16 calc_digamma(c10::BFloat16 a) { + return calc_digamma(static_cast(a)); +} + +inline c10::Half calc_digamma(c10::Half a) { + return calc_digamma(static_cast(a)); +} + +template +inline C10_HOST_DEVICE scalar_t calc_polygamma(scalar_t x, int n) { + // already blocked if n <= 1 + const auto one = scalar_t{1}; + return ((n % 2) ? one : -one) * + std::exp(std::lgamma(static_cast(n) + one)) * + zeta(static_cast(n + 1), x); +} + +// regularized lower incomplete gamma +// the regularized lower, upper incomplete gamma, as well as their +// helper functions follow SciPy's implementation + +/* References + * [igam1] "The Digital Library of Mathematical Functions", dlmf.nist.gov + * [igam2] Maddock et al., "Incomplete Gamma Functions", + * https://www.boost.org/doc/libs/1_61_0/libs/math/doc/html/math_toolkit/sf_gamma/igamma.html + */ + +/* + * This implementation of the regularized incomplete gamma functions and + * their helper functions are derived from the implementation of SciPy's + * gammainc, Cephes's igam and igamc, and Boost's Lanczos approximations. + * See NOTICE for the licenses. + */ +template +scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M, + const scalar_t denom[], int64_t N) { + // evaluating rational function, i.e., the ratio of two polynomials + // the coefficients for numerator are given by `num` while coeffs for + // denumerator are given by `denom` + + int64_t i, dir; + scalar_t y, num_ans, denom_ans; + scalar_t absx = std::fabs(x); + const scalar_t *p; + + if (absx > 1) { + /* Evaluate as a polynomial in 1/x. */ + dir = -1; + p = num + M; + y = 1 / x; + } + else { + dir = 1; + p = num; + y = x; + } + + /* Evaluate the numerator */ + num_ans = *p; + p += dir; + for (i = 1; i <= M; i++) { + num_ans = num_ans * y + *p; + p += dir; + } + /* Evaluate the denominator */ + if (absx > 1) { + p = denom + N; + } + else { + p = denom; + } + + denom_ans = *p; + p += dir; + for (i = 1; i <= N; i++) { + denom_ans = denom_ans * y + *p; + p += dir; + } + if (absx > 1) { + i = N - M; + return std::pow(x, i) * num_ans / denom_ans; + } + else { + return num_ans / denom_ans; + } +} + +// SciPy's lanczos implementation is taken from Boost +/* (C) Copyright John Maddock 2006. + * Use, modification and distribution are subject to the + * Boost Software License, Version 1.0. See + * https://www.boost.org/LICENSE_1_0.txt or see NOTICE. + */ +template +static scalar_t lanczos_sum_expg_scaled(scalar_t x) { + // lanczos approximation + static const scalar_t lanczos_sum_expg_scaled_num[13] = { + 0.006061842346248906525783753964555936883222, + 0.5098416655656676188125178644804694509993, + 19.51992788247617482847860966235652136208, + 449.9445569063168119446858607650988409623, + 6955.999602515376140356310115515198987526, + 75999.29304014542649875303443598909137092, + 601859.6171681098786670226533699352302507, + 3481712.15498064590882071018964774556468, + 14605578.08768506808414169982791359218571, + 43338889.32467613834773723740590533316085, + 86363131.28813859145546927288977868422342, + 103794043.1163445451906271053616070238554, + 56906521.91347156388090791033559122686859 + }; + static const scalar_t lanczos_sum_expg_scaled_denom[13] = { + 1., + 66., + 1925., + 32670., + 357423., + 2637558., + 13339535., + 45995730., + 105258076., + 150917976., + 120543840., + 39916800., + 0. + }; + return ratevl(x, lanczos_sum_expg_scaled_num, + sizeof(lanczos_sum_expg_scaled_num) / sizeof(lanczos_sum_expg_scaled_num[0]) - 1, + lanczos_sum_expg_scaled_denom, + sizeof(lanczos_sum_expg_scaled_denom) / sizeof(lanczos_sum_expg_scaled_denom[0]) - 1); +} + +template +static scalar_t _igam_helper_fac(scalar_t a, scalar_t x) { + // compute x^a * exp(-a) / gamma(a) + // corrected from (15) and (16) in [igam2] by replacing exp(x - a) with + // exp(a - x). + + scalar_t ax, fac, res, num, numfac; + static scalar_t MAXLOG = std::is_same::value ? + 7.09782712893383996843E2 : 88.72283905206835; + static scalar_t EXP1 = 2.718281828459045; + static scalar_t lanczos_g = 6.024680040776729583740234375; + + if (std::fabs(a - x) > 0.4 * std::fabs(a)) { + ax = a * std::log(x) - x - std::lgamma(a); + if (ax < -MAXLOG) { + return 0.0; + } + return std::exp(ax); + } + + fac = a + lanczos_g - 0.5; + res = std::sqrt(fac / EXP1) / lanczos_sum_expg_scaled(a); + + if ((a < 200) && (x < 200)) { + res *= std::exp(a - x) * std::pow(x / fac, a); + } + else { + num = x - a - lanczos_g + 0.5; + numfac = num / fac; + res *= std::exp(a * (std::log1p(numfac) - numfac) + x * (0.5 - lanczos_g) / fac); + } + return res; +} + +template +static scalar_t _igam_helper_series(scalar_t a, scalar_t x) { + // Compute igam using DLMF 8.11.4. [igam1] + static scalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + static int MAXITER = 2000; + + int i; + scalar_t ans, ax, c, r; + + ax = _igam_helper_fac(a, x); + if (ax == 0.0) { + return 0.0; + } + + /* power series */ + r = a; + c = 1.0; + ans = 1.0; + + for (i = 0; i < MAXITER; i++) { + r += 1.0; + c *= x / r; + ans += c; + if (c <= MACHEP * ans) { + break; + } + } + return (ans * ax / a); +} + +template +static scalar_t _igamc_helper_series(scalar_t a, scalar_t x) { + // Compute igamc using DLMF 8.7.3 [igam1]. This is related to the series in + // _igam_helper_series but extra care is taken to avoid cancellation. + + int n; + scalar_t fac = 1; + scalar_t sum = 0; + scalar_t term, logx; + static scalar_t MAXITER = 2000; + static scalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + + for (n = 1; n < MAXITER; n++) { + fac *= -x / n; + term = fac / (a + n); + sum += term; + if (std::fabs(term) <= MACHEP * std::fabs(sum)) { + break; + } + } + + logx = std::log(x); + term = -std::expm1(a * logx - std::lgamma(1+a)); + return term - std::exp(a * logx - std::lgamma(a)) * sum; +} + +template +static scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) { + // Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1] + static const scalar_t d[25][25] = + {{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, + 1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, + 3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, + 8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9, + 1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10, + -2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11, + -5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13, + -1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16, + -1.9752288294349443e-15}, + {-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, + -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, + -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, + 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, + 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, + 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, + 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, + -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, + -4.13125571381061e-15}, + {4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, + 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, + -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, + -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, + -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, + 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, + 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, + 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, + 8.8592218725911273e-15}, + {6.4943415637860082e-4, 2.2947209362139918e-4, -4.6918949439525571e-4, + 2.6772063206283885e-4, -7.5618016718839764e-5, -2.3965051138672967e-7, + 1.1082654115347302e-5, -5.6749528269915966e-6, 1.4230900732435884e-6, + -2.7861080291528142e-11, -1.6958404091930277e-7, 8.0994649053880824e-8, + -1.9111168485973654e-8, 2.3928620439808118e-12, 2.0620131815488798e-9, + -9.4604966618551322e-10, 2.1541049775774908e-10, -1.388823336813903e-14, + -2.1894761681963939e-11, 9.7909989511716851e-12, -2.1782191880180962e-12, + 6.2088195734079014e-17, 2.126978363279737e-13, -9.3446887915174333e-14, + 2.0453671226782849e-14}, + {-8.618882909167117e-4, 7.8403922172006663e-4, -2.9907248030319018e-4, + -1.4638452578843418e-6, 6.6414982154651222e-5, -3.9683650471794347e-5, + 1.1375726970678419e-5, 2.5074972262375328e-10, -1.6954149536558306e-6, + 8.9075075322053097e-7, -2.2929348340008049e-7, 2.956794137544049e-11, + 2.8865829742708784e-8, -1.4189739437803219e-8, 3.4463580499464897e-9, + -2.3024517174528067e-13, -3.9409233028046405e-10, 1.8602338968504502e-10, + -4.356323005056618e-11, 1.2786001016296231e-15, 4.6792750266579195e-12, + -2.1492464706134829e-12, 4.9088156148096522e-13, -6.3385914848915603e-18, + -5.0453320690800944e-14}, + {-3.3679855336635815e-4, -6.9728137583658578e-5, 2.7727532449593921e-4, + -1.9932570516188848e-4, 6.7977804779372078e-5, 1.419062920643967e-7, + -1.3594048189768693e-5, 8.0184702563342015e-6, -2.2914811765080952e-6, + -3.252473551298454e-10, 3.4652846491085265e-7, -1.8447187191171343e-7, + 4.8240967037894181e-8, -1.7989466721743515e-14, -6.3061945000135234e-9, + 3.1624176287745679e-9, -7.8409242536974293e-10, 5.1926791652540407e-15, + 9.3589442423067836e-11, -4.5134262161632782e-11, 1.0799129993116827e-11, + -3.661886712685252e-17, -1.210902069055155e-12, 5.6807435849905643e-13, + -1.3249659916340829e-13}, + {5.3130793646399222e-4, -5.9216643735369388e-4, 2.7087820967180448e-4, + 7.9023532326603279e-7, -8.1539693675619688e-5, 5.6116827531062497e-5, + -1.8329116582843376e-5, -3.0796134506033048e-9, 3.4651553688036091e-6, + -2.0291327396058604e-6, 5.7887928631490037e-7, 2.338630673826657e-13, + -8.8286007463304835e-8, 4.7435958880408128e-8, -1.2545415020710382e-8, + 8.6496488580102925e-14, 1.6846058979264063e-9, -8.5754928235775947e-10, + 2.1598224929232125e-10, -7.6132305204761539e-16, -2.6639822008536144e-11, + 1.3065700536611057e-11, -3.1799163902367977e-12, 4.7109761213674315e-18, + 3.6902800842763467e-13}, + {3.4436760689237767e-4, 5.1717909082605922e-5, -3.3493161081142236e-4, + 2.812695154763237e-4, -1.0976582244684731e-4, -1.2741009095484485e-7, + 2.7744451511563644e-5, -1.8263488805711333e-5, 5.7876949497350524e-6, + 4.9387589339362704e-10, -1.0595367014026043e-6, 6.1667143761104075e-7, + -1.7562973359060462e-7, -1.2974473287015439e-12, 2.695423606288966e-8, + -1.4578352908731271e-8, 3.887645959386175e-9, -3.8810022510194121e-17, + -5.3279941738772867e-10, 2.7437977643314845e-10, -6.9957960920705679e-11, + 2.5899863874868481e-17, 8.8566890996696381e-12, -4.403168815871311e-12, + 1.0865561947091654e-12}, + {-6.5262391859530942e-4, 8.3949872067208728e-4, -4.3829709854172101e-4, + -6.969091458420552e-7, 1.6644846642067548e-4, -1.2783517679769219e-4, + 4.6299532636913043e-5, 4.5579098679227077e-9, -1.0595271125805195e-5, + 6.7833429048651666e-6, -2.1075476666258804e-6, -1.7213731432817145e-11, + 3.7735877416110979e-7, -2.1867506700122867e-7, 6.2202288040189269e-8, + 6.5977038267330006e-16, -9.5903864974256858e-9, 5.2132144922808078e-9, + -1.3991589583935709e-9, 5.382058999060575e-16, 1.9484714275467745e-10, + -1.0127287556389682e-10, 2.6077347197254926e-11, -5.0904186999932993e-18, + -3.3721464474854592e-12}, + {-5.9676129019274625e-4, -7.2048954160200106e-5, 6.7823088376673284e-4, + -6.4014752602627585e-4, 2.7750107634328704e-4, 1.8197008380465151e-7, + -8.4795071170685032e-5, 6.105192082501531e-5, -2.1073920183404862e-5, + -8.8585890141255994e-10, 4.5284535953805377e-6, -2.8427815022504408e-6, + 8.7082341778646412e-7, 3.6886101871706965e-12, -1.5344695190702061e-7, + 8.862466778790695e-8, -2.5184812301826817e-8, -1.0225912098215092e-14, + 3.8969470758154777e-9, -2.1267304792235635e-9, 5.7370135528051385e-10, + -1.887749850169741e-19, -8.0931538694657866e-11, 4.2382723283449199e-11, + -1.1002224534207726e-11}, + {1.3324454494800656e-3, -1.9144384985654775e-3, 1.1089369134596637e-3, + 9.932404122642299e-7, -5.0874501293093199e-4, 4.2735056665392884e-4, + -1.6858853767910799e-4, -8.1301893922784998e-9, 4.5284402370562147e-5, + -3.127053674781734e-5, 1.044986828530338e-5, 4.8435226265680926e-11, + -2.1482565873456258e-6, 1.329369701097492e-6, -4.0295693092101029e-7, + -1.7567877666323291e-13, 7.0145043163668257e-8, -4.040787734999483e-8, + 1.1474026743371963e-8, 3.9642746853563325e-18, -1.7804938269892714e-9, + 9.7480262548731646e-10, -2.6405338676507616e-10, 5.794875163403742e-18, + 3.7647749553543836e-11}, + {1.579727660730835e-3, 1.6251626278391582e-4, -2.0633421035543276e-3, + 2.1389686185689098e-3, -1.0108559391263003e-3, -3.9912705529919201e-7, + 3.6235025084764691e-4, -2.8143901463712154e-4, 1.0449513336495887e-4, + 2.1211418491830297e-9, -2.5779417251947842e-5, 1.7281818956040463e-5, + -5.6413773872904282e-6, -1.1024320105776174e-11, 1.1223224418895175e-6, + -6.8693396379526735e-7, 2.0653236975414887e-7, 4.6714772409838506e-14, + -3.5609886164949055e-8, 2.0470855345905963e-8, -5.8091738633283358e-9, + -1.332821287582869e-16, 9.0354604391335133e-10, -4.9598782517330834e-10, + 1.3481607129399749e-10}, + {-4.0725121195140166e-3, 6.4033628338080698e-3, -4.0410161081676618e-3, + -2.183732802866233e-6, 2.1740441801254639e-3, -1.9700440518418892e-3, + 8.3595469747962458e-4, 1.9445447567109655e-8, -2.5779387120421696e-4, + 1.9009987368139304e-4, -6.7696499937438965e-5, -1.4440629666426572e-10, + 1.5712512518742269e-5, -1.0304008744776893e-5, 3.304517767401387e-6, + 7.9829760242325709e-13, -6.4097794149313004e-7, 3.8894624761300056e-7, + -1.1618347644948869e-7, -2.816808630596451e-15, 1.9878012911297093e-8, + -1.1407719956357511e-8, 3.2355857064185555e-9, 4.1759468293455945e-20, + -5.0423112718105824e-10}, + {-5.9475779383993003e-3, -5.4016476789260452e-4, 8.7910413550767898e-3, + -9.8576315587856125e-3, 5.0134695031021538e-3, 1.2807521786221875e-6, + -2.0626019342754683e-3, 1.7109128573523058e-3, -6.7695312714133799e-4, + -6.9011545676562133e-9, 1.8855128143995902e-4, -1.3395215663491969e-4, + 4.6263183033528039e-5, 4.0034230613321351e-11, -1.0255652921494033e-5, + 6.612086372797651e-6, -2.0913022027253008e-6, -2.0951775649603837e-13, + 3.9756029041993247e-7, -2.3956211978815887e-7, 7.1182883382145864e-8, + 8.925574873053455e-16, -1.2101547235064676e-8, 6.9350618248334386e-9, + -1.9661464453856102e-9}, + {1.7402027787522711e-2, -2.9527880945699121e-2, 2.0045875571402799e-2, + 7.0289515966903407e-6, -1.2375421071343148e-2, 1.1976293444235254e-2, + -5.4156038466518525e-3, -6.3290893396418616e-8, 1.8855118129005065e-3, + -1.473473274825001e-3, 5.5515810097708387e-4, 5.2406834412550662e-10, + -1.4357913535784836e-4, 9.9181293224943297e-5, -3.3460834749478311e-5, + -3.5755837291098993e-12, 7.1560851960630076e-6, -4.5516802628155526e-6, + 1.4236576649271475e-6, 1.8803149082089664e-14, -2.6623403898929211e-7, + 1.5950642189595716e-7, -4.7187514673841102e-8, -6.5107872958755177e-17, + 7.9795091026746235e-9}, + {3.0249124160905891e-2, 2.4817436002649977e-3, -4.9939134373457022e-2, + 5.9915643009307869e-2, -3.2483207601623391e-2, -5.7212968652103441e-6, + 1.5085251778569354e-2, -1.3261324005088445e-2, 5.5515262632426148e-3, + 3.0263182257030016e-8, -1.7229548406756723e-3, 1.2893570099929637e-3, + -4.6845138348319876e-4, -1.830259937893045e-10, 1.1449739014822654e-4, + -7.7378565221244477e-5, 2.5625836246985201e-5, 1.0766165333192814e-12, + -5.3246809282422621e-6, 3.349634863064464e-6, -1.0381253128684018e-6, + -5.608909920621128e-15, 1.9150821930676591e-7, -1.1418365800203486e-7, + 3.3654425209171788e-8}, + {-9.9051020880159045e-2, 1.7954011706123486e-1, -1.2989606383463778e-1, + -3.1478872752284357e-5, 9.0510635276848131e-2, -9.2828824411184397e-2, + 4.4412112839877808e-2, 2.7779236316835888e-7, -1.7229543805449697e-2, + 1.4182925050891573e-2, -5.6214161633747336e-3, -2.39598509186381e-9, + 1.6029634366079908e-3, -1.1606784674435773e-3, 4.1001337768153873e-4, + 1.8365800754090661e-11, -9.5844256563655903e-5, 6.3643062337764708e-5, + -2.076250624489065e-5, -1.1806020912804483e-13, 4.2131808239120649e-6, + -2.6262241337012467e-6, 8.0770620494930662e-7, 6.0125912123632725e-16, + -1.4729737374018841e-7}, + {-1.9994542198219728e-1, -1.5056113040026424e-2, 3.6470239469348489e-1, + -4.6435192311733545e-1, 2.6640934719197893e-1, 3.4038266027147191e-5, + -1.3784338709329624e-1, 1.276467178337056e-1, -5.6213828755200985e-2, + -1.753150885483011e-7, 1.9235592956768113e-2, -1.5088821281095315e-2, + 5.7401854451350123e-3, 1.0622382710310225e-9, -1.5335082692563998e-3, + 1.0819320643228214e-3, -3.7372510193945659e-4, -6.6170909729031985e-12, + 8.4263617380909628e-5, -5.5150706827483479e-5, 1.7769536448348069e-5, + 3.8827923210205533e-14, -3.53513697488768e-6, 2.1865832130045269e-6, + -6.6812849447625594e-7}, + {7.2438608504029431e-1, -1.3918010932653375, 1.0654143352413968, + 1.876173868950258e-4, -8.2705501176152696e-1, 8.9352433347828414e-1, + -4.4971003995291339e-1, -1.6107401567546652e-6, 1.9235590165271091e-1, + -1.6597702160042609e-1, 6.8882222681814333e-2, 1.3910091724608687e-8, + -2.146911561508663e-2, 1.6228980898865892e-2, -5.9796016172584256e-3, + -1.1287469112826745e-10, 1.5167451119784857e-3, -1.0478634293553899e-3, + 3.5539072889126421e-4, 8.1704322111801517e-13, -7.7773013442452395e-5, + 5.0291413897007722e-5, -1.6035083867000518e-5, 1.2469354315487605e-14, + 3.1369106244517615e-6}, + {1.6668949727276811, 1.165462765994632e-1, -3.3288393225018906, + 4.4692325482864037, -2.6977693045875807, -2.600667859891061e-4, + 1.5389017615694539, -1.4937962361134612, 6.8881964633233148e-1, + 1.3077482004552385e-6, -2.5762963325596288e-1, 2.1097676102125449e-1, + -8.3714408359219882e-2, -7.7920428881354753e-9, 2.4267923064833599e-2, + -1.7813678334552311e-2, 6.3970330388900056e-3, 4.9430807090480523e-11, + -1.5554602758465635e-3, 1.0561196919903214e-3, -3.5277184460472902e-4, + 9.3002334645022459e-14, 7.5285855026557172e-5, -4.8186515569156351e-5, + 1.5227271505597605e-5}, + {-6.6188298861372935, 1.3397985455142589e+1, -1.0789350606845146e+1, + -1.4352254537875018e-3, 9.2333694596189809, -1.0456552819547769e+1, + 5.5105526029033471, 1.2024439690716742e-5, -2.5762961164755816, + 2.3207442745387179, -1.0045728797216284, -1.0207833290021914e-7, + 3.3975092171169466e-1, -2.6720517450757468e-1, 1.0235252851562706e-1, + 8.4329730484871625e-10, -2.7998284958442595e-2, 2.0066274144976813e-2, + -7.0554368915086242e-3, 1.9402238183698188e-12, 1.6562888105449611e-3, + -1.1082898580743683e-3, 3.654545161310169e-4, -5.1290032026971794e-11, + -7.6340103696869031e-5}, + {-1.7112706061976095e+1, -1.1208044642899116, 3.7131966511885444e+1, + -5.2298271025348962e+1, 3.3058589696624618e+1, 2.4791298976200222e-3, + -2.061089403411526e+1, 2.088672775145582e+1, -1.0045703956517752e+1, + -1.2238783449063012e-5, 4.0770134274221141, -3.473667358470195, + 1.4329352617312006, 7.1359914411879712e-8, -4.4797257159115612e-1, + 3.4112666080644461e-1, -1.2699786326594923e-1, -2.8953677269081528e-10, + 3.3125776278259863e-2, -2.3274087021036101e-2, 8.0399993503648882e-3, + -1.177805216235265e-9, -1.8321624891071668e-3, 1.2108282933588665e-3, + -3.9479941246822517e-4}, + {7.389033153567425e+1, -1.5680141270402273e+2, 1.322177542759164e+2, + 1.3692876877324546e-2, -1.2366496885920151e+2, 1.4620689391062729e+2, + -8.0365587724865346e+1, -1.1259851148881298e-4, 4.0770132196179938e+1, + -3.8210340013273034e+1, 1.719522294277362e+1, 9.3519707955168356e-7, + -6.2716159907747034, 5.1168999071852637, -2.0319658112299095, + -4.9507215582761543e-9, 5.9626397294332597e-1, -4.4220765337238094e-1, + 1.6079998700166273e-1, -2.4733786203223402e-8, -4.0307574759979762e-2, + 2.7849050747097869e-2, -9.4751858992054221e-3, 6.419922235909132e-6, + 2.1250180774699461e-3}, + {2.1216837098382522e+2, 1.3107863022633868e+1, -4.9698285932871748e+2, + 7.3121595266969204e+2, -4.8213821720890847e+2, -2.8817248692894889e-2, + 3.2616720302947102e+2, -3.4389340280087117e+2, 1.7195193870816232e+2, + 1.4038077378096158e-4, -7.52594195897599e+1, 6.651969984520934e+1, + -2.8447519748152462e+1, -7.613702615875391e-7, 9.5402237105304373, + -7.5175301113311376, 2.8943997568871961, -4.6612194999538201e-7, + -8.0615149598794088e-1, 5.8483006570631029e-1, -2.0845408972964956e-1, + 1.4765818959305817e-4, 5.1000433863753019e-2, -3.3066252141883665e-2, + 1.5109265210467774e-2}, + {-9.8959643098322368e+2, 2.1925555360905233e+3, -1.9283586782723356e+3, + -1.5925738122215253e-1, 1.9569985945919857e+3, -2.4072514765081556e+3, + 1.3756149959336496e+3, 1.2920735237496668e-3, -7.525941715948055e+2, + 7.3171668742208716e+2, -3.4137023466220065e+2, -9.9857390260608043e-6, + 1.3356313181291573e+2, -1.1276295161252794e+2, 4.6310396098204458e+1, + -7.9237387133614756e-6, -1.4510726927018646e+1, 1.1111771248100563e+1, + -4.1690817945270892, 3.1008219800117808e-3, 1.1220095449981468, + -7.6052379926149916e-1, 3.6262236505085254e-1, 2.216867741940747e-1, + 4.8683443692930507e-1}}; + + int k, n, sgn; + int maxpow = 0; + static scalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + scalar_t lambda = x / a; + scalar_t sigma = (x - a) / a; + scalar_t eta, res, ck, ckterm, term, absterm; + scalar_t absoldterm = INFINITY; + scalar_t etapow[25] = {1}; + scalar_t sum = 0; + scalar_t afac = 1; + + if (igam) { + sgn = -1; + } + else { + sgn = 1; + } + + if (lambda > 1) { + eta = std::sqrt(-2 * (std::log1p(sigma) - sigma)); + } + else if (lambda < 1) { + eta = -std::sqrt(-2 * (std::log1p(sigma) - sigma)); + } + else { + eta = 0; + } + res = 0.5 * std::erfc(sgn * eta * std::sqrt(a / 2)); + + for (k = 0; k < 25; k++) { + ck = d[k][0]; + for (n = 1; n < 25; n++) { + if (n > maxpow) { + etapow[n] = eta * etapow[n-1]; + maxpow += 1; + } + ckterm = d[k][n]*etapow[n]; + ck += ckterm; + if (std::fabs(ckterm) < MACHEP * std::fabs(ck)) { + break; + } + } + term = ck * afac; + absterm = std::fabs(term); + if (absterm > absoldterm) { + break; + } + sum += term; + if (absterm < MACHEP * std::fabs(sum)) { + break; + } + absoldterm = absterm; + afac /= a; + } + res += sgn * std::exp(-0.5 * a * eta * eta) * sum / std::sqrt(2 * c10::pi * a); + + return res; +} + +template +static scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar_t x) { + // Compute igamc using DLMF 8.9.2. [igam1] + int i; + scalar_t ans, ax, c, yc, r, t, y, z; + scalar_t pk, pkm1, pkm2, qk, qkm1, qkm2; + int MAXITER = 2000; + static scalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + static scalar_t BIG = std::is_same::value ? + 4.503599627370496e15 : 16777216.; + static scalar_t BIGINV = std::is_same::value ? + 2.22044604925031308085e-16 : 5.9604644775390625E-8; + + ax = _igam_helper_fac(a, x); + if (ax == 0.0) { + return 0.0; + } + + /* continued fraction */ + y = 1.0 - a; + z = x + y + 1.0; + c = 0.0; + pkm2 = 1.0; + qkm2 = x; + pkm1 = x + 1.0; + qkm1 = z * x; + ans = pkm1 / qkm1; + + for (i = 0; i < MAXITER; i++) { + c += 1.0; + y += 1.0; + z += 2.0; + yc = y * c; + pk = pkm1 * z - pkm2 * yc; + qk = qkm1 * z - qkm2 * yc; + if (qk != 0) { + r = pk / qk; + t = std::fabs((ans - r) / r); + ans = r; + } + else { + t = 1.0; + } + pkm2 = pkm1; + pkm1 = pk; + qkm2 = qkm1; + qkm1 = qk; + if (std::fabs(pk) > BIG) { + pkm2 *= BIGINV; + pkm1 *= BIGINV; + qkm2 *= BIGINV; + qkm1 *= BIGINV; + } + if (t <= MACHEP) { + break; + } + } + return ans * ax; +} + +template +inline scalar_t calc_igammac(scalar_t a, scalar_t x) { + /* the calculation of the regularized upper incomplete gamma function + * is done differently based on the values of a and x: + * - if x and/or a is at the boundary of defined region, then assign the + * result at the boundary + * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for + * Large Parameter (see DLMF 8.12.4 [igam1]) + * - if x > 1.1 and x < a, using the substraction from the regularized lower + * incomplete gamma + * - otherwise, calculate the series from [igam2] eq (5) + */ + scalar_t absxma_a; + + static scalar_t SMALL = 20.0; + static scalar_t LARGE = 200.0; + static scalar_t SMALLRATIO = 0.3; + static scalar_t LARGERATIO = 4.5; + + // note that in SciPy, a and x are non-negative, with exclusive 0s (i.e., + // at most 1 of them can be 0), where igammac(0, x) = 0.0 iff x > 0. + if ((x < 0) || (a < 0)) { + // out of defined-region of the function + return std::numeric_limits::quiet_NaN(); + } + else if (a == 0) { + if (x > 0) { + return 0.0; + } + else { + return std::numeric_limits::quiet_NaN(); + } + } + else if (x == 0) { + return 1.0; + } + else if (std::isinf(a)) { + if (std::isinf(x)) { + return std::numeric_limits::quiet_NaN(); + } + return 1.0; + } + else if (std::isinf(x)) { + return 0.0; + } + + absxma_a = std::fabs(x - a) / a; + if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { + return _igam_helper_asymptotic_series(a, x, 0); + } + else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) { + return _igam_helper_asymptotic_series(a, x, 0); + } + + if (x > 1.1) { + if (x < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_continued_fraction(a, x); + } + } + else if (x <= 0.5) { + if (-0.4 / std::log(x) < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_series(a, x); + } + } + else { + if (x * 1.1 < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_series(a, x); + } + } +} + +template +scalar_t calc_igamma(scalar_t a, scalar_t x) { + /* the calculation of the regularized lower incomplete gamma function + * is done differently based on the values of a and x: + * - if x and/or a is at the boundary of defined region, then assign the + * result at the boundary + * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for + * Large Parameter (see DLMF 8.12.3 [igam1]) + * - if x > 1 and x > a, using the substraction from the regularized upper + * incomplete gamma + * - otherwise, calculate the series from [igam2] eq (4) + */ + scalar_t absxma_a; + static scalar_t SMALL = 20.0; + static scalar_t LARGE = 200.0; + static scalar_t SMALLRATIO = 0.3; + static scalar_t LARGERATIO = 4.5; + + // boundary values following SciPy + // note that in SciPy, a and x are non-negative, with exclusive 0s (i.e., + // at most 1 of them can be 0), where igamma(0, x) = 1.0 iff x > 0. + if ((x < 0) || (a < 0)) { + // out of defined-region of the function + return std::numeric_limits::quiet_NaN(); + } + else if (a == 0) { + if (x > 0) { + return 1.0; + } + else { + return std::numeric_limits::quiet_NaN(); + } + } + else if (x == 0) { + return 0.0; // zero integration limit + } + else if (std::isinf(a)) { + if (std::isinf(x)) { + return std::numeric_limits::quiet_NaN(); + } + return 0.0; + } + else if (std::isinf(x)) { + return 1.0; + } + + /* Asymptotic regime where a ~ x. See [igam2] */ + absxma_a = std::fabs(x - a) / a; + if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { + return _igam_helper_asymptotic_series(a, x, 1); + } + else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) { + return _igam_helper_asymptotic_series(a, x, 1); + } + + if ((x > 1.0) && (x > a)) { + return 1.0 - calc_igammac(a, x); + } + + return _igam_helper_series(a, x); +} + +template <> +C10_UNUSED inline c10::BFloat16 calc_igamma(c10::BFloat16 a, c10::BFloat16 x) { + return calc_igamma(float(a), float(x)); +} + +template <> +C10_UNUSED inline c10::Half calc_igamma(c10::Half a, c10::Half x) { + return calc_igamma(float(a), float(x)); +} + +template <> +C10_UNUSED inline c10::BFloat16 calc_igammac(c10::BFloat16 a, c10::BFloat16 x) { + return calc_igammac(float(a), float(x)); +} + +template <> +C10_UNUSED inline c10::Half calc_igammac(c10::Half a, c10::Half x) { + return calc_igammac(float(a), float(x)); +} + +inline c10::BFloat16 calc_erfinv(c10::BFloat16 a) { return calc_erfinv(float(a)); } + +template +inline T abs_impl(T v) { + return std::abs(v); +} + +template <> +C10_UNUSED inline uint8_t abs_impl(uint8_t v) { + return v; +} + +template +inline typename std::enable_if::value, T>::type +calc_gcd(T a, T b) { + a = abs_impl(a); + b = abs_impl(b); + while (a != 0) { + T c = a; + a = b % a; + b = c; + } + return b; +} + +template +C10_HOST_DEVICE T exp2_impl(T x) { + return std::exp2(x); +} + +template +C10_HOST_DEVICE c10::complex exp2_impl(c10::complex x) { + // There is no std::exp2 overload for complex, so instead + // use the identity 2^x = e^(ln(2) * x) + constexpr auto ln2 = c10::ln_2; + return std::exp(ln2 * x); +} + +/* + * This function is derived from the implementation of the chbevl function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + * + * Evaluates the series + * + * len-1 + * - ' + * y = > array[i] T (x/2) + * - i + * i=0 + * + * of Chebyshev polynomials Ti at argument x/2. + * + * Coefficients are stored in reverse order, i.e. the zero order term is last in the array. Note len is the number of + * coefficients, not the order. + * + * If coefficients are for the interval a to b, x must have been transformed to x -> 2(2x - b - a)/(b-a) before + * entering the routine. This maps x from (a, b) to (-1, 1), over which the Chebyshev polynomials are defined. + * + * If the coefficients are for the inverted interval, in which (a, b) is mapped to (1/b, 1/a), the transformation + * required is x -> 2(2ab/x - b - a)/(b-a). If b is infinity, this becomes x -> 4a/x - 1. + */ +template +inline typename std::enable_if::value, T>::type +chbevl(const T x, const T array[], size_t len) { + T b0, b1, b2; + + b0 = array[0]; + b1 = static_cast(0.0); + + for (size_t i = 1; i < len; ++i) { + b2 = b1; + b1 = b0; + b0 = x * b1 - b2 + array[i]; + } + + return (static_cast(0.5) * (b0 - b2)); +} + +/* + * This function is derived from the implementation of the i0 function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + * + * Computes an approximation of the zeroth order modified Bessel function of the first kind. + * The approximation is actually two (sub)approximations, both using a Chebyshev polynomial expansion. + * One approximates the function over [0, 8], and the other over (8, infinity). This function takes the absolute value + * of all inputs to convert them into the domain of the approximation. + */ +template +inline std::tuple chebyshev_coefficients_i0e_A() { + /* Chebyshev coefficients for exp(-x) I0(x) + * in the interval [0,8]. + * + * lim(x->0){ exp(-x) I0(x) } = 1. + */ + static const T coeff[] = { + -4.41534164647933937950E-18, 3.33079451882223809783E-17, + -2.43127984654795469359E-16, 1.71539128555513303061E-15, + -1.16853328779934516808E-14, 7.67618549860493561688E-14, + -4.85644678311192946090E-13, 2.95505266312963983461E-12, + -1.72682629144155570723E-11, 9.67580903537323691224E-11, + -5.18979560163526290666E-10, 2.65982372468238665035E-9, + -1.30002500998624804212E-8, 6.04699502254191894932E-8, + -2.67079385394061173391E-7, 1.11738753912010371815E-6, + -4.41673835845875056359E-6, 1.64484480707288970893E-5, + -5.75419501008210370398E-5, 1.88502885095841655729E-4, + -5.76375574538582365885E-4, 1.63947561694133579842E-3, + -4.32430999505057594430E-3, 1.05464603945949983183E-2, + -2.37374148058994688156E-2, 4.93052842396707084878E-2, + -9.49010970480476444210E-2, 1.71620901522208775349E-1, + -3.04682672343198398683E-1, 6.76795274409476084995E-1}; + return std::make_tuple(coeff, 30); +}; + +template +inline std::tuple chebyshev_coefficients_i0e_B() { + /* Chebyshev coefficients for exp(-x) sqrt(x) I0(x) + * in the inverted interval [8,infinity]. + * + * lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi). + */ + static const T coeff[] = { + -7.23318048787475395456E-18, -4.83050448594418207126E-18, + 4.46562142029675999901E-17, 3.46122286769746109310E-17, + -2.82762398051658348494E-16, -3.42548561967721913462E-16, + 1.77256013305652638360E-15, 3.81168066935262242075E-15, + -9.55484669882830764870E-15, -4.15056934728722208663E-14, + 1.54008621752140982691E-14, 3.85277838274214270114E-13, + 7.18012445138366623367E-13, -1.79417853150680611778E-12, + -1.32158118404477131188E-11, -3.14991652796324136454E-11, + 1.18891471078464383424E-11, 4.94060238822496958910E-10, + 3.39623202570838634515E-9, 2.26666899049817806459E-8, + 2.04891858946906374183E-7, 2.89137052083475648297E-6, + 6.88975834691682398426E-5, 3.36911647825569408990E-3, + 8.04490411014108831608E-1}; + + return std::make_tuple(coeff, 25); +}; + +template +inline typename std::enable_if::value, std::tuple>::type +chebyshev_coefficients_i1e_A() { + /* Chebyshev coefficients for exp(-x) I1(x) + * in the interval [0,8]. + * + * lim(x->0){ exp(-x) I1(x) / x } = 1/2. + */ + static const T coeff[] = { + 2.77791411276104639959E-18, -2.11142121435816608115E-17, + 1.55363195773620046921E-16, -1.10559694773538630805E-15, + 7.60068429473540693410E-15, -5.04218550472791168711E-14, + 3.22379336594557470981E-13, -1.98397439776494371520E-12, + 1.17361862988909016308E-11, -6.66348972350202774223E-11, + 3.62559028155211703701E-10, -1.88724975172282928790E-9, + 9.38153738649577178388E-9, -4.44505912879632808065E-8, + 2.00329475355213526229E-7, -8.56872026469545474066E-7, + 3.47025130813767847674E-6, -1.32731636560394358279E-5, + 4.78156510755005422638E-5, -1.61760815825896745588E-4, + 5.12285956168575772895E-4, -1.51357245063125314899E-3, + 4.15642294431288815669E-3, -1.05640848946261981558E-2, + 2.47264490306265168283E-2, -5.29459812080949914269E-2, + 1.02643658689847095384E-1, -1.76416518357834055153E-1, + 2.52587186443633654823E-1}; + return std::make_tuple(coeff, 29); +}; + +template +inline typename std::enable_if::value, std::tuple>::type +chebyshev_coefficients_i1e_A() { + /* Chebyshev coefficients for exp(-x) I1(x) + * in the interval [0,8]. + * + * lim(x->0){ exp(-x) I1(x) / x } = 1/2. + */ + static const T coeff[] = { + 9.38153738649577178388E-9f, + -4.44505912879632808065E-8f, + 2.00329475355213526229E-7f, + -8.56872026469545474066E-7f, + 3.47025130813767847674E-6f, + -1.32731636560394358279E-5f, + 4.78156510755005422638E-5f, + -1.61760815825896745588E-4f, + 5.12285956168575772895E-4f, + -1.51357245063125314899E-3f, + 4.15642294431288815669E-3f, + -1.05640848946261981558E-2f, + 2.47264490306265168283E-2f, + -5.29459812080949914269E-2f, + 1.02643658689847095384E-1f, + -1.76416518357834055153E-1f, + 2.52587186443633654823E-1f}; + return std::make_tuple(coeff, 17); +}; + +template +inline typename std::enable_if::value, std::tuple>::type +chebyshev_coefficients_i1e_B() { + /* Chebyshev coefficients for exp(-x) sqrt(x) I1(x) + * in the inverted interval [8,infinity]. + * + * lim(x->inf){ exp(-x) sqrt(x) I1(x) } = 1/sqrt(2pi). + */ + static const T coeff[] = { + 7.51729631084210481353E-18, 4.41434832307170791151E-18, + -4.65030536848935832153E-17, -3.20952592199342395980E-17, + 2.96262899764595013876E-16, 3.30820231092092828324E-16, + -1.88035477551078244854E-15, -3.81440307243700780478E-15, + 1.04202769841288027642E-14, 4.27244001671195135429E-14, + -2.10154184277266431302E-14, -4.08355111109219731823E-13, + -7.19855177624590851209E-13, 2.03562854414708950722E-12, + 1.41258074366137813316E-11, 3.25260358301548823856E-11, + -1.89749581235054123450E-11, -5.58974346219658380687E-10, + -3.83538038596423702205E-9, -2.63146884688951950684E-8, + -2.51223623787020892529E-7, -3.88256480887769039346E-6, + -1.10588938762623716291E-4, -9.76109749136146840777E-3, + 7.78576235018280120474E-1}; + + return std::make_tuple(coeff, 25); +}; + +template +inline typename std::enable_if::value, std::tuple>::type +chebyshev_coefficients_i1e_B() { + /* Chebyshev coefficients for exp(-x) sqrt(x) I1(x) + * in the inverted interval [8,infinity]. + * + * lim(x->inf){ exp(-x) sqrt(x) I1(x) } = 1/sqrt(2pi). + */ + static const T coeff[] = { + -3.83538038596423702205E-9f, + -2.63146884688951950684E-8f, + -2.51223623787020892529E-7f, + -3.88256480887769039346E-6f, + -1.10588938762623716291E-4f, + -9.76109749136146840777E-3f, + 7.78576235018280120474E-1f}; + + return std::make_tuple(coeff, 7); +}; + +template +inline typename std::enable_if::value, T>::type +calc_i0(T _x) { + T x = std::abs(_x); + + if (x <= T{8.0}) { + auto coeff_pair = chebyshev_coefficients_i0e_A(); + auto A = std::get<0>(coeff_pair); + auto len = std::get<1>(coeff_pair); + T y = (x / T{2.0}) - T{2.0}; + return static_cast(std::exp(x) * chbevl(y, A, len)); + } + auto coeff_pair = chebyshev_coefficients_i0e_B(); + auto B = std::get<0>(coeff_pair); + auto len = std::get<1>(coeff_pair); + return std::exp(x) * chbevl(T{32.0} / x - T{2.0}, B, len) / std::sqrt(x); +} + +// Upcast bfloat16 input to float for numerical accuracy purposes +inline c10::BFloat16 calc_i0(c10::BFloat16 a) { return calc_i0(static_cast(a)); } + +/* + * This function is derived from the implementation of the i1 function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + * + * Computes an approximation of the first order modified Bessel function of the first kind. + * The approximation is actually two (sub)approximations, both using a Chebyshev polynomial expansion. + * One approximates the function over [0, 8], and the other over (8, infinity). This function takes the absolute value + * of all inputs to convert them into the domain of the approximation. + */ +template +inline typename std::enable_if::value, T>::type +calc_i1(T _x) { + T x = std::abs(_x); + + if (x <= T{8.0}) { + auto coeff_pair = chebyshev_coefficients_i1e_A(); + auto A = std::get<0>(coeff_pair); + auto len = std::get<1>(coeff_pair); + T y = (x / T{2.0}) - T{2.0}; + const T out = std::exp(x) * x * chbevl(y, A, len); + return (_x < T{0.0}) ? -out : out; + } + auto coeff_pair = chebyshev_coefficients_i1e_B(); + auto B = std::get<0>(coeff_pair); + auto len = std::get<1>(coeff_pair); + const T out = (std::exp(x) * chbevl(T{32.0} / x - T{2.0}, B, len)) / std::sqrt(x); + return (_x < T{0.0}) ? -out : out; +} + +/* + * This function is derived from the implementation of the i1e function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + * + * Computes an approximation of the exponentially scaled first order modified Bessel function of the first kind. + * The approximation is actually two (sub)approximations, both using a Chebyshev polynomial expansion. + * One approximates the function over [0, 8], and the other over (8, infinity). This function takes the absolute value + * of all inputs to convert them into the domain of the approximation. + */ +template +inline typename std::enable_if::value, T>::type +calc_i1e(T _x) { + T x = std::abs(_x); + + if (x <= T{8.0}) { + auto coeff_pair = chebyshev_coefficients_i1e_A(); + auto A = std::get<0>(coeff_pair); + auto len = std::get<1>(coeff_pair); + T y = (x / T{2.0}) - T{2.0}; + const T out = chbevl(y, A, len) * x; + return (_x < T{0.0}) ? -out : out; + } + auto coeff_pair = chebyshev_coefficients_i1e_B(); + auto B = std::get<0>(coeff_pair); + auto len = std::get<1>(coeff_pair); + const auto out = chbevl(T{32.0} / x - T{2.0}, B, len) / std::sqrt(x); + return (_x < T{0.0}) ? -out : out; +} + +/* + * This function is derived from the implementation of the i1e function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + * + * Computes the argument, x, for which the area under the Gaussian probability density function + * (integrated from minus infinity to x) is equal to y. + */ +template +inline C10_HOST_DEVICE T calc_ndtri(T y0) { + + /* sqrt(2pi) */ + constexpr T s2pi = 2.50662827463100050242E0; + constexpr T one = 1; + constexpr T zero = 0; + + /* approximation for 0 <= |y - 0.5| <= 3/8 */ + static const T P0[5] = { + -5.99633501014107895267E1, + 9.80010754185999661536E1, + -5.66762857469070293439E1, + 1.39312609387279679503E1, + -1.23916583867381258016E0, + }; + + static const T Q0[9] = { + 1.00000000000000000000E0, + 1.95448858338141759834E0, + 4.67627912898881538453E0, + 8.63602421390890590575E1, + -2.25462687854119370527E2, + 2.00260212380060660359E2, + -8.20372256168333339912E1, + 1.59056225126211695515E1, + -1.18331621121330003142E0, + }; + + /* Approximation for interval z = sqrt(-2 log y ) between 2 and 8 + * i.e., y between exp(-2) = .135 and exp(-32) = 1.27e-14. + */ + static const T P1[9] = { + 4.05544892305962419923E0, + 3.15251094599893866154E1, + 5.71628192246421288162E1, + 4.40805073893200834700E1, + 1.46849561928858024014E1, + 2.18663306850790267539E0, + -1.40256079171354495875E-1, + -3.50424626827848203418E-2, + -8.57456785154685413611E-4, + }; + + static const T Q1[9] = { + 1.00000000000000000000E0, + 1.57799883256466749731E1, + 4.53907635128879210584E1, + 4.13172038254672030440E1, + 1.50425385692907503408E1, + 2.50464946208309415979E0, + -1.42182922854787788574E-1, + -3.80806407691578277194E-2, + -9.33259480895457427372E-4, + }; + + /* Approximation for interval z = sqrt(-2 log y ) between 8 and 64 + * i.e., y between exp(-32) = 1.27e-14 and exp(-2048) = 3.67e-890. + */ + + static const T P2[9] = { + 3.23774891776946035970E0, + 6.91522889068984211695E0, + 3.93881025292474443415E0, + 1.33303460815807542389E0, + 2.01485389549179081538E-1, + 1.23716634817820021358E-2, + 3.01581553508235416007E-4, + 2.65806974686737550832E-6, + 6.23974539184983293730E-9, + }; + + static const T Q2[9] = { + 1.00000000000000000000E0, + 6.02427039364742014255E0, + 3.67983563856160859403E0, + 1.37702099489081330271E0, + 2.16236993594496635890E-1, + 1.34204006088543189037E-2, + 3.28014464682127739104E-4, + 2.89247864745380683936E-6, + 6.79019408009981274425E-9, + }; + + if (y0 == zero) { + return -std::numeric_limits::infinity(); + } + if (y0 == one) { + return std::numeric_limits::infinity(); + } + if (y0 < zero || y0 > one) { + return std::numeric_limits::quiet_NaN(); + } + bool code = true; + T y = y0; + if (y > one - T{0.13533528323661269189}) { /* 0.135... = exp(-2) */ + y = one - y; + code = false; + } + + if (y > T{0.13533528323661269189}) { + y = y - T{0.5}; + const T y2 = y * y; + T x = y + y * (y2 * polevl(y2, P0, 4) / polevl(y2, Q0, 8)); + return (x * s2pi); + } + + T x = ::sqrt(T{-2.0} * ::log(y)); + const T x0 = x - ::log(x) / x; + + const T z = one / x; + T x1; + if (x < T{8.0}) /* y > exp(-32) = 1.2664165549e-14 */ + { + x1 = z * polevl(z, P1, 8) / polevl(z, Q1, 8); + } else { + x1 = z * polevl(z, P2, 8) / polevl(z, Q2, 8); + } + x = x0 - x1; + if (code) { + x = -x; + } + return x; +} + +/* The next function is taken from http://ab-initio.mit.edu/Faddeev */ + +/* Copyright (c) 2012 Massachusetts Institute of Technology + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE + * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION + * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION + * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +/* erfcx(x) = exp(x^2) erfc(x) function, for real x, written by + Steven G. Johnson, October 2012. + + This function combines a few different ideas. + + First, for x > 50, it uses a continued-fraction expansion (same as + for the Faddeeva function, but with algebraic simplifications for z=i*x). + + Second, for 0 <= x <= 50, it uses Chebyshev polynomial approximations, + but with two twists: + + a) It maps x to y = 4 / (4+x) in [0,1]. This simple transformation, + inspired by a similar transformation in the octave-forge/specfun + erfcx by Soren Hauberg, results in much faster Chebyshev convergence + than other simple transformations I have examined. + + b) Instead of using a single Chebyshev polynomial for the entire + [0,1] y interval, we break the interval up into 100 equal + subintervals, with a switch/lookup table, and use much lower + degree Chebyshev polynomials in each subinterval. This greatly + improves performance in my tests. + + For x < 0, we use the relationship erfcx(-x) = 2 exp(x^2) - erfc(x), + with the usual checks for overflow etcetera. + + Performance-wise, it seems to be substantially faster than either + the SLATEC DERFC function [or an erfcx function derived therefrom] + or Cody's CALERF function (from netlib.org/specfun), while + retaining near machine precision in accuracy. */ + +/* Given y100=100*y, where y = 4/(4+x) for x >= 0, compute erfc(x). + + Uses a look-up table of 100 different Chebyshev polynomials + for y intervals [0,0.01], [0.01,0.02], ...., [0.99,1], generated + with the help of Maple and a little shell script. This allows + the Chebyshev polynomials to be of significantly lower degree (about 1/4) + compared to fitting the whole [0,1] interval with a single polynomial. */ + + +template +C10_HOST_DEVICE inline typename std::enable_if::value, T>::type +erfcx_y100(T y100) +{ + switch (static_cast(y100)) { +case 0: { +T t = 2*y100 - 1; +return 0.70878032454106438663e-3 + (0.71234091047026302958e-3 + (0.35779077297597742384e-5 + (0.17403143962587937815e-7 + (0.81710660047307788845e-10 + (0.36885022360434957634e-12 + 0.15917038551111111111e-14 * t) * t) * t) * t) * t) * t; +} +case 1: { +T t = 2*y100 - 3; +return 0.21479143208285144230e-2 + (0.72686402367379996033e-3 + (0.36843175430938995552e-5 + (0.18071841272149201685e-7 + (0.85496449296040325555e-10 + (0.38852037518534291510e-12 + 0.16868473576888888889e-14 * t) * t) * t) * t) * t) * t; +} +case 2: { +T t = 2*y100 - 5; +return 0.36165255935630175090e-2 + (0.74182092323555510862e-3 + (0.37948319957528242260e-5 + (0.18771627021793087350e-7 + (0.89484715122415089123e-10 + (0.40935858517772440862e-12 + 0.17872061464888888889e-14 * t) * t) * t) * t) * t) * t; +} +case 3: { +T t = 2*y100 - 7; +return 0.51154983860031979264e-2 + (0.75722840734791660540e-3 + (0.39096425726735703941e-5 + (0.19504168704300468210e-7 + (0.93687503063178993915e-10 + (0.43143925959079664747e-12 + 0.18939926435555555556e-14 * t) * t) * t) * t) * t) * t; +} +case 4: { +T t = 2*y100 - 9; +return 0.66457513172673049824e-2 + (0.77310406054447454920e-3 + (0.40289510589399439385e-5 + (0.20271233238288381092e-7 + (0.98117631321709100264e-10 + (0.45484207406017752971e-12 + 0.20076352213333333333e-14 * t) * t) * t) * t) * t) * t; +} +case 5: { +T t = 2*y100 - 11; +return 0.82082389970241207883e-2 + (0.78946629611881710721e-3 + (0.41529701552622656574e-5 + (0.21074693344544655714e-7 + (0.10278874108587317989e-9 + (0.47965201390613339638e-12 + 0.21285907413333333333e-14 * t) * t) * t) * t) * t) * t; +} +case 6: { +T t = 2*y100 - 13; +return 0.98039537275352193165e-2 + (0.80633440108342840956e-3 + (0.42819241329736982942e-5 + (0.21916534346907168612e-7 + (0.10771535136565470914e-9 + (0.50595972623692822410e-12 + 0.22573462684444444444e-14 * t) * t) * t) * t) * t) * t; +} +case 7: { +T t = 2*y100 - 15; +return 0.11433927298290302370e-1 + (0.82372858383196561209e-3 + (0.44160495311765438816e-5 + (0.22798861426211986056e-7 + (0.11291291745879239736e-9 + (0.53386189365816880454e-12 + 0.23944209546666666667e-14 * t) * t) * t) * t) * t) * t; +} +case 8: { +T t = 2*y100 - 17; +return 0.13099232878814653979e-1 + (0.84167002467906968214e-3 + (0.45555958988457506002e-5 + (0.23723907357214175198e-7 + (0.11839789326602695603e-9 + (0.56346163067550237877e-12 + 0.25403679644444444444e-14 * t) * t) * t) * t) * t) * t; +} +case 9: { +T t = 2*y100 - 19; +return 0.14800987015587535621e-1 + (0.86018092946345943214e-3 + (0.47008265848816866105e-5 + (0.24694040760197315333e-7 + (0.12418779768752299093e-9 + (0.59486890370320261949e-12 + 0.26957764568888888889e-14 * t) * t) * t) * t) * t) * t; +} +case 10: { +T t = 2*y100 - 21; +return 0.16540351739394069380e-1 + (0.87928458641241463952e-3 + (0.48520195793001753903e-5 + (0.25711774900881709176e-7 + (0.13030128534230822419e-9 + (0.62820097586874779402e-12 + 0.28612737351111111111e-14 * t) * t) * t) * t) * t) * t; +} +case 11: { +T t = 2*y100 - 23; +return 0.18318536789842392647e-1 + (0.89900542647891721692e-3 + (0.50094684089553365810e-5 + (0.26779777074218070482e-7 + (0.13675822186304615566e-9 + (0.66358287745352705725e-12 + 0.30375273884444444444e-14 * t) * t) * t) * t) * t) * t; +} +case 12: { +T t = 2*y100 - 25; +return 0.20136801964214276775e-1 + (0.91936908737673676012e-3 + (0.51734830914104276820e-5 + (0.27900878609710432673e-7 + (0.14357976402809042257e-9 + (0.70114790311043728387e-12 + 0.32252476000000000000e-14 * t) * t) * t) * t) * t) * t; +} +case 13: { +T t = 2*y100 - 27; +return 0.21996459598282740954e-1 + (0.94040248155366777784e-3 + (0.53443911508041164739e-5 + (0.29078085538049374673e-7 + (0.15078844500329731137e-9 + (0.74103813647499204269e-12 + 0.34251892320000000000e-14 * t) * t) * t) * t) * t) * t; +} +case 14: { +T t = 2*y100 - 29; +return 0.23898877187226319502e-1 + (0.96213386835900177540e-3 + (0.55225386998049012752e-5 + (0.30314589961047687059e-7 + (0.15840826497296335264e-9 + (0.78340500472414454395e-12 + 0.36381553564444444445e-14 * t) * t) * t) * t) * t) * t; +} +case 15: { +T t = 2*y100 - 31; +return 0.25845480155298518485e-1 + (0.98459293067820123389e-3 + (0.57082915920051843672e-5 + (0.31613782169164830118e-7 + (0.16646478745529630813e-9 + (0.82840985928785407942e-12 + 0.38649975768888888890e-14 * t) * t) * t) * t) * t) * t; +} +case 16: { +T t = 2*y100 - 33; +return 0.27837754783474696598e-1 + (0.10078108563256892757e-2 + (0.59020366493792212221e-5 + (0.32979263553246520417e-7 + (0.17498524159268458073e-9 + (0.87622459124842525110e-12 + 0.41066206488888888890e-14 * t) * t) * t) * t) * t) * t; +} +case 17: { +T t = 2*y100 - 35; +return 0.29877251304899307550e-1 + (0.10318204245057349310e-2 + (0.61041829697162055093e-5 + (0.34414860359542720579e-7 + (0.18399863072934089607e-9 + (0.92703227366365046533e-12 + 0.43639844053333333334e-14 * t) * t) * t) * t) * t) * t; +} +case 18: { +T t = 2*y100 - 37; +return 0.31965587178596443475e-1 + (0.10566560976716574401e-2 + (0.63151633192414586770e-5 + (0.35924638339521924242e-7 + (0.19353584758781174038e-9 + (0.98102783859889264382e-12 + 0.46381060817777777779e-14 * t) * t) * t) * t) * t) * t; +} +case 19: { +T t = 2*y100 - 39; +return 0.34104450552588334840e-1 + (0.10823541191350532574e-2 + (0.65354356159553934436e-5 + (0.37512918348533521149e-7 + (0.20362979635817883229e-9 + (0.10384187833037282363e-11 + 0.49300625262222222221e-14 * t) * t) * t) * t) * t) * t; +} +case 20: { +T t = 2*y100 - 41; +return 0.36295603928292425716e-1 + (0.11089526167995268200e-2 + (0.67654845095518363577e-5 + (0.39184292949913591646e-7 + (0.21431552202133775150e-9 + (0.10994259106646731797e-11 + 0.52409949102222222221e-14 * t) * t) * t) * t) * t) * t; +} +case 21: { +T t = 2*y100 - 43; +return 0.38540888038840509795e-1 + (0.11364917134175420009e-2 + (0.70058230641246312003e-5 + (0.40943644083718586939e-7 + (0.22563034723692881631e-9 + (0.11642841011361992885e-11 + 0.55721092871111111110e-14 * t) * t) * t) * t) * t) * t; +} +case 22: { +T t = 2*y100 - 45; +return 0.40842225954785960651e-1 + (0.11650136437945673891e-2 + (0.72569945502343006619e-5 + (0.42796161861855042273e-7 + (0.23761401711005024162e-9 + (0.12332431172381557035e-11 + 0.59246802364444444445e-14 * t) * t) * t) * t) * t) * t; +} +case 23: { +T t = 2*y100 - 47; +return 0.43201627431540222422e-1 + (0.11945628793917272199e-2 + (0.75195743532849206263e-5 + (0.44747364553960993492e-7 + (0.25030885216472953674e-9 + (0.13065684400300476484e-11 + 0.63000532853333333334e-14 * t) * t) * t) * t) * t) * t; +} +case 24: { +T t = 2*y100 - 49; +return 0.45621193513810471438e-1 + (0.12251862608067529503e-2 + (0.77941720055551920319e-5 + (0.46803119830954460212e-7 + (0.26375990983978426273e-9 + (0.13845421370977119765e-11 + 0.66996477404444444445e-14 * t) * t) * t) * t) * t) * t; +} +case 25: { +T t = 2*y100 - 51; +return 0.48103121413299865517e-1 + (0.12569331386432195113e-2 + (0.80814333496367673980e-5 + (0.48969667335682018324e-7 + (0.27801515481905748484e-9 + (0.14674637611609884208e-11 + 0.71249589351111111110e-14 * t) * t) * t) * t) * t) * t; +} +case 26: { +T t = 2*y100 - 53; +return 0.50649709676983338501e-1 + (0.12898555233099055810e-2 + (0.83820428414568799654e-5 + (0.51253642652551838659e-7 + (0.29312563849675507232e-9 + (0.15556512782814827846e-11 + 0.75775607822222222221e-14 * t) * t) * t) * t) * t) * t; +} +case 27: { +T t = 2*y100 - 55; +return 0.53263363664388864181e-1 + (0.13240082443256975769e-2 + (0.86967260015007658418e-5 + (0.53662102750396795566e-7 + (0.30914568786634796807e-9 + (0.16494420240828493176e-11 + 0.80591079644444444445e-14 * t) * t) * t) * t) * t) * t; +} +case 28: { +T t = 2*y100 - 57; +return 0.55946601353500013794e-1 + (0.13594491197408190706e-2 + (0.90262520233016380987e-5 + (0.56202552975056695376e-7 + (0.32613310410503135996e-9 + (0.17491936862246367398e-11 + 0.85713381688888888890e-14 * t) * t) * t) * t) * t) * t; +} +case 29: { +T t = 2*y100 - 59; +return 0.58702059496154081813e-1 + (0.13962391363223647892e-2 + (0.93714365487312784270e-5 + (0.58882975670265286526e-7 + (0.34414937110591753387e-9 + (0.18552853109751857859e-11 + 0.91160736711111111110e-14 * t) * t) * t) * t) * t) * t; +} +case 30: { +T t = 2*y100 - 61; +return 0.61532500145144778048e-1 + (0.14344426411912015247e-2 + (0.97331446201016809696e-5 + (0.61711860507347175097e-7 + (0.36325987418295300221e-9 + (0.19681183310134518232e-11 + 0.96952238400000000000e-14 * t) * t) * t) * t) * t) * t; +} +case 31: { +T t = 2*y100 - 63; +return 0.64440817576653297993e-1 + (0.14741275456383131151e-2 + (0.10112293819576437838e-4 + (0.64698236605933246196e-7 + (0.38353412915303665586e-9 + (0.20881176114385120186e-11 + 0.10310784480000000000e-13 * t) * t) * t) * t) * t) * t; +} +case 32: { +T t = 2*y100 - 65; +return 0.67430045633130393282e-1 + (0.15153655418916540370e-2 + (0.10509857606888328667e-4 + (0.67851706529363332855e-7 + (0.40504602194811140006e-9 + (0.22157325110542534469e-11 + 0.10964842115555555556e-13 * t) * t) * t) * t) * t) * t; +} +case 33: { +T t = 2*y100 - 67; +return 0.70503365513338850709e-1 + (0.15582323336495709827e-2 + (0.10926868866865231089e-4 + (0.71182482239613507542e-7 + (0.42787405890153386710e-9 + (0.23514379522274416437e-11 + 0.11659571751111111111e-13 * t) * t) * t) * t) * t) * t; +} +case 34: { +T t = 2*y100 - 69; +return 0.73664114037944596353e-1 + (0.16028078812438820413e-2 + (0.11364423678778207991e-4 + (0.74701423097423182009e-7 + (0.45210162777476488324e-9 + (0.24957355004088569134e-11 + 0.12397238257777777778e-13 * t) * t) * t) * t) * t) * t; +} +case 35: { +T t = 2*y100 - 71; +return 0.76915792420819562379e-1 + (0.16491766623447889354e-2 + (0.11823685320041302169e-4 + (0.78420075993781544386e-7 + (0.47781726956916478925e-9 + (0.26491544403815724749e-11 + 0.13180196462222222222e-13 * t) * t) * t) * t) * t) * t; +} +case 36: { +T t = 2*y100 - 73; +return 0.80262075578094612819e-1 + (0.16974279491709504117e-2 + (0.12305888517309891674e-4 + (0.82350717698979042290e-7 + (0.50511496109857113929e-9 + (0.28122528497626897696e-11 + 0.14010889635555555556e-13 * t) * t) * t) * t) * t) * t; +} +case 37: { +T t = 2*y100 - 75; +return 0.83706822008980357446e-1 + (0.17476561032212656962e-2 + (0.12812343958540763368e-4 + (0.86506399515036435592e-7 + (0.53409440823869467453e-9 + (0.29856186620887555043e-11 + 0.14891851591111111111e-13 * t) * t) * t) * t) * t) * t; +} +case 38: { +T t = 2*y100 - 77; +return 0.87254084284461718231e-1 + (0.17999608886001962327e-2 + (0.13344443080089492218e-4 + (0.90900994316429008631e-7 + (0.56486134972616465316e-9 + (0.31698707080033956934e-11 + 0.15825697795555555556e-13 * t) * t) * t) * t) * t) * t; +} +case 39: { +T t = 2*y100 - 79; +return 0.90908120182172748487e-1 + (0.18544478050657699758e-2 + (0.13903663143426120077e-4 + (0.95549246062549906177e-7 + (0.59752787125242054315e-9 + (0.33656597366099099413e-11 + 0.16815130613333333333e-13 * t) * t) * t) * t) * t) * t; +} +case 40: { +T t = 2*y100 - 81; +return 0.94673404508075481121e-1 + (0.19112284419887303347e-2 + (0.14491572616545004930e-4 + (0.10046682186333613697e-6 + (0.63221272959791000515e-9 + (0.35736693975589130818e-11 + 0.17862931591111111111e-13 * t) * t) * t) * t) * t) * t; +} +case 41: { +T t = 2*y100 - 83; +return 0.98554641648004456555e-1 + (0.19704208544725622126e-2 + (0.15109836875625443935e-4 + (0.10567036667675984067e-6 + (0.66904168640019354565e-9 + (0.37946171850824333014e-11 + 0.18971959040000000000e-13 * t) * t) * t) * t) * t) * t; +} +case 42: { +T t = 2*y100 - 85; +return 0.10255677889470089531e0 + (0.20321499629472857418e-2 + (0.15760224242962179564e-4 + (0.11117756071353507391e-6 + (0.70814785110097658502e-9 + (0.40292553276632563925e-11 + 0.20145143075555555556e-13 * t) * t) * t) * t) * t) * t; +} +case 43: { +T t = 2*y100 - 87; +return 0.10668502059865093318e0 + (0.20965479776148731610e-2 + (0.16444612377624983565e-4 + (0.11700717962026152749e-6 + (0.74967203250938418991e-9 + (0.42783716186085922176e-11 + 0.21385479360000000000e-13 * t) * t) * t) * t) * t) * t; +} +case 44: { +T t = 2*y100 - 89; +return 0.11094484319386444474e0 + (0.21637548491908170841e-2 + (0.17164995035719657111e-4 + (0.12317915750735938089e-6 + (0.79376309831499633734e-9 + (0.45427901763106353914e-11 + 0.22696025653333333333e-13 * t) * t) * t) * t) * t) * t; +} +case 45: { +T t = 2*y100 - 91; +return 0.11534201115268804714e0 + (0.22339187474546420375e-2 + (0.17923489217504226813e-4 + (0.12971465288245997681e-6 + (0.84057834180389073587e-9 + (0.48233721206418027227e-11 + 0.24079890062222222222e-13 * t) * t) * t) * t) * t) * t; +} +case 46: { +T t = 2*y100 - 93; +return 0.11988259392684094740e0 + (0.23071965691918689601e-2 + (0.18722342718958935446e-4 + (0.13663611754337957520e-6 + (0.89028385488493287005e-9 + (0.51210161569225846701e-11 + 0.25540227111111111111e-13 * t) * t) * t) * t) * t) * t; +} +case 47: { +T t = 2*y100 - 95; +return 0.12457298393509812907e0 + (0.23837544771809575380e-2 + (0.19563942105711612475e-4 + (0.14396736847739470782e-6 + (0.94305490646459247016e-9 + (0.54366590583134218096e-11 + 0.27080225920000000000e-13 * t) * t) * t) * t) * t) * t; +} +case 48: { +T t = 2*y100 - 97; +return 0.12941991566142438816e0 + (0.24637684719508859484e-2 + (0.20450821127475879816e-4 + (0.15173366280523906622e-6 + (0.99907632506389027739e-9 + (0.57712760311351625221e-11 + 0.28703099555555555556e-13 * t) * t) * t) * t) * t) * t; +} +case 49: { +T t = 2*y100 - 99; +return 0.13443048593088696613e0 + (0.25474249981080823877e-2 + (0.21385669591362915223e-4 + (0.15996177579900443030e-6 + (0.10585428844575134013e-8 + (0.61258809536787882989e-11 + 0.30412080142222222222e-13 * t) * t) * t) * t) * t) * t; +} +case 50: { +T t = 2*y100 - 101; +return 0.13961217543434561353e0 + (0.26349215871051761416e-2 + (0.22371342712572567744e-4 + (0.16868008199296822247e-6 + (0.11216596910444996246e-8 + (0.65015264753090890662e-11 + 0.32210394506666666666e-13 * t) * t) * t) * t) * t) * t; +} +case 51: { +T t = 2*y100 - 103; +return 0.14497287157673800690e0 + (0.27264675383982439814e-2 + (0.23410870961050950197e-4 + (0.17791863939526376477e-6 + (0.11886425714330958106e-8 + (0.68993039665054288034e-11 + 0.34101266222222222221e-13 * t) * t) * t) * t) * t) * t; +} +case 52: { +T t = 2*y100 - 105; +return 0.15052089272774618151e0 + (0.28222846410136238008e-2 + (0.24507470422713397006e-4 + (0.18770927679626136909e-6 + (0.12597184587583370712e-8 + (0.73203433049229821618e-11 + 0.36087889048888888890e-13 * t) * t) * t) * t) * t) * t; +} +case 53: { +T t = 2*y100 - 107; +return 0.15626501395774612325e0 + (0.29226079376196624949e-2 + (0.25664553693768450545e-4 + (0.19808568415654461964e-6 + (0.13351257759815557897e-8 + (0.77658124891046760667e-11 + 0.38173420035555555555e-13 * t) * t) * t) * t) * t) * t; +} +case 54: { +T t = 2*y100 - 109; +return 0.16221449434620737567e0 + (0.30276865332726475672e-2 + (0.26885741326534564336e-4 + (0.20908350604346384143e-6 + (0.14151148144240728728e-8 + (0.82369170665974313027e-11 + 0.40360957457777777779e-13 * t) * t) * t) * t) * t) * t; +} +case 55: { +T t = 2*y100 - 111; +return 0.16837910595412130659e0 + (0.31377844510793082301e-2 + (0.28174873844911175026e-4 + (0.22074043807045782387e-6 + (0.14999481055996090039e-8 + (0.87348993661930809254e-11 + 0.42653528977777777779e-13 * t) * t) * t) * t) * t) * t; +} +case 56: { +T t = 2*y100 - 113; +return 0.17476916455659369953e0 + (0.32531815370903068316e-2 + (0.29536024347344364074e-4 + (0.23309632627767074202e-6 + (0.15899007843582444846e-8 + (0.92610375235427359475e-11 + 0.45054073102222222221e-13 * t) * t) * t) * t) * t) * t; +} +case 57: { +T t = 2*y100 - 115; +return 0.18139556223643701364e0 + (0.33741744168096996041e-2 + (0.30973511714709500836e-4 + (0.24619326937592290996e-6 + (0.16852609412267750744e-8 + (0.98166442942854895573e-11 + 0.47565418097777777779e-13 * t) * t) * t) * t) * t) * t; +} +case 58: { +T t = 2*y100 - 117; +return 0.18826980194443664549e0 + (0.35010775057740317997e-2 + (0.32491914440014267480e-4 + (0.26007572375886319028e-6 + (0.17863299617388376116e-8 + (0.10403065638343878679e-10 + 0.50190265831111111110e-13 * t) * t) * t) * t) * t) * t; +} +case 59: { +T t = 2*y100 - 119; +return 0.19540403413693967350e0 + (0.36342240767211326315e-2 + (0.34096085096200907289e-4 + (0.27479061117017637474e-6 + (0.18934228504790032826e-8 + (0.11021679075323598664e-10 + 0.52931171733333333334e-13 * t) * t) * t) * t) * t) * t; +} +case 60: { +T t = 2*y100 - 121; +return 0.20281109560651886959e0 + (0.37739673859323597060e-2 + (0.35791165457592409054e-4 + (0.29038742889416172404e-6 + (0.20068685374849001770e-8 + (0.11673891799578381999e-10 + 0.55790523093333333334e-13 * t) * t) * t) * t) * t) * t; +} +case 61: { +T t = 2*y100 - 123; +return 0.21050455062669334978e0 + (0.39206818613925652425e-2 + (0.37582602289680101704e-4 + (0.30691836231886877385e-6 + (0.21270101645763677824e-8 + (0.12361138551062899455e-10 + 0.58770520160000000000e-13 * t) * t) * t) * t) * t) * t; +} +case 62: { +T t = 2*y100 - 125; +return 0.21849873453703332479e0 + (0.40747643554689586041e-2 + (0.39476163820986711501e-4 + (0.32443839970139918836e-6 + (0.22542053491518680200e-8 + (0.13084879235290858490e-10 + 0.61873153262222222221e-13 * t) * t) * t) * t) * t) * t; +} +case 63: { +T t = 2*y100 - 127; +return 0.22680879990043229327e0 + (0.42366354648628516935e-2 + (0.41477956909656896779e-4 + (0.34300544894502810002e-6 + (0.23888264229264067658e-8 + (0.13846596292818514601e-10 + 0.65100183751111111110e-13 * t) * t) * t) * t) * t) * t; +} +case 64: { +T t = 2*y100 - 129; +return 0.23545076536988703937e0 + (0.44067409206365170888e-2 + (0.43594444916224700881e-4 + (0.36268045617760415178e-6 + (0.25312606430853202748e-8 + (0.14647791812837903061e-10 + 0.68453122631111111110e-13 * t) * t) * t) * t) * t) * t; +} +case 65: { +T t = 2*y100 - 131; +return 0.24444156740777432838e0 + (0.45855530511605787178e-2 + (0.45832466292683085475e-4 + (0.38352752590033030472e-6 + (0.26819103733055603460e-8 + (0.15489984390884756993e-10 + 0.71933206364444444445e-13 * t) * t) * t) * t) * t) * t; +} +case 66: { +T t = 2*y100 - 133; +return 0.25379911500634264643e0 + (0.47735723208650032167e-2 + (0.48199253896534185372e-4 + (0.40561404245564732314e-6 + (0.28411932320871165585e-8 + (0.16374705736458320149e-10 + 0.75541379822222222221e-13 * t) * t) * t) * t) * t) * t; +} +case 67: { +T t = 2*y100 - 135; +return 0.26354234756393613032e0 + (0.49713289477083781266e-2 + (0.50702455036930367504e-4 + (0.42901079254268185722e-6 + (0.30095422058900481753e-8 + (0.17303497025347342498e-10 + 0.79278273368888888890e-13 * t) * t) * t) * t) * t) * t; +} +case 68: { +T t = 2*y100 - 137; +return 0.27369129607732343398e0 + (0.51793846023052643767e-2 + (0.53350152258326602629e-4 + (0.45379208848865015485e-6 + (0.31874057245814381257e-8 + (0.18277905010245111046e-10 + 0.83144182364444444445e-13 * t) * t) * t) * t) * t) * t; +} +case 69: { +T t = 2*y100 - 139; +return 0.28426714781640316172e0 + (0.53983341916695141966e-2 + (0.56150884865255810638e-4 + (0.48003589196494734238e-6 + (0.33752476967570796349e-8 + (0.19299477888083469086e-10 + 0.87139049137777777779e-13 * t) * t) * t) * t) * t) * t; +} +case 70: { +T t = 2*y100 - 141; +return 0.29529231465348519920e0 + (0.56288077305420795663e-2 + (0.59113671189913307427e-4 + (0.50782393781744840482e-6 + (0.35735475025851713168e-8 + (0.20369760937017070382e-10 + 0.91262442613333333334e-13 * t) * t) * t) * t) * t) * t; +} +case 71: { +T t = 2*y100 - 143; +return 0.30679050522528838613e0 + (0.58714723032745403331e-2 + (0.62248031602197686791e-4 + (0.53724185766200945789e-6 + (0.37827999418960232678e-8 + (0.21490291930444538307e-10 + 0.95513539182222222221e-13 * t) * t) * t) * t) * t) * t; +} +case 72: { +T t = 2*y100 - 145; +return 0.31878680111173319425e0 + (0.61270341192339103514e-2 + (0.65564012259707640976e-4 + (0.56837930287837738996e-6 + (0.40035151353392378882e-8 + (0.22662596341239294792e-10 + 0.99891109760000000000e-13 * t) * t) * t) * t) * t) * t; +} +case 73: { +T t = 2*y100 - 147; +return 0.33130773722152622027e0 + (0.63962406646798080903e-2 + (0.69072209592942396666e-4 + (0.60133006661885941812e-6 + (0.42362183765883466691e-8 + (0.23888182347073698382e-10 + 0.10439349811555555556e-12 * t) * t) * t) * t) * t) * t; +} +case 74: { +T t = 2*y100 - 149; +return 0.34438138658041336523e0 + (0.66798829540414007258e-2 + (0.72783795518603561144e-4 + (0.63619220443228800680e-6 + (0.44814499336514453364e-8 + (0.25168535651285475274e-10 + 0.10901861383111111111e-12 * t) * t) * t) * t) * t) * t; +} +case 75: { +T t = 2*y100 - 151; +return 0.35803744972380175583e0 + (0.69787978834882685031e-2 + (0.76710543371454822497e-4 + (0.67306815308917386747e-6 + (0.47397647975845228205e-8 + (0.26505114141143050509e-10 + 0.11376390933333333333e-12 * t) * t) * t) * t) * t) * t; +} +case 76: { +T t = 2*y100 - 153; +return 0.37230734890119724188e0 + (0.72938706896461381003e-2 + (0.80864854542670714092e-4 + (0.71206484718062688779e-6 + (0.50117323769745883805e-8 + (0.27899342394100074165e-10 + 0.11862637614222222222e-12 * t) * t) * t) * t) * t) * t; +} +case 77: { +T t = 2*y100 - 155; +return 0.38722432730555448223e0 + (0.76260375162549802745e-2 + (0.85259785810004603848e-4 + (0.75329383305171327677e-6 + (0.52979361368388119355e-8 + (0.29352606054164086709e-10 + 0.12360253370666666667e-12 * t) * t) * t) * t) * t) * t; +} +case 78: { +T t = 2*y100 - 157; +return 0.40282355354616940667e0 + (0.79762880915029728079e-2 + (0.89909077342438246452e-4 + (0.79687137961956194579e-6 + (0.55989731807360403195e-8 + (0.30866246101464869050e-10 + 0.12868841946666666667e-12 * t) * t) * t) * t) * t) * t; +} +case 79: { +T t = 2*y100 - 159; +return 0.41914223158913787649e0 + (0.83456685186950463538e-2 + (0.94827181359250161335e-4 + (0.84291858561783141014e-6 + (0.59154537751083485684e-8 + (0.32441553034347469291e-10 + 0.13387957943111111111e-12 * t) * t) * t) * t) * t) * t; +} +case 80: { +T t = 2*y100 - 161; +return 0.43621971639463786896e0 + (0.87352841828289495773e-2 + (0.10002929142066799966e-3 + (0.89156148280219880024e-6 + (0.62480008150788597147e-8 + (0.34079760983458878910e-10 + 0.13917107176888888889e-12 * t) * t) * t) * t) * t) * t; +} +case 81: { +T t = 2*y100 - 163; +return 0.45409763548534330981e0 + (0.91463027755548240654e-2 + (0.10553137232446167258e-3 + (0.94293113464638623798e-6 + (0.65972492312219959885e-8 + (0.35782041795476563662e-10 + 0.14455745872000000000e-12 * t) * t) * t) * t) * t) * t; +} +case 82: { +T t = 2*y100 - 165; +return 0.47282001668512331468e0 + (0.95799574408860463394e-2 + (0.11135019058000067469e-3 + (0.99716373005509038080e-6 + (0.69638453369956970347e-8 + (0.37549499088161345850e-10 + 0.15003280712888888889e-12 * t) * t) * t) * t) * t) * t; +} +case 83: { +T t = 2*y100 - 167; +return 0.49243342227179841649e0 + (0.10037550043909497071e-1 + (0.11750334542845234952e-3 + (0.10544006716188967172e-5 + (0.73484461168242224872e-8 + (0.39383162326435752965e-10 + 0.15559069118222222222e-12 * t) * t) * t) * t) * t) * t; +} +case 84: { +T t = 2*y100 - 169; +return 0.51298708979209258326e0 + (0.10520454564612427224e-1 + (0.12400930037494996655e-3 + (0.11147886579371265246e-5 + (0.77517184550568711454e-8 + (0.41283980931872622611e-10 + 0.16122419680000000000e-12 * t) * t) * t) * t) * t) * t; +} +case 85: { +T t = 2*y100 - 171; +return 0.53453307979101369843e0 + (0.11030120618800726938e-1 + (0.13088741519572269581e-3 + (0.11784797595374515432e-5 + (0.81743383063044825400e-8 + (0.43252818449517081051e-10 + 0.16692592640000000000e-12 * t) * t) * t) * t) * t) * t; +} +case 86: { +T t = 2*y100 - 173; +return 0.55712643071169299478e0 + (0.11568077107929735233e-1 + (0.13815797838036651289e-3 + (0.12456314879260904558e-5 + (0.86169898078969313597e-8 + (0.45290446811539652525e-10 + 0.17268801084444444444e-12 * t) * t) * t) * t) * t) * t; +} +case 87: { +T t = 2*y100 - 175; +return 0.58082532122519320968e0 + (0.12135935999503877077e-1 + (0.14584223996665838559e-3 + (0.13164068573095710742e-5 + (0.90803643355106020163e-8 + (0.47397540713124619155e-10 + 0.17850211608888888889e-12 * t) * t) * t) * t) * t) * t; +} +case 88: { +T t = 2*y100 - 177; +return 0.60569124025293375554e0 + (0.12735396239525550361e-1 + (0.15396244472258863344e-3 + (0.13909744385382818253e-5 + (0.95651595032306228245e-8 + (0.49574672127669041550e-10 + 0.18435945564444444444e-12 * t) * t) * t) * t) * t) * t; +} +case 89: { +T t = 2*y100 - 179; +return 0.63178916494715716894e0 + (0.13368247798287030927e-1 + (0.16254186562762076141e-3 + (0.14695084048334056083e-5 + (0.10072078109604152350e-7 + (0.51822304995680707483e-10 + 0.19025081422222222222e-12 * t) * t) * t) * t) * t) * t; +} +case 90: { +T t = 2*y100 - 181; +return 0.65918774689725319200e0 + (0.14036375850601992063e-1 + (0.17160483760259706354e-3 + (0.15521885688723188371e-5 + (0.10601827031535280590e-7 + (0.54140790105837520499e-10 + 0.19616655146666666667e-12 * t) * t) * t) * t) * t) * t; +} +case 91: { +T t = 2*y100 - 183; +return 0.68795950683174433822e0 + (0.14741765091365869084e-1 + (0.18117679143520433835e-3 + (0.16392004108230585213e-5 + (0.11155116068018043001e-7 + (0.56530360194925690374e-10 + 0.20209663662222222222e-12 * t) * t) * t) * t) * t) * t; +} +case 92: { +T t = 2*y100 - 185; +return 0.71818103808729967036e0 + (0.15486504187117112279e-1 + (0.19128428784550923217e-3 + (0.17307350969359975848e-5 + (0.11732656736113607751e-7 + (0.58991125287563833603e-10 + 0.20803065333333333333e-12 * t) * t) * t) * t) * t) * t; +} +case 93: { +T t = 2*y100 - 187; +return 0.74993321911726254661e0 + (0.16272790364044783382e-1 + (0.20195505163377912645e-3 + (0.18269894883203346953e-5 + (0.12335161021630225535e-7 + (0.61523068312169087227e-10 + 0.21395783431111111111e-12 * t) * t) * t) * t) * t) * t; +} +case 94: { +T t = 2*y100 - 189; +return 0.78330143531283492729e0 + (0.17102934132652429240e-1 + (0.21321800585063327041e-3 + (0.19281661395543913713e-5 + (0.12963340087354341574e-7 + (0.64126040998066348872e-10 + 0.21986708942222222222e-12 * t) * t) * t) * t) * t) * t; +} +case 95: { +T t = 2*y100 - 191; +return 0.81837581041023811832e0 + (0.17979364149044223802e-1 + (0.22510330592753129006e-3 + (0.20344732868018175389e-5 + (0.13617902941839949718e-7 + (0.66799760083972474642e-10 + 0.22574701262222222222e-12 * t) * t) * t) * t) * t) * t; +} +case 96: { +T t = 2*y100 - 193; +return 0.85525144775685126237e0 + (0.18904632212547561026e-1 + (0.23764237370371255638e-3 + (0.21461248251306387979e-5 + (0.14299555071870523786e-7 + (0.69543803864694171934e-10 + 0.23158593688888888889e-12 * t) * t) * t) * t) * t) * t; +} +case 97: { +T t = 2*y100 - 195; +return 0.89402868170849933734e0 + (0.19881418399127202569e-1 + (0.25086793128395995798e-3 + (0.22633402747585233180e-5 + (0.15008997042116532283e-7 + (0.72357609075043941261e-10 + 0.23737194737777777778e-12 * t) * t) * t) * t) * t) * t; +} +case 98: { +T t = 2*y100 - 197; +return 0.93481333942870796363e0 + (0.20912536329780368893e-1 + (0.26481403465998477969e-3 + (0.23863447359754921676e-5 + (0.15746923065472184451e-7 + (0.75240468141720143653e-10 + 0.24309291271111111111e-12 * t) * t) * t) * t) * t) * t; +} +case 99: { +T t = 2*y100 - 199; +return 0.97771701335885035464e0 + (0.22000938572830479551e-1 + (0.27951610702682383001e-3 + (0.25153688325245314530e-5 + (0.16514019547822821453e-7 + (0.78191526829368231251e-10 + 0.24873652355555555556e-12 * t) * t) * t) * t) * t) * t; +} + } + // we only get here if y = 1, i.e. |x| < 4*eps, in which case + // erfcx is within 1e-15 of 1.. + return 1.0; +} + +template +C10_HOST_DEVICE inline typename std::enable_if::value, T>::type +calc_erfcx(T x) +{ + if (at::_isnan(x)) { + return x; + } + + if (x >= 0) { + if (x > 50) { // continued-fraction expansion is faster + const T ispi = 0.56418958354775628694807945156; // 1 / sqrt(pi) + if (x > 5e7) { // 1-term expansion, important to avoid overflow + return ispi / x; + } + /* 5-term expansion (rely on compiler for CSE), simplified from: + ispi / (x+0.5/(x+1/(x+1.5/(x+2/x)))) */ + return ispi*((x*x) * (x*x+4.5) + 2) / (x * ((x*x) * (x*x+5) + 3.75)); + } + return erfcx_y100(400/(4+x)); + } + else { + if (x < -26.7) { + return std::numeric_limits::infinity(); + } + else if (x < -6.1) { + return 2*exp(x*x); + } + else { + return 2*exp(x*x) - erfcx_y100(400/(4-x)); + } + } +} + +/* + * Logarithm of Gaussian cumulative distribution function. + + * This implementation of log_ndtr and its helper functions + * follow SciPy's implementation + * See NOTICE for the licenses. + */ +template +inline C10_HOST_DEVICE T calc_log_ndtr(T x) { + T t = x * c10::frac_sqrt_2; + if (x < T{-1.0}) { + return std::log(calc_erfcx(-t) / 2) - t * t; + } else { + return std::log1p(-std::erfc(t) / 2); + } +} + +template +inline C10_HOST_DEVICE T airy_ai_forward(T x) { + static const T AN[] = { + +3.46538101525629032477e-01, + +1.20075952739645805542e+01, + +7.62796053615234516538e+01, + +1.68089224934630576269e+02, + +1.59756391350164413639e+02, + +7.05360906840444183113e+01, + +1.40264691163389668864e+01, + +9.99999999999999995305e-01, + }; + + static const T AD[] = { + +5.67594532638770212846e-01, + +1.47562562584847203173e+01, + +8.45138970141474626562e+01, + +1.77318088145400459522e+02, + +1.64234692871529701831e+02, + +7.14778400825575695274e+01, + +1.40959135607834029598e+01, + +1.00000000000000000470e+00, + }; + + static const T AFN[] = { + -1.31696323418331795333e-01, + -6.26456544431912369773e-01, + -6.93158036036933542233e-01, + -2.79779981545119124951e-01, + -4.91900132609500318020e-02, + -4.06265923594885404393e-03, + -1.59276496239262096340e-04, + -2.77649108155232920844e-06, + -1.67787698489114633780e-08, + }; + + static const T AFD[] = { + +1.33560420706553243746e+01, + +3.26825032795224613948e+01, + +2.67367040941499554804e+01, + +9.18707402907259625840e+00, + +1.47529146771666414581e+00, + +1.15687173795188044134e-01, + +4.40291641615211203805e-03, + +7.54720348287414296618e-05, + +4.51850092970580378464e-07, + }; + + static const T AGN[] = { + +1.97339932091685679179e-02, + +3.91103029615688277255e-01, + +1.06579897599595591108e+00, + +9.39169229816650230044e-01, + +3.51465656105547619242e-01, + +6.33888919628925490927e-02, + +5.85804113048388458567e-03, + +2.82851600836737019778e-04, + +6.98793669997260967291e-06, + +8.11789239554389293311e-08, + +3.41551784765923618484e-10, + }; + + static const T AGD[] = { + +9.30892908077441974853e+00, + +1.98352928718312140417e+01, + +1.55646628932864612953e+01, + +5.47686069422975497931e+00, + +9.54293611618961883998e-01, + +8.64580826352392193095e-02, + +4.12656523824222607191e-03, + +1.01259085116509135510e-04, + +1.17166733214413521882e-06, + +4.91834570062930015649e-09, + }; + + int domain_flag = 0; + + T ai; + + if (std::isinf(x)) { + return std::numeric_limits::quiet_NaN(); + } + + if (x > T(103.892)) { + return T(0.0); + } + + T f; + T g; + T k; + + if (x < T(-2.09)) { + T z = T(1.0) / (T(-2.0) * x * std::sqrt(-x) / T(3.0)); + + T afn = 0.0; + + for (uint8_t index = 0; index <= 8; index++) { + afn = afn * (z * z) + AFN[index]; + } + + T afd = 0.0; + + for (uint8_t index = 0; index <= 8; index++) { + afd = afd * (z * z) + AFD[index]; + } + + T agn = 0.0; + + for (uint8_t index = 0; index <= 10 + 0; index++) { + agn = agn * (z * z) + AGN[index]; + } + + T agd = 0.0; + + for (uint8_t index = 0; index <= 10 - 1; index++) { + agd = agd * (z * z) + AGD[index]; + } + + T t = T(-2.0) * x * std::sqrt(-x) / T(3.0) + T(0.25) * c10::pi; + + return T(5.64189583547756286948e-01) / std::sqrt(std::sqrt(-x)) * (std::sin(t) * (T(1.0) + z * z * afn / afd) - std::cos(t) * (z * agn / agd)); + } + + if (x >= T(2.09)) { + domain_flag = 5; + + T zeta = T(2.0) * x * std::sqrt(x) / T(3.0); + + T an = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + an = an * (T(1.0) / zeta) + AN[index]; + } + + T ad = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + ad = ad * (T(1.0) / zeta) + AD[index]; + } + + ai = T(5.64189583547756286948e-01) * (an / ad) / (T(2.0) * std::sqrt(std::sqrt(x)) * std::exp(zeta)); + + if (x > T(8.3203353)) { + return ai; + } + } + + f = 1.0; + g = x; + k = 1.0; + + T m = 1.0; + T n = x; + T t = 1.0; + T z = x * x * x; + + while (t > std::numeric_limits::epsilon()) { + m *= z; + k += T(1.0); + m /= k; + n *= z; + k += T(1.0); + n /= k; + m /= k; + f += m; + k += T(1.0); + n /= k; + g += n; + + t = std::abs(m / f); + } + + if ((domain_flag & 1) == 0) { + return T(0.355028053887817239260) * f - T(0.258819403792806798405) * g; + } + + return ai; +} // T airy_ai(T x) + +template +inline C10_HOST_DEVICE T bessel_j0_forward(T x) { + static const T PP[] = { + +7.96936729297347051624e-04, + +8.28352392107440799803e-02, + +1.23953371646414299388e+00, + +5.44725003058768775090e+00, + +8.74716500199817011941e+00, + +5.30324038235394892183e+00, + +9.99999999999999997821e-01, + }; + + static const T PQ[] = { + +9.24408810558863637013e-04, + +8.56288474354474431428e-02, + +1.25352743901058953537e+00, + +5.47097740330417105182e+00, + +8.76190883237069594232e+00, + +5.30605288235394617618e+00, + +1.00000000000000000218e+00, + }; + + static const T QP[] = { + -1.13663838898469149931e-02, + -1.28252718670509318512e+00, + -1.95539544257735972385e+01, + -9.32060152123768231369e+01, + -1.77681167980488050595e+02, + -1.47077505154951170175e+02, + -5.14105326766599330220e+01, + -6.05014350600728481186e+00, + }; + + static const T QQ[] = { + +6.43178256118178023184e+01, + +8.56430025976980587198e+02, + +3.88240183605401609683e+03, + +7.24046774195652478189e+03, + +5.93072701187316984827e+03, + +2.06209331660327847417e+03, + +2.42005740240291393179e+02, + }; + + static const T RP[] = { + -4.79443220978201773821e+09, + +1.95617491946556577543e+12, + -2.49248344360967716204e+14, + +9.70862251047306323952e+15, + }; + + static const T RQ[] = { + +4.99563147152651017219e+02, + +1.73785401676374683123e+05, + +4.84409658339962045305e+07, + +1.11855537045356834862e+10, + +2.11277520115489217587e+12, + +3.10518229857422583814e+14, + +3.18121955943204943306e+16, + +1.71086294081043136091e+18, + }; + + if (x < T(0)) { + x = -x; + } + + if (x <= T(5.0)) { + if (x < T(0.00001)) { + return T(1.0) - x * x / T(4.0); + } + + T rp = 0.0; + + for (uint8_t index = 0; index <= 3; index++) { + rp = rp * (x * x) + RP[index]; + } + + T rq = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + rq = rq * (x * x) + RQ[index]; + } + + return (x * x - T(5.78318596294678452118e+00)) * (x * x - T(3.04712623436620863991e+01)) * rp / rq; + } + + T pp = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pp = pp * (T(25.0) / (x * x)) + PP[index]; + } + + T pq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pq = pq * (T(25.0) / (x * x)) + PQ[index]; + } + + T qp = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + qp = qp * (T(25.0) / (x * x)) + QP[index]; + } + + T qq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + qq = qq * (T(25.0) / (x * x)) + QQ[index]; + } + + return (pp / pq * std::cos(x - T(0.785398163397448309615660845819875721)) - T(5.0) / x * (qp / qq) * std::sin(x - T(0.785398163397448309615660845819875721))) * T(0.797884560802865355879892119868763737) / std::sqrt(x); +} // bessel_j0_forward(T x) + +template +inline C10_HOST_DEVICE T bessel_j1_forward(T x) { + static const T PP[] = { + +7.62125616208173112003e-04, + +7.31397056940917570436e-02, + +1.12719608129684925192e+00, + +5.11207951146807644818e+00, + +8.42404590141772420927e+00, + +5.21451598682361504063e+00, + +1.00000000000000000254e+00, + }; + + static const T PQ[] = { + +5.71323128072548699714e-04, + +6.88455908754495404082e-02, + +1.10514232634061696926e+00, + +5.07386386128601488557e+00, + +8.39985554327604159757e+00, + +5.20982848682361821619e+00, + +9.99999999999999997461e-01, + }; + + static const T QP[] = { + +5.10862594750176621635e-02, + +4.98213872951233449420e+00, + +7.58238284132545283818e+01, + +3.66779609360150777800e+02, + +7.10856304998926107277e+02, + +5.97489612400613639965e+02, + +2.11688757100572135698e+02, + +2.52070205858023719784e+01, + }; + + static const T QQ[] = { + +7.42373277035675149943e+01, + +1.05644886038262816351e+03, + +4.98641058337653607651e+03, + +9.56231892404756170795e+03, + +7.99704160447350683650e+03, + +2.82619278517639096600e+03, + +3.36093607810698293419e+02, + }; + + static const T RP[] = { + -8.99971225705559398224e+08, + +4.52228297998194034323e+11, + -7.27494245221818276015e+13, + +3.68295732863852883286e+15, + }; + + static const T RQ[] = { + +6.20836478118054335476e+02, + +2.56987256757748830383e+05, + +8.35146791431949253037e+07, + +2.21511595479792499675e+10, + +4.74914122079991414898e+12, + +7.84369607876235854894e+14, + +8.95222336184627338078e+16, + +5.32278620332680085395e+18, + }; + + if (x < T(0.0)) { + return -bessel_j1_forward(-x); + } + + if (x <= T(5.0)) { + T rp = 0.0; + + for (uint8_t index = 0; index <= 3; index++) { + rp = rp * (x * x) + RP[index]; + } + + T rq = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + rq = rq * (x * x) + RQ[index]; + } + + return rp / rq * x * (x * x - T(1.46819706421238932572e+01)) * (x * x - T(4.92184563216946036703e+01)); + } + + T pp = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pp = pp * (T(5.0) / x * (T(5.0) / x)) + PP[index]; + } + + T pq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pq = pq * (T(5.0) / x * (T(5.0) / x)) + PQ[index]; + } + + T qp = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + qp = qp * (T(5.0) / x * (T(5.0) / x)) + QP[index]; + } + + T qq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + qq = qq * (T(5.0) / x * (T(5.0) / x)) + QQ[index]; + } + + return (pp / pq * std::cos(x - T(2.356194490192344928846982537459627163)) - T(5.0) / x * (qp / qq) * std::sin(x - T(2.356194490192344928846982537459627163))) * T(0.797884560802865355879892119868763737) / std::sqrt(x); +} // bessel_j1_forward(T x) + +template +inline C10_HOST_DEVICE T bessel_y0_forward(T x) { + static const T PP[] = { + +7.96936729297347051624e-04, + +8.28352392107440799803e-02, + +1.23953371646414299388e+00, + +5.44725003058768775090e+00, + +8.74716500199817011941e+00, + +5.30324038235394892183e+00, + +9.99999999999999997821e-01, + }; + + static const T PQ[] = { + +9.24408810558863637013e-04, + +8.56288474354474431428e-02, + +1.25352743901058953537e+00, + +5.47097740330417105182e+00, + +8.76190883237069594232e+00, + +5.30605288235394617618e+00, + +1.00000000000000000218e+00, + }; + + static const T QP[] = { + -1.13663838898469149931e-02, + -1.28252718670509318512e+00, + -1.95539544257735972385e+01, + -9.32060152123768231369e+01, + -1.77681167980488050595e+02, + -1.47077505154951170175e+02, + -5.14105326766599330220e+01, + -6.05014350600728481186e+00, + }; + + static const T QQ[] = { + +6.43178256118178023184e+01, + +8.56430025976980587198e+02, + +3.88240183605401609683e+03, + +7.24046774195652478189e+03, + +5.93072701187316984827e+03, + +2.06209331660327847417e+03, + +2.42005740240291393179e+02, + }; + + static const T YP[] = { + +1.55924367855235737965e+04, + -1.46639295903971606143e+07, + +5.43526477051876500413e+09, + -9.82136065717911466409e+11, + +8.75906394395366999549e+13, + -3.46628303384729719441e+15, + +4.42733268572569800351e+16, + -1.84950800436986690637e+16, + }; + + static const T YQ[] = { + +1.04128353664259848412e+03, + +6.26107330137134956842e+05, + +2.68919633393814121987e+08, + +8.64002487103935000337e+10, + +2.02979612750105546709e+13, + +3.17157752842975028269e+15, + +2.50596256172653059228e+17, + }; + + if (x <= T(5.0)) { + if (x == T(0.0)) { + return -std::numeric_limits::infinity(); + } + + if (x < T(0.0)) { + return std::numeric_limits::quiet_NaN(); + } + + T yp = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + yp = yp * (x * x) + YP[index]; + } + + T yq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + yq = yq * (x * x) + YQ[index]; + } + + return yp / yq + (T(0.636619772367581343075535053490057448) * std::log(x) * bessel_j0_forward(x)); + } + + T pp = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pp = pp * (T(25.0) / (x * x)) + PP[index]; + } + + T pq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pq = pq * (T(25.0) / (x * x)) + PQ[index]; + } + + T qp = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + qp = qp * (T(25.0) / (x * x)) + QP[index]; + } + + T qq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + qq = qq * (T(25.0) / (x * x)) + QQ[index]; + } + + return (pp / pq * std::sin(x - T(0.785398163397448309615660845819875721)) + T(5.0) / x * (qp / qq) * std::cos(x - T(0.785398163397448309615660845819875721))) * T(0.797884560802865355879892119868763737) / std::sqrt(x); +} // bessel_y0_forward(T x) + +template +inline C10_HOST_DEVICE T bessel_y1_forward(T x) { + static const T PP[] = { + +7.62125616208173112003e-04, + +7.31397056940917570436e-02, + +1.12719608129684925192e+00, + +5.11207951146807644818e+00, + +8.42404590141772420927e+00, + +5.21451598682361504063e+00, + +1.00000000000000000254e+00, + }; + + static const T PQ[] = { + +5.71323128072548699714e-04, + +6.88455908754495404082e-02, + +1.10514232634061696926e+00, + +5.07386386128601488557e+00, + +8.39985554327604159757e+00, + +5.20982848682361821619e+00, + +9.99999999999999997461e-01, + }; + + static const T QP[] = { + +5.10862594750176621635e-02, + +4.98213872951233449420e+00, + +7.58238284132545283818e+01, + +3.66779609360150777800e+02, + +7.10856304998926107277e+02, + +5.97489612400613639965e+02, + +2.11688757100572135698e+02, + +2.52070205858023719784e+01, + }; + + static const T QQ[] = { + +7.42373277035675149943e+01, + +1.05644886038262816351e+03, + +4.98641058337653607651e+03, + +9.56231892404756170795e+03, + +7.99704160447350683650e+03, + +2.82619278517639096600e+03, + +3.36093607810698293419e+02, + }; + + static const T YP[] = { + +1.26320474790178026440e+09, + -6.47355876379160291031e+11, + +1.14509511541823727583e+14, + -8.12770255501325109621e+15, + +2.02439475713594898196e+17, + -7.78877196265950026825e+17, + }; + + static const T YQ[] = { + +5.94301592346128195359e+02, + +2.35564092943068577943e+05, + +7.34811944459721705660e+07, + +1.87601316108706159478e+10, + +3.88231277496238566008e+12, + +6.20557727146953693363e+14, + +6.87141087355300489866e+16, + +3.97270608116560655612e+18, + }; + + if (x <= T(5.0)) { + if (x == T(0.0)) { + return -std::numeric_limits::infinity(); + } + + if (x <= T(0.0)) { + return std::numeric_limits::quiet_NaN(); + } + + T yp = 0.0; + + for (uint8_t index = 0; index <= 5; index++) { + yp = yp * (x * x) + YP[index]; + } + + T yq = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + yq = yq * (x * x) + YQ[index]; + } + + return x * (yp / yq) + (T(0.636619772367581343075535053490057448) * (bessel_j1_forward(x) * std::log(x) - T(1.0) / x)); + } + + T pp = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pp = pp * (T(5.0) / x * (T(5.0) / x)) + PP[index]; + } + + T pq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + pq = pq * (T(5.0) / x * (T(5.0) / x)) + PQ[index]; + } + + T qp = 0.0; + + for (uint8_t index = 0; index <= 7; index++) { + qp = qp * (T(5.0) / x * (T(5.0) / x)) + QP[index]; + } + + T qq = 0.0; + + for (uint8_t index = 0; index <= 6; index++) { + qq = qq * (T(5.0) / x * (T(5.0) / x)) + QQ[index]; + } + + return (pp / pq * std::sin(x - T(2.356194490192344928846982537459627163)) + T(5.0) / x * (qp / qq) * std::cos(x - T(2.356194490192344928846982537459627163))) * T(0.797884560802865355879892119868763737) / std::sqrt(x); +} // bessel_y1_forward(T x) + +template +inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (std::abs(x) == T(1.0)) { + if (x > T(0.0) || n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if ((n > 6) && (std::abs(x) < T(1.0))) { + return std::cos(n * std::acos(x)); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x; + } + + T p = T(1.0); + T q = x; + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x) * q - p; + p = q; + q = r; + } + + return r; +} // chebyshev_polynomial_t_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, T n) { + return chebyshev_polynomial_t_forward(x, static_cast(n)); +} // chebyshev_polynomial_t_forward(T x, T n) + +template +inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (std::abs(x) == T(1.0)) { + if (x > T(0.0) || n % 2 == 0) { + return n + 1; + } + + return -(n + 1); + } + + if ((n > 8) && (std::abs(x) < T(1.0))) { + if (std::sin(std::acos(x)) != T(0.0)) { + return std::sin((n + 1) * std::acos(x)) / std::sin(std::acos(x)); + } + + return (n + 1) * std::cos((n + 1) * std::acos(x)) / x; + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x; + } + + T p = T(1.0); + T q = x + x; + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x) * q - p; + p = q; + q = r; + } + + return r; +} // chebyshev_polynomial_u_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, T n) { + return chebyshev_polynomial_u_forward(x, static_cast(n)); +} // chebyshev_polynomial_u_forward(T x, T n) + +template +inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (std::abs(x) == T(1.0)) { + if (x > T(0.0)) { + return T(1.0); + } + + if (n % 2 == 0) { + return n + n + 1; + } + + return -(n + n + 1); + } + + if ((n > 8) && (std::abs(x) < T(1.0))) { + if (std::sin(std::acos(x) / T(2.0)) != T(1.0)) { + return std::cos((n + T(0.5)) * std::acos(x)) / std::cos(std::acos(x) / T(2.0)); + } + + if (n % 2 == 0) { + return n + n + 1; + } + + return -(n + n + 1); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x - T(1.0); + } + + T p = T(1.0); + T q = x + x - T(1.0); + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x) * q - p; + p = q; + q = r; + } + + return r; +} // chebyshev_polynomial_v_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, T n) { + return chebyshev_polynomial_v_forward(x, static_cast(n)); +} // chebyshev_polynomial_v_forward(T x, T n) + +template +inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (std::abs(x) == T(1.0)) { + if (x > T(0.0)) { + return n + n + 1; + } + + if (n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if ((n > 8) && (std::abs(x) < T(1.0))) { + if (std::cos(std::acos(x) / T(2.0)) != T(1.0)) { + return std::sin((n + T(0.5)) * std::acos(x)) / std::sin(std::acos(x) / T(2.0)); + } + + if (x > T(0.0)) { + return n + n + 1; + } + + if (n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x + T(1.0); + } + + T p = T(1.0); + T q = x + x + T(1.0); + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x) * q - p; + p = q; + q = r; + } + + return r; +} // chebyshev_polynomial_w_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, T n) { + return chebyshev_polynomial_w_forward(x, static_cast(n)); +} // chebyshev_polynomial_w_forward(T x, T n) + +template +inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x; + } + + T p = T(1.0); + T q = x + x; + T r = T(0.0); + + for (int64_t k = 2; k < n + n; k += 2) { + r = (x + x) * q - k * p; + p = q; + q = r; + } + + return r; +} // hermite_polynomial_h_forward(T x, int64_t n) + +template::value, int> = 0> +inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, T n) { + return hermite_polynomial_h_forward(x, static_cast(n)); +} // hermite_polynomial_h_forward(T x, T n) + +template::value, int> = 0> +inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, T n) { + return hermite_polynomial_h_forward(x, ((!std::isinf(n)) && (!std::isnan(n))) ? static_cast(n) : static_cast(-1)); +} // hermite_polynomial_h_forward(T x, T n) + +template +inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x; + } + + T p = T(1.0); + T q = x; + T r; + + for (int64_t k = 1; k < n; k++) { + r = x * q - k * p; + p = q; + q = r; + } + + return r; +} // hermite_polynomial_he_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, T n) { + return hermite_polynomial_he_forward(x, static_cast(n)); +} // hermite_polynomial_he_forward(T x, T n) + +template +inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (std::abs(x) == T(0.0)) { + return T(1.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return T(1.0) - x; + } + + T p = T(1.0); + T q = T(1.0) - x; + T r; + + for (int64_t k = 1; k < n; k++) { + r = (((k + k) + (T(1.0) - x)) * q - k * p) / (k + 1); + p = q; + q = r; + } + + return r; +} // laguerre_polynomial_l_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, T n) { + return laguerre_polynomial_l_forward(x, static_cast(n)); +} // laguerre_polynomial_l_forward(T x, T n) + +template +inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (std::abs(x) == T(1.0)) { + if (x > T(0.0) || n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x; + } + + T p = T(1.0); + T q = x; + T r; + + for (int64_t k = 1; k < n; k++) { + r = ((k + k + 1) * x * q - k * p) / (k + 1); + p = q; + q = r; + } + + return r; +} // legendre_polynomial_p_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, T n) { + return legendre_polynomial_p_forward(x, static_cast(n)); +} // legendre_polynomial_p_forward(T x, T n) + +template +inline C10_HOST_DEVICE T modified_bessel_i0_forward(T x) { + static const T A[] = { + -4.41534164647933937950e-18, + +3.33079451882223809783e-17, + -2.43127984654795469359e-16, + +1.71539128555513303061e-15, + -1.16853328779934516808e-14, + +7.67618549860493561688e-14, + -4.85644678311192946090e-13, + +2.95505266312963983461e-12, + -1.72682629144155570723e-11, + +9.67580903537323691224e-11, + -5.18979560163526290666e-10, + +2.65982372468238665035e-09, + -1.30002500998624804212e-08, + +6.04699502254191894932e-08, + -2.67079385394061173391e-07, + +1.11738753912010371815e-06, + -4.41673835845875056359e-06, + +1.64484480707288970893e-05, + -5.75419501008210370398e-05, + +1.88502885095841655729e-04, + -5.76375574538582365885e-04, + +1.63947561694133579842e-03, + -4.32430999505057594430e-03, + +1.05464603945949983183e-02, + -2.37374148058994688156e-02, + +4.93052842396707084878e-02, + -9.49010970480476444210e-02, + +1.71620901522208775349e-01, + -3.04682672343198398683e-01, + +6.76795274409476084995e-01, + }; + + static const T B[] = { + -7.23318048787475395456e-18, + -4.83050448594418207126e-18, + +4.46562142029675999901e-17, + +3.46122286769746109310e-17, + -2.82762398051658348494e-16, + -3.42548561967721913462e-16, + +1.77256013305652638360e-15, + +3.81168066935262242075e-15, + -9.55484669882830764870e-15, + -4.15056934728722208663e-14, + +1.54008621752140982691e-14, + +3.85277838274214270114e-13, + +7.18012445138366623367e-13, + -1.79417853150680611778e-12, + -1.32158118404477131188e-11, + -3.14991652796324136454e-11, + +1.18891471078464383424e-11, + +4.94060238822496958910e-10, + +3.39623202570838634515e-09, + +2.26666899049817806459e-08, + +2.04891858946906374183e-07, + +2.89137052083475648297e-06, + +6.88975834691682398426e-05, + +3.36911647825569408990e-03, + +8.04490411014108831608e-01, + }; + + T p; + T q = 0.0; + + if (std::abs(x) <= T(8.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 30; index++) { + p = q; + q = a; + a = ((std::abs(x) / T(2.0)) - T(2.0)) * q - p + A[index]; + } + + return std::exp(std::abs(x)) * (T(0.5) * (a - p)); + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(32.0) / std::abs(x) - T(2.0)) * q - p + B[index]; + } + + return std::exp(std::abs(x)) * (T(0.5) * (b - p)) / std::sqrt(std::abs(x)); +} // modified_bessel_i0_forward(T x) + +template +inline C10_HOST_DEVICE T modified_bessel_i1_forward(T x) { + static const T A[] = { + +2.77791411276104639959e-18, + -2.11142121435816608115e-17, + +1.55363195773620046921e-16, + -1.10559694773538630805e-15, + +7.60068429473540693410e-15, + -5.04218550472791168711e-14, + +3.22379336594557470981e-13, + -1.98397439776494371520e-12, + +1.17361862988909016308e-11, + -6.66348972350202774223e-11, + +3.62559028155211703701e-10, + -1.88724975172282928790e-09, + +9.38153738649577178388e-09, + -4.44505912879632808065e-08, + +2.00329475355213526229e-07, + -8.56872026469545474066e-07, + +3.47025130813767847674e-06, + -1.32731636560394358279e-05, + +4.78156510755005422638e-05, + -1.61760815825896745588e-04, + +5.12285956168575772895e-04, + -1.51357245063125314899e-03, + +4.15642294431288815669e-03, + -1.05640848946261981558e-02, + +2.47264490306265168283e-02, + -5.29459812080949914269e-02, + +1.02643658689847095384e-01, + -1.76416518357834055153e-01, + +2.52587186443633654823e-01, + }; + + static const T B[] = { + +7.51729631084210481353e-18, + +4.41434832307170791151e-18, + -4.65030536848935832153e-17, + -3.20952592199342395980e-17, + +2.96262899764595013876e-16, + +3.30820231092092828324e-16, + -1.88035477551078244854e-15, + -3.81440307243700780478e-15, + +1.04202769841288027642e-14, + +4.27244001671195135429e-14, + -2.10154184277266431302e-14, + -4.08355111109219731823e-13, + -7.19855177624590851209e-13, + +2.03562854414708950722e-12, + +1.41258074366137813316e-11, + +3.25260358301548823856e-11, + -1.89749581235054123450e-11, + -5.58974346219658380687e-10, + -3.83538038596423702205e-09, + -2.63146884688951950684e-08, + -2.51223623787020892529e-07, + -3.88256480887769039346e-06, + -1.10588938762623716291e-04, + -9.76109749136146840777e-03, + +7.78576235018280120474e-01, + }; + + T p; + T q = 0.0; + + if (std::abs(x) <= T(8.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 29; index++) { + p = q; + q = a; + a = ((std::abs(x) / T(2.0)) - T(2.0)) * q - p + A[index]; + } + + if (x < T(0.0)) { + return -(T(0.5) * (a - p) * std::abs(x) * std::exp(std::abs(x))); + } + + return T(0.5) * (a - p) * std::abs(x) * std::exp(std::abs(x)); + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(32.0) / std::abs(x) - T(2.0)) * q - p + B[index]; + } + + if (x < T(0.0)) { + return -(std::exp(std::abs(x)) * (T(0.5) * (b - p)) / std::sqrt(std::abs(x))); + } + + return std::exp(std::abs(x)) * (T(0.5) * (b - p)) / std::sqrt(std::abs(x)); +} // modified_bessel_i1_forward(T x) + +template +inline C10_HOST_DEVICE T modified_bessel_k0_forward(T x) { + static const T A[] = { + +1.37446543561352307156e-16, + +4.25981614279661018399e-14, + +1.03496952576338420167e-11, + +1.90451637722020886025e-09, + +2.53479107902614945675e-07, + +2.28621210311945178607e-05, + +1.26461541144692592338e-03, + +3.59799365153615016266e-02, + +3.44289899924628486886e-01, + -5.35327393233902768720e-01, + }; + + static const T B[] = { + +5.30043377268626276149e-18, + -1.64758043015242134646e-17, + +5.21039150503902756861e-17, + -1.67823109680541210385e-16, + +5.51205597852431940784e-16, + -1.84859337734377901440e-15, + +6.34007647740507060557e-15, + -2.22751332699166985548e-14, + +8.03289077536357521100e-14, + -2.98009692317273043925e-13, + +1.14034058820847496303e-12, + -4.51459788337394416547e-12, + +1.85594911495471785253e-11, + -7.95748924447710747776e-11, + +3.57739728140030116597e-10, + -1.69753450938905987466e-09, + +8.57403401741422608519e-09, + -4.66048989768794782956e-08, + +2.76681363944501510342e-07, + -1.83175552271911948767e-06, + +1.39498137188764993662e-05, + -1.28495495816278026384e-04, + +1.56988388573005337491e-03, + -3.14481013119645005427e-02, + +2.44030308206595545468e+00, + }; + + if (x == T(0.0)) { + return std::numeric_limits::infinity(); + } + + if (x < T(0.0)) { + return std::numeric_limits::quiet_NaN(); + } + + T p; + T q = 0.0; + + if (x <= T(2.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 10; index++) { + p = q; + q = a; + a = (x * x - T(2.0)) * q - p + A[index]; + } + + return T(0.5) * (a - p) - std::log(0.5 * x) * modified_bessel_i0_forward(x); + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(8.0) / x - T(2.0)) * q - p + B[index]; + } + + return std::exp(-x) * (T(0.5) * (b - p)) / std::sqrt(x); +} // modified_bessel_k0_forward(T x) + +template +inline C10_HOST_DEVICE T modified_bessel_k1_forward(T x) { + static const T A[] = { + -7.02386347938628759343e-18, + -2.42744985051936593393e-15, + -6.66690169419932900609e-13, + -1.41148839263352776110e-10, + -2.21338763073472585583e-08, + -2.43340614156596823496e-06, + -1.73028895751305206302e-04, + -6.97572385963986435018e-03, + -1.22611180822657148235e-01, + -3.53155960776544875667e-01, + +1.52530022733894777053e+00, + }; + + static const T B[] = { + -5.75674448366501715755e-18, + +1.79405087314755922667e-17, + -5.68946255844285935196e-17, + +1.83809354436663880070e-16, + -6.05704724837331885336e-16, + +2.03870316562433424052e-15, + -7.01983709041831346144e-15, + +2.47715442448130437068e-14, + -8.97670518232499435011e-14, + +3.34841966607842919884e-13, + -1.28917396095102890680e-12, + +5.13963967348173025100e-12, + -2.12996783842756842877e-11, + +9.21831518760500529508e-11, + -4.19035475934189648750e-10, + +2.01504975519703286596e-09, + -1.03457624656780970260e-08, + +5.74108412545004946722e-08, + -3.50196060308781257119e-07, + +2.40648494783721712015e-06, + -1.93619797416608296024e-05, + +1.95215518471351631108e-04, + -2.85781685962277938680e-03, + +1.03923736576817238437e-01, + +2.72062619048444266945e+00, + }; + + if (x == T(0.0)) { + return std::numeric_limits::infinity(); + } + + if (x < T(0.0)) { + return std::numeric_limits::quiet_NaN(); + } + + T p; + T q = 0.0; + + if (x <= T(2.0)) { + T a = A[0]; + + for (uint8_t index = 1; index < 11; index++) { + p = q; + q = a; + a = (x * x - T(2.0)) * q - p + A[index]; + } + + return std::log(T(0.5) * x) * modified_bessel_i1_forward(x) + T(0.5) * (a - p) / x; + } + + T b = B[0]; + + for (uint8_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(8.0) / x - T(2.0)) * q - p + B[index]; + } + + return std::exp(-x) * (T(0.5) * (b - p)) / std::sqrt(x); +} // modified_bessel_k1_forward(T x) + +template +inline C10_HOST_DEVICE T scaled_modified_bessel_k0_forward(T x) { + static const T A[] = { + +1.37446543561352307156e-16, + +4.25981614279661018399e-14, + +1.03496952576338420167e-11, + +1.90451637722020886025e-09, + +2.53479107902614945675e-07, + +2.28621210311945178607e-05, + +1.26461541144692592338e-03, + +3.59799365153615016266e-02, + +3.44289899924628486886e-01, + -5.35327393233902768720e-01, + }; + + static const T B[] = { + +5.30043377268626276149e-18, + -1.64758043015242134646e-17, + +5.21039150503902756861e-17, + -1.67823109680541210385e-16, + +5.51205597852431940784e-16, + -1.84859337734377901440e-15, + +6.34007647740507060557e-15, + -2.22751332699166985548e-14, + +8.03289077536357521100e-14, + -2.98009692317273043925e-13, + +1.14034058820847496303e-12, + -4.51459788337394416547e-12, + +1.85594911495471785253e-11, + -7.95748924447710747776e-11, + +3.57739728140030116597e-10, + -1.69753450938905987466e-09, + +8.57403401741422608519e-09, + -4.66048989768794782956e-08, + +2.76681363944501510342e-07, + -1.83175552271911948767e-06, + +1.39498137188764993662e-05, + -1.28495495816278026384e-04, + +1.56988388573005337491e-03, + -3.14481013119645005427e-02, + +2.44030308206595545468e+00, + }; + + if (x == T(0.0)) { + return std::numeric_limits::infinity(); + } + + if (x < T(0.0)) { + return std::numeric_limits::quiet_NaN(); + } + + T p; + T q = 0.0; + + if (x <= T(2.0)) { + T a = A[0]; + + for (uint64_t index = 1; index < 10; index++) { + p = q; + q = a; + a = (x * x - T(2.0)) * q - p + A[index]; + } + + return (T(0.5) * (a - p) - std::log(T(0.5) * x) * modified_bessel_i0_forward(x)) * std::exp(x); + } + + T b = B[0]; + + for (uint64_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(8.0) / x - T(2.0)) * q - p + B[index]; + } + + return T(0.5) * (b - p) / std::sqrt(x); +} // T scaled_modified_bessel_k0_forward(T x) + +template +inline C10_HOST_DEVICE T scaled_modified_bessel_k1_forward(T x) { + static const T A[] = { + -7.02386347938628759343e-18, + -2.42744985051936593393e-15, + -6.66690169419932900609e-13, + -1.41148839263352776110e-10, + -2.21338763073472585583e-08, + -2.43340614156596823496e-06, + -1.73028895751305206302e-04, + -6.97572385963986435018e-03, + -1.22611180822657148235e-01, + -3.53155960776544875667e-01, + +1.52530022733894777053e+00, + }; + + static const T B[] = { + -5.75674448366501715755e-18, + +1.79405087314755922667e-17, + -5.68946255844285935196e-17, + +1.83809354436663880070e-16, + -6.05704724837331885336e-16, + +2.03870316562433424052e-15, + -7.01983709041831346144e-15, + +2.47715442448130437068e-14, + -8.97670518232499435011e-14, + +3.34841966607842919884e-13, + -1.28917396095102890680e-12, + +5.13963967348173025100e-12, + -2.12996783842756842877e-11, + +9.21831518760500529508e-11, + -4.19035475934189648750e-10, + +2.01504975519703286596e-09, + -1.03457624656780970260e-08, + +5.74108412545004946722e-08, + -3.50196060308781257119e-07, + +2.40648494783721712015e-06, + -1.93619797416608296024e-05, + +1.95215518471351631108e-04, + -2.85781685962277938680e-03, + +1.03923736576817238437e-01, + +2.72062619048444266945e+00, + }; + + if (x == T(0.0)) { + return std::numeric_limits::infinity(); + } + + if (x < T(0.0)) { + return std::numeric_limits::quiet_NaN(); + } + + T p; + T q = 0.0; + + if (x <= T(2.0)) { + T a = A[0]; + + for (uint64_t index = 1; index < 11; index++) { + p = q; + q = a; + a = (x * x - T(2.0)) * q - p + A[index]; + } + + return (std::log(T(0.5) * x) * modified_bessel_i1_forward(x) + T(0.5) * (a - p) / x) * std::exp(x); + } + + T b = B[0]; + + for (uint64_t index = 1; index < 25; index++) { + p = q; + q = b; + b = (T(8.0) / x - T(2.0)) * q - p + B[index]; + } + + return (T(0.5) * (b - p) / std::sqrt(x)); +} // T scaled_modified_bessel_k1_forward(T x) + +template +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (x == T(1.0)) { + return T(1.0); + } + + if (x == T(0.0)) { + if (n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if ((n > 6) && (std::abs(x + x - T(1.0)) < T(1.0))) { + return std::cos(n * std::acos(x + x - T(1.0))); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x - T(1.0); + } + + T p = T(1.0); + T q = x + x - T(1.0); + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; + p = q; + q = r; + } + + return r; +} // shifted_chebyshev_polynomial_t_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, T n) { + return shifted_chebyshev_polynomial_t_forward(x, static_cast(n)); +} // shifted_chebyshev_polynomial_t_forward(T x, T n) + +template +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (x == T(1.0)) { + return n + 1; + } + + if (x == T(0.0)) { + if (n % 2 == 0) { + return n + 1; + } + + return -(n + 1); + } + + if ((n > 6) && (std::abs(x + x - T(1.0)) < T(1.0))) { + if (std::sin(std::acos(x + x - T(1.0))) != T(0.0)) { + return std::sin((n + 1) * std::acos(x + x - T(1.0))) / std::sin(std::acos(x + x - T(1.0))); + } + + return (n + 1) * std::cos((n + 1) * std::acos(x + x - T(1.0))) / (x + x - T(1.0)); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x - T(1.0) + (x + x - T(1.0)); + } + + T p = T(1.0); + T q = x + x - T(1.0) + (x + x - T(1.0)); + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; + p = q; + q = r; + } + + return r; +} // shifted_chebyshev_polynomial_u_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, T n) { + return shifted_chebyshev_polynomial_u_forward(x, static_cast(n)); +} // shifted_chebyshev_polynomial_u_forward(T x, T n) + +template +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (x == T(1.0)) { + return T(1.0); + } + + if (x == T(0.0)) { + if (n % 2 == 0) { + return (n + n + 1); + } + + return -(n + n + 1); + } + + if ((n > 6) && (std::abs(x + x - T(1.0)) < T(1.0))) { + if (std::sin(std::acos(x + x - T(1.0)) / T(2.0)) != T(1.0)) { + return std::cos(((n) + T(0.5)) * std::acos(x + x - T(1.0))) / std::cos(std::acos(x + x - T(1.0)) / T(2.0)); + } + + if (n % 2 == 0) { + return n + n + 1; + } + + return -(n + n + 1); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x - T(1.0) + (x + x - T(1.0)) - T(1.0); + } + + T p = T(1.0); + T q = x + x - T(1.0) + (x + x - T(1.0)) - T(1.0); + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; + p = q; + q = r; + } + + return r; +} // shifted_chebyshev_polynomial_v_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, T n) { + return shifted_chebyshev_polynomial_v_forward(x, static_cast(n)); +} // shifted_chebyshev_polynomial_v_forward(T x, T n) + +template +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, int64_t n) { + if (n < 0) { + return T(0.0); + } + + if (x == T(1.0)) { + return n + n + 1; + } + + if (x == T(0.0)) { + if (n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if ((n > 4) && (std::abs(x + x - T(1.0)) < T(1.0))) { + if (std::cos(std::acos(x + x - T(1.0)) / T(2.0)) != T(1.0)) { + return std::sin((n + T(0.5)) * std::acos(x + x - T(1.0))) / std::sin(std::acos(x + x - T(1.0)) / T(2.0)); + } + + if (n % 2 == 0) { + return T(1.0); + } + + return T(-1.0); + } + + if (n == 0) { + return T(1.0); + } + + if (n == 1) { + return x + x - T(1.0) + (x + x - T(1.0)) + T(1.0); + } + + T p = T(1.0); + T q = x + x - T(1.0) + (x + x - T(1.0)) + T(1.0); + T r; + + for (int64_t k = 2; k <= n; k++) { + r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; + p = q; + q = r; + } + + return r; +} // shifted_chebyshev_polynomial_w_forward(T x, int64_t n) + +template +inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, T n) { + return shifted_chebyshev_polynomial_w_forward(x, static_cast(n)); +} // shifted_chebyshev_polynomial_w_forward(T x, T n) + +template +inline C10_HOST_DEVICE T spherical_bessel_j0_forward(T x) { + if (std::isinf(x)) { + return T(0.0); + } + + if (std::abs(x) < T(0.5)) { + return T(1.0) + x * x * (T(-1.0) / T(6.0) + x * x * (T(1.0) / T(120.0) + x * x * (T(-1.0) / T(5040.0) + x * x * (T(1.0) / T(362880.0) + x * x * (T(-1.0) / T(39916800.0) + x * x * (T(1.0) / T(6227020800.0))))))); + } + + return std::sin(x) / x; +} // T spherical_bessel_j0_forward(T x) + +C10_CLANG_DIAGNOSTIC_POP() diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/MathBitFallThroughLists.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/MathBitFallThroughLists.h new file mode 100644 index 0000000000000000000000000000000000000000..97b0854d82d0a2fec6bb708db767d81273ec7bcc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/MathBitFallThroughLists.h @@ -0,0 +1,71 @@ +#pragma once + +namespace at { +// views and their in-place version ops +#define TORCH_VIEW_FNS(m) \ + m.impl("as_strided_", torch::CppFunction::makeFallthrough()); \ + m.impl("detach", torch::CppFunction::makeFallthrough()); \ + m.impl("detach_", torch::CppFunction::makeFallthrough()); \ + m.impl("diagonal", torch::CppFunction::makeFallthrough()); \ + m.impl("expand", torch::CppFunction::makeFallthrough()); \ + m.impl("expand_as", torch::CppFunction::makeFallthrough()); \ + m.impl("movedim.int", torch::CppFunction::makeFallthrough()); \ + m.impl("movedim.intlist", torch::CppFunction::makeFallthrough()); \ + m.impl("narrow", torch::CppFunction::makeFallthrough()); \ + m.impl("permute", torch::CppFunction::makeFallthrough()); \ + m.impl("select.Dimname", torch::CppFunction::makeFallthrough()); \ + m.impl("select.int", torch::CppFunction::makeFallthrough()); \ + m.impl("squeeze", torch::CppFunction::makeFallthrough()); \ + m.impl("squeeze_", torch::CppFunction::makeFallthrough()); \ + m.impl("transpose.int", torch::CppFunction::makeFallthrough()); \ + m.impl("transpose.Dimname", torch::CppFunction::makeFallthrough()); \ + m.impl("transpose_", torch::CppFunction::makeFallthrough()); \ + m.impl("t", torch::CppFunction::makeFallthrough()); \ + m.impl("t_", torch::CppFunction::makeFallthrough()); \ + m.impl("real", torch::CppFunction::makeFallthrough()); \ + m.impl("imag", torch::CppFunction::makeFallthrough()); \ + m.impl("view_as_real", torch::CppFunction::makeFallthrough()); \ + m.impl("unflatten.int", torch::CppFunction::makeFallthrough()); \ + m.impl("unflatten.Dimname", torch::CppFunction::makeFallthrough()); \ + m.impl("unfold", torch::CppFunction::makeFallthrough()); \ + m.impl("unsqueeze", torch::CppFunction::makeFallthrough()); \ + m.impl("unsqueeze_", torch::CppFunction::makeFallthrough()); \ + m.impl("view_as", torch::CppFunction::makeFallthrough()); \ + m.impl("unbind.int", torch::CppFunction::makeFallthrough()); \ + m.impl("unbind.Dimname", torch::CppFunction::makeFallthrough()); \ + m.impl("split.Tensor", torch::CppFunction::makeFallthrough()); \ + m.impl("split_with_sizes", torch::CppFunction::makeFallthrough()); \ + m.impl("swapaxes", torch::CppFunction::makeFallthrough()); \ + m.impl("swapdims", torch::CppFunction::makeFallthrough()); \ + m.impl("chunk", torch::CppFunction::makeFallthrough()); \ + m.impl("reshape", torch::CppFunction::makeFallthrough()); \ + m.impl("alias", torch::CppFunction::makeFallthrough()); \ + m.impl("hsplit.int", torch::CppFunction::makeFallthrough()); \ + m.impl("hsplit.array", torch::CppFunction::makeFallthrough()); \ + m.impl("dsplit.int", torch::CppFunction::makeFallthrough()); \ + m.impl("dsplit.array", torch::CppFunction::makeFallthrough()); \ + m.impl("vsplit.int", torch::CppFunction::makeFallthrough()); \ + m.impl("vsplit.array", torch::CppFunction::makeFallthrough()); \ + m.impl("conj", torch::CppFunction::makeFallthrough()); \ + m.impl("_conj", torch::CppFunction::makeFallthrough()); \ + m.impl("_unsafe_view", torch::CppFunction::makeFallthrough()); \ + m.impl("resize_", torch::CppFunction::makeFallthrough()); + +#define TENSOR_UTILITIES_AND_CONSTRUCTORS(m) \ + m.impl("empty_like", torch::CppFunction::makeFallthrough()); \ + m.impl("empty.memory_format", torch::CppFunction::makeFallthrough()); \ + m.impl("empty.out", torch::CppFunction::makeFallthrough()); \ + m.impl("empty_strided", torch::CppFunction::makeFallthrough()); \ + m.impl("full_like", torch::CppFunction::makeFallthrough()); \ + m.impl("stride.int", torch::CppFunction::makeFallthrough()); \ + m.impl("stride.Dimname", torch::CppFunction::makeFallthrough()); \ + m.impl("size.int", torch::CppFunction::makeFallthrough()); \ + m.impl("size.Dimname", torch::CppFunction::makeFallthrough()); \ + m.impl("is_complex", torch::CppFunction::makeFallthrough()); \ + m.impl("is_floating_point", torch::CppFunction::makeFallthrough()); \ + m.impl("requires_grad_", torch::CppFunction::makeFallthrough()); +} + +#define TORCH_VIEW_FNS_NATIVE_FN_REGISTRATION(m) \ + m.impl("as_strided", torch::CppFunction::makeFallthrough()); \ + m.impl("view", torch::CppFunction::makeFallthrough()); diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/NonEmptyUtils.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/NonEmptyUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..3d18bc5e1525bacaf27d97a86024540236ce6220 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/NonEmptyUtils.h @@ -0,0 +1,27 @@ +#include +#include +#include + +namespace at::native { + +inline int64_t ensure_nonempty_dim(int64_t dim) { + return std::max(dim, 1); +} + +inline int64_t ensure_nonempty_size(const TensorBase &t, int64_t dim) { + return t.dim() == 0 ? 1 : t.size(dim); +} + +inline int64_t ensure_nonempty_stride(const TensorBase &t, int64_t dim) { + return t.dim() == 0 ? 1 : t.stride(dim); +} + +using IdxVec = std::vector; +inline IdxVec ensure_nonempty_vec(IdxVec vec) { + if (vec.empty()) { + vec.push_back(1); + } + return vec; +} + +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Normalization.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Normalization.h new file mode 100644 index 0000000000000000000000000000000000000000..1ba99e77b65c8d73bcab5a1ace75882682bb96ca --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Normalization.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include + +namespace at::native { + +using renorm_scale_factor_fn = void (*) (TensorIteratorBase& iter, double maxnorm); +DECLARE_DISPATCH(renorm_scale_factor_fn, renorm_scale_factor_stub); + +enum class BatchNormBackend { + Native, + Cudnn, + Miopen, +}; + +TORCH_API BatchNormBackend _select_batch_norm_backend(const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double eps); + +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Padding.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Padding.h new file mode 100644 index 0000000000000000000000000000000000000000..53a054027f33d9de4ceb285bb3f95bd16027c402 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Padding.h @@ -0,0 +1,62 @@ +#pragma once + +#include +#include + +namespace at::native { + +using padding_fn = void (*)(const Tensor&, const Tensor&, IntArrayRef); + +// reflection padding +DECLARE_DISPATCH(padding_fn, reflection_pad1d_kernel); +DECLARE_DISPATCH(padding_fn, reflection_pad1d_backward_kernel); +DECLARE_DISPATCH(padding_fn, reflection_pad2d_kernel); +DECLARE_DISPATCH(padding_fn, reflection_pad2d_backward_kernel); +DECLARE_DISPATCH(padding_fn, reflection_pad3d_kernel); +DECLARE_DISPATCH(padding_fn, reflection_pad3d_backward_kernel); + +// replication padding +DECLARE_DISPATCH(padding_fn, replication_pad1d_kernel); +DECLARE_DISPATCH(padding_fn, replication_pad1d_backward_kernel); +DECLARE_DISPATCH(padding_fn, replication_pad2d_kernel); +DECLARE_DISPATCH(padding_fn, replication_pad2d_backward_kernel); +DECLARE_DISPATCH(padding_fn, replication_pad3d_kernel); +DECLARE_DISPATCH(padding_fn, replication_pad3d_backward_kernel); + +namespace padding { + +template +inline void check_valid_input(const Tensor& input, IntArrayRef padding) { + + TORCH_CHECK(padding.size() == 2 * dim, + "padding size is expected to be ", 2 * dim, + ", but got: ", padding.size()); + + int input_dim = input.dim(); + + bool is_batch_mode = input_dim == (dim + 2); + + bool valid_batch_mode = is_batch_mode; + bool valid_non_batch_mode = !is_batch_mode; + + if (is_batch_mode) { + // allow batch size of 0-dim. + for (const auto d : c10::irange(1, input_dim)) { + valid_batch_mode = valid_batch_mode && input.size(d) != 0; + } + } else { + for (const auto d : c10::irange(0, input_dim)) { + valid_non_batch_mode = valid_non_batch_mode && input.size(d) != 0; + } + } + + // allow empty batch size but not other dimensions. + TORCH_CHECK(valid_batch_mode || valid_non_batch_mode, + "Expected ", dim + 1, "D or ", dim + 2, + "D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ", + input.sizes()); +} + +} // namespace padding + +} // at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/PixelShuffle.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/PixelShuffle.h new file mode 100644 index 0000000000000000000000000000000000000000..a9a66a3dbb9d1dfbe8b8a7b926d70bee0a645258 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/PixelShuffle.h @@ -0,0 +1,47 @@ +#include +#include + +namespace at { +namespace native { + +inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_factor) { + TORCH_CHECK(self.dim() >= 3, + "pixel_shuffle expects input to have at least 3 dimensions, but got input with ", + self.dim(), " dimension(s)"); + TORCH_CHECK(upscale_factor > 0, + "pixel_shuffle expects a positive upscale_factor, but got ", + upscale_factor); + int64_t c = self.size(-3); + int64_t upscale_factor_squared = upscale_factor * upscale_factor; + TORCH_CHECK(c % upscale_factor_squared == 0, + "pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of " + "upscale_factor, but input.size(-3)=", c, " is not divisible by ", upscale_factor_squared); +} + +inline void check_pixel_unshuffle_shapes(const Tensor& self, int64_t downscale_factor) { + TORCH_CHECK( + self.dim() >= 3, + "pixel_unshuffle expects input to have at least 3 dimensions, but got input with ", + self.dim(), + " dimension(s)"); + TORCH_CHECK( + downscale_factor > 0, + "pixel_unshuffle expects a positive downscale_factor, but got ", + downscale_factor); + int64_t h = self.size(-2); + int64_t w = self.size(-1); + TORCH_CHECK( + h % downscale_factor == 0, + "pixel_unshuffle expects height to be divisible by downscale_factor, but input.size(-2)=", + h, + " is not divisible by ", + downscale_factor); + TORCH_CHECK( + w % downscale_factor == 0, + "pixel_unshuffle expects width to be divisible by downscale_factor, but input.size(-1)=", + w, + " is not divisible by ", + downscale_factor); +} + +}} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/PointwiseOps.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/PointwiseOps.h new file mode 100644 index 0000000000000000000000000000000000000000..d2e2d44db2af1e74eda272afa6fa80729ab2a2db --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/PointwiseOps.h @@ -0,0 +1,28 @@ +// Ternary and higher-order pointwise operations +#pragma once + +#include + +namespace c10 { +class Scalar; +} + +namespace at { + +struct TensorIterator; +struct TensorIteratorBase; + +namespace native { + +using pointwise_fn = void (*)(TensorIterator&, const Scalar& scalar); +using structured_pointwise_fn = void (*)(TensorIteratorBase&, const Scalar& scalar); +using pointwise_fn_double = void (*)(TensorIterator&, const Scalar&, double); + +DECLARE_DISPATCH(structured_pointwise_fn, addcmul_stub); +DECLARE_DISPATCH(structured_pointwise_fn, addcdiv_stub); +DECLARE_DISPATCH(pointwise_fn_double, smooth_l1_backward_stub); +DECLARE_DISPATCH(pointwise_fn_double, huber_backward_stub); +DECLARE_DISPATCH(pointwise_fn, mse_backward_stub); + +} // namespace native +} // namespace at diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Pool.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Pool.h new file mode 100644 index 0000000000000000000000000000000000000000..896570e3a18f28030c7285551dc9a01b6879c806 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Pool.h @@ -0,0 +1,355 @@ +#include +#include +#include +#include +#include + +#include + +#pragma once + +namespace at::native { + +using max_pool2d_fn = void(*)(const Tensor& output, const Tensor& indices, const Tensor& input, + int kW, int kH, int dW, int dH, int padW, int padH, int dilationW, int dilationH); +using max_pool2d_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices); + +DECLARE_DISPATCH(max_pool2d_fn, max_pool2d_kernel); +DECLARE_DISPATCH(max_pool2d_backward_fn, max_pool2d_backward_kernel); + +// averge pooling has same signature for forward and backward +using avg_pool2d_fn = void(*)(const Tensor& output, const Tensor& input, int64_t kW, int64_t kH, + int64_t dW, int64_t dH, int64_t padW, int64_t padH, bool count_include_pad, std::optional divisor_override); +using avg_pool2d_backward_fn = void(*)(const Tensor& output, const Tensor& input, int kW, int kH, + int dW, int dH, int padW, int padH, bool count_include_pad, std::optional divisor_override); + +DECLARE_DISPATCH(avg_pool2d_fn, avg_pool2d_kernel); +DECLARE_DISPATCH(avg_pool2d_backward_fn, avg_pool2d_backward_kernel); + +// averge pooling has same signature for forward and backward +using avg_pool3d_fn = void(*)(const Tensor& output, const Tensor& input, + int64_t kW, int64_t kH, int64_t kD, int64_t dW, int64_t dH, int64_t dD, + int64_t padW, int64_t padH, int64_t padD, bool count_include_pad, + std::optional divisor_override); +using avg_pool3d_backward_fn = void(*)(const Tensor& output, const Tensor& input, + int kW, int kH, int kD, int dW, int dH, int dD, + int padW, int padH, int padD, bool count_include_pad, + std::optional divisor_override); + +DECLARE_DISPATCH(avg_pool3d_fn, avg_pool3d_kernel); +DECLARE_DISPATCH(avg_pool3d_backward_fn, avg_pool3d_backward_kernel); + +using max_pool3d_fn = void(*)(Tensor& output, Tensor& indices, const Tensor& input, + int kW, int kH, int kD, int dW, int dH, int dD, int pW, int pH, int pD, int dilationW, int dilationH, int dilationD); +using max_pool3d_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output, const Tensor& indices); + +DECLARE_DISPATCH(max_pool3d_fn, max_pool3d_kernel); +DECLARE_DISPATCH(max_pool3d_backward_fn, max_pool3d_backward_kernel); +namespace { + +template +inline dest_t +safe_downcast(src_t v) +{ + TORCH_CHECK(std::numeric_limits::min() <= v && v <= std::numeric_limits::max(), + "integer out of range"); + + return static_cast(v); +} + +template +inline T pooling_output_shape_pad_lr( + T inputSize, T kernelSize, T pad_l, T pad_r, T stride, T dilation, + bool ceil_mode) { + T outputSize = div_rtn( + inputSize + pad_l + pad_r - dilation * (kernelSize - 1) - 1 + + (ceil_mode ? stride - 1 : 0), stride) + 1; + if (ceil_mode) { + // ensure that the last pooling starts inside the image + // needed to avoid problems in ceil mode + if ((outputSize - 1) * stride >= inputSize + pad_l) { + --outputSize; + } + } + return outputSize; +} + +template +inline T pooling_output_shape( + T inputSize, T kernelSize, T pad, T stride, T dilation, bool ceil_mode) { + TORCH_CHECK(stride != 0, "stride should not be zero"); + TORCH_CHECK(pad >= 0, + "pad must be non-negative, but got pad: ", pad); + TORCH_CHECK(pad <= ((kernelSize - 1) * dilation + 1) / 2, + "pad should be at most half of effective kernel size, but got pad=", + pad, ", kernel_size=", kernelSize, " and dilation=", dilation) + return pooling_output_shape_pad_lr( + inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode); +} + +template +std::pair _pooling_same_mode_padding_lr( + T inputSize, T kernelSize, T stride, T dilation) { + // NOTE: with strides, the output shape is ceil(inputSize/stride) + auto total_padding = T(dilation) * (kernelSize - 1); + + // Prefer symmetric padding if possible + if (stride > 2 && (total_padding % 2 == 1)) { + // The floor in the output size calculation gives us a little wiggle room + auto wiggle_room = inputSize % stride - 1; + if (wiggle_room > 0) { + total_padding = total_padding - 1; + } + } + + auto left = total_padding / 2; + return {left, total_padding - left}; +} + +inline std::pair pooling_same_mode_padding_lr( + int64_t inputSize, int64_t kernelSize, int64_t stride, int64_t dilation) { + return _pooling_same_mode_padding_lr(inputSize, kernelSize, stride, dilation); +} + +inline std::pair pooling_same_mode_padding_lr( + c10::SymInt inputSize, c10::SymInt kernelSize, c10::SymInt stride, c10::SymInt dilation) { + return _pooling_same_mode_padding_lr(std::move(inputSize), std::move(kernelSize), std::move(stride), std::move(dilation)); +} + +// AveragePool2d/DilatedMaxPool2d (forward) +inline void +pool2d_shape_check( + const Tensor& input, + int64_t kH, int64_t kW, int64_t dH, int64_t dW, int64_t padH, int64_t padW, int64_t dilationH, int64_t dilationW, + int64_t nInputPlane, + int64_t inputHeight, int64_t inputWidth, + int64_t outputHeight, int64_t outputWidth, MemoryFormat memory_format) +{ + const int64_t ndim = input.ndimension(); +#ifndef STRIP_ERROR_MESSAGES + const int64_t nOutputPlane = nInputPlane; +#endif + + TORCH_CHECK(kW > 0 && kH > 0, + "kernel size should be greater than zero, but got ", + "kH: ", kH, " kW: ", kW); + TORCH_CHECK(dW > 0 && dH > 0, + "stride should be greater than zero, but got " + "dH: ", dH, " dW: ", dW); + TORCH_CHECK(dilationH > 0 && dilationW > 0, + "dilation should be greater than zero, but got ", + "dilationH: ", dilationH, " dilationW: ", dilationW); + + bool valid_dims = input.size(1) != 0 && input.size(2) != 0; + if (memory_format == at::MemoryFormat::ChannelsLast){ + // Expect tensor in NHWC format and allow 0-dim only for N. + TORCH_CHECK((ndim == 4 && valid_dims && input.size(3) != 0), + "Expected 4D (batch mode) tensor expected for input with channels_last layout" + " with optional 0 dim batch size for input, but got: ", input.sizes()); + } else { + TORCH_CHECK((ndim == 3 && input.size(0) != 0 && valid_dims) || + (ndim == 4 && valid_dims && input.size(3) != 0), + "Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got:", + input.sizes()); + } + + TORCH_CHECK(kW/2 >= padW && kH/2 >= padH, + "pad should be smaller than or equal to half of kernel size, but got ", + "padW = ", padW, ", padH = ", padH, ", kW = ", kW, ", kH = ", kH); + + TORCH_CHECK(outputWidth >= 1 && outputHeight >= 1, + "Given input size: (", + nInputPlane, "x", inputHeight, "x", inputWidth, "). ", + "Calculated output size: (", + nOutputPlane, "x", outputHeight, "x", outputWidth, "). ", + "Output size is too small"); +} + +// DilatedMaxPool2d (backward) +inline void +max_pool2d_backward_shape_check( + const Tensor& input, + const Tensor& gradOutput, + const Tensor& indices, + int kH, int kW, int dH, int dW, int padH, int padW, int dilationH, int dilationW, + int64_t nInputPlane, + int64_t inputHeight, int64_t inputWidth, + int64_t outputHeight, int64_t outputWidth, MemoryFormat memory_format) +{ + pool2d_shape_check( + input, + kH, kW, dH, dW, padH, padW, dilationH, dilationW, + nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, memory_format); + + const int64_t ndim = input.ndimension(); + const int64_t nOutputPlane = nInputPlane; + + check_dim_size(gradOutput, ndim, ndim-3, nOutputPlane); + check_dim_size(gradOutput, ndim, ndim-2, outputHeight); + check_dim_size(gradOutput, ndim, ndim-1, outputWidth); + + check_dim_size(indices, ndim, ndim-3, nOutputPlane); + check_dim_size(indices, ndim, ndim-2, outputHeight); + check_dim_size(indices, ndim, ndim-1, outputWidth); +} + +// AveragePool2d (backward) +inline void +avg_pool2d_backward_shape_check( + const Tensor& input, + const Tensor& gradOutput, + int64_t /*nbatch*/, + int kH, int kW, int dH, int dW, int padH, int padW, + int64_t nInputPlane, + int64_t inputHeight, int64_t inputWidth, + int64_t outputHeight, int64_t outputWidth, + MemoryFormat memory_format) +{ + pool2d_shape_check( + input, + kH, kW, dH, dW, padH, padW, 1, 1, + nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, + memory_format); + + const int64_t ndim = input.ndimension(); + const int64_t nOutputPlane = nInputPlane; + + check_dim_size(gradOutput, ndim, ndim-3, nOutputPlane); + check_dim_size(gradOutput, ndim, ndim-2, outputHeight); + check_dim_size(gradOutput, ndim, ndim-1, outputWidth); +} + +// AveragePool3d/DilatedMaxPool3d (forward) +inline void +pool3d_shape_check( + const Tensor& input, + int64_t nslices, + int kT, int kH, int kW, + int dT, int dH, int dW, + int pT, int pH, int pW, + int dilationT, int dilationH, int dilationW, + int64_t itime, int64_t iheight, int64_t iwidth, + int64_t otime, int64_t oheight, int64_t owidth, + const char *fn_name, + bool check_input_size=false) +{ + const int64_t ndim = input.ndimension(); + + TORCH_CHECK(kT > 0 && kW > 0 && kH > 0, + "kernel size should be greater than zero, but got ", + "kT: ", kT, " kH: ", kH, " kW: ", kW); + TORCH_CHECK(dT > 0 && dW > 0 && dH > 0, + "stride should be greater than zero, but got ", + "dT: ", dT, " dH: ", dH, " dW: ", dW); + TORCH_CHECK(dilationT > 0 && dilationW > 0 && dilationH > 0, + "dilation should be greater than zero, but got ", + "dilationT: ", dilationT, " dilationH: ", dilationH, " dilationW: ", dilationW); + + TORCH_CHECK(ndim == 4 || ndim == 5, + fn_name, ": Expected 4D or 5D tensor for input, but got: ", input.sizes()); + + for (const auto i : c10::irange(ndim)) { + if (ndim == 5 && i == 0) { + // size of batch-dim can be 0. + continue; + } + TORCH_CHECK( + input.size(i) > 0, + fn_name, + ": Expected input's non-batch dimensions to have positive length," + " but input has a shape of ", + input.sizes(), + " and non-batch dimension ", + input.size(i), + " has length zero!") + } + + if (check_input_size) { // AveragePool3d + TORCH_CHECK(itime >= kT && iheight >= kH && iwidth >= kW, + "input image ", "(T: ", itime, " H: ", iheight, " W: ", iwidth, ") smaller than ", + "kernel size ", "(kT: ", kT, " kH: ", kH, " kW: ", kW, ")"); + } + + TORCH_CHECK(kT/2 >= pT && kW/2 >= pW && kH/2 >= pH, + "pad should be smaller than or equal to half of kernel size, but got " + "kT: ", kT, " kW: ", kW, " kH: ", kH, " padT: ", pT, " padW: ", pW, " padH: ", pH); + + TORCH_CHECK(otime >= 1 && owidth >= 1 && oheight >= 1, + "Given input size: (", + nslices,"x", itime, "x", iheight, "x", iwidth, "). ", + "Calculated output size: (", + nslices, "x", otime, "x", oheight, "x", owidth, "). ", + "Output size is too small"); +} + +inline void +max_pool3d_backward_shape_check( + const Tensor& input, + const Tensor& gradOutput, + const Tensor& indices, + int64_t nslices, + int kT, int kH, int kW, + int dT, int dH, int dW, + int pT, int pH, int pW, + int dilationT, int dilationH, int dilationW, + int64_t itime, int64_t iheight, int64_t iwidth, + int64_t otime, int64_t oheight, int64_t owidth, + const char* fn_name) +{ + const int64_t ndim = input.ndimension(); + + pool3d_shape_check( + input, + nslices, + kT, kH, kW, + dT, dH, dW, + pT, pH, pW, + dilationT, dilationH, dilationW, + itime, iheight, iwidth, + otime, oheight, owidth, fn_name); + + check_dim_size(gradOutput, ndim, ndim-4, nslices); + check_dim_size(gradOutput, ndim, ndim-3, otime); + check_dim_size(gradOutput, ndim, ndim-2, oheight); + check_dim_size(gradOutput, ndim, ndim-1, owidth); + + check_dim_size(indices, ndim, ndim-4, nslices); + check_dim_size(indices, ndim, ndim-3, otime); + check_dim_size(indices, ndim, ndim-2, oheight); + check_dim_size(indices, ndim, ndim-1, owidth); +} + +inline void +avg_pool3d_backward_shape_check( + const Tensor& input, + const Tensor& gradOutput, + int64_t nslices, + int kT, int kH, int kW, + int dT, int dH, int dW, + int pT, int pH, int pW, + int64_t itime, int64_t iheight, int64_t iwidth, + int64_t otime, int64_t oheight, int64_t owidth, + const char *fn_name) +{ + const int64_t ndim = input.ndimension(); + + pool3d_shape_check( + input, + nslices, + kT, kH, kW, + dT, dH, dW, + pT, pH, pW, + 1, 1, 1, + itime, iheight, iwidth, + otime, oheight, owidth, + fn_name, true); + + check_dim_size(gradOutput, ndim, ndim-4, nslices); + check_dim_size(gradOutput, ndim, ndim-3, otime); + check_dim_size(gradOutput, ndim, ndim-2, oheight); + check_dim_size(gradOutput, ndim, ndim-1, owidth); +} + +} // anonymous namespace + +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Pow.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Pow.h new file mode 100644 index 0000000000000000000000000000000000000000..76ddda846a59a6796ac42a5c480fbddc08e65d4f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Pow.h @@ -0,0 +1,69 @@ +#pragma once + +#include + +namespace c10 { +class Scalar; +} + +namespace at { + +struct TensorIterator; +struct TensorIteratorBase; + +namespace native { + +#if defined(__CUDACC__) || defined(__HIPCC__) +#define HOST_DEVICE __host__ __device__ +#else +#define HOST_DEVICE +#endif + +// integral power in pytorch allows for negative exponents, giving truncated integral results. +// e.g. since 2**-1==0.5, the truncated integral result is zero. 1**negative_exponent is the +// only non-zero result. +template ::value, T>::type* = nullptr> +inline HOST_DEVICE __ubsan_ignore_signed_int_overflow__ T powi_impl(T a, T b) { + T result = 1; + while (b) { + if (b & 1) { + result *= a; + } + b /= 2; + a *= a; + } + return result; +} + +template ::value && !std::is_signed::value, T>::type* = nullptr> +inline HOST_DEVICE T powi(T a, T b) { + return powi_impl(a, b); +} + +template ::value && std::is_signed::value, T>::type* = nullptr> +inline HOST_DEVICE T powi(T a, T b) { + if ( b < 0 ) { + if ( a == 1 ) { + return 1; + } else if ( a == -1 ) { + auto negative = (-b) % static_cast(2); + return negative ? -1 : 1; + } else { + return 0; + } + } + return powi_impl(a, b); +} + +using pow_tensor_tensor_fn = void (*)(TensorIteratorBase&); +using pow_tensor_scalar_fn = void (*)(TensorIteratorBase&, const c10::Scalar&); + +DECLARE_DISPATCH(pow_tensor_tensor_fn, pow_tensor_tensor_stub); +DECLARE_DISPATCH(pow_tensor_scalar_fn, pow_tensor_scalar_stub); + +} // namespace native + +} // namespace at diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/RNN.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/RNN.h new file mode 100644 index 0000000000000000000000000000000000000000..f3e54c2a40b425eb07fdecdb1164690bf78b4996 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/RNN.h @@ -0,0 +1,53 @@ +#pragma once + +#include +#include + +namespace at::native { + +using lstm_fn = void(*)(Tensor&, Tensor&, Tensor&, const Tensor&, TensorList, TensorList, bool, int64_t, double, bool, bool, bool); +using rnn_fn = void(*)(Tensor&, Tensor&, const Tensor&, const Tensor&, TensorList, bool, int64_t, double, bool, bool, bool); +using lstm_packed_fn = void(*)(Tensor&, Tensor&, Tensor&, const Tensor&, const Tensor&, TensorList, TensorList, bool, int64_t, double, bool, bool); +using rnn_packed_fn = void(*)(Tensor&, Tensor&, const Tensor&, const Tensor&, const Tensor&, TensorList, bool, int64_t, double, bool, bool); + +DECLARE_DISPATCH(lstm_fn, lstm_cudnn_stub); +DECLARE_DISPATCH(lstm_fn, lstm_miopen_stub); +DECLARE_DISPATCH(lstm_fn, lstm_mkldnn_stub); +DECLARE_DISPATCH(rnn_fn, gru_cudnn_stub); +DECLARE_DISPATCH(rnn_fn, gru_miopen_stub); +DECLARE_DISPATCH(rnn_fn, rnn_tanh_cudnn_stub); +DECLARE_DISPATCH(rnn_fn, rnn_tanh_miopen_stub); +DECLARE_DISPATCH(rnn_fn, rnn_relu_cudnn_stub); +DECLARE_DISPATCH(rnn_fn, rnn_relu_miopen_stub); +DECLARE_DISPATCH(lstm_packed_fn, lstm_packed_cudnn_stub); +DECLARE_DISPATCH(lstm_packed_fn, lstm_packed_miopen_stub); +DECLARE_DISPATCH(rnn_packed_fn, gru_packed_cudnn_stub); +DECLARE_DISPATCH(rnn_packed_fn, gru_packed_miopen_stub); +DECLARE_DISPATCH(rnn_packed_fn, rnn_tanh_packed_cudnn_stub); +DECLARE_DISPATCH(rnn_packed_fn, rnn_tanh_packed_miopen_stub); +DECLARE_DISPATCH(rnn_packed_fn, rnn_relu_packed_cudnn_stub); +DECLARE_DISPATCH(rnn_packed_fn, rnn_relu_packed_miopen_stub); + +inline void check_attributes(const Tensor& input, const TensorList& params, const TensorList& hiddens, bool check_dtype=false) { + auto input_device = input.device(); + auto input_dtype = input.scalar_type(); + + auto check_tensors = [&](const std::string& name, const Tensor& t) { + if (!t.defined()) return; + auto t_device = t.device(); + TORCH_CHECK(input_device == t_device, + "Input and ", name, " tensors are not at the same device, found input tensor at ", + input_device, " and ", name, " tensor at ", t_device); + if (check_dtype) { + auto t_dtype = t.scalar_type(); + TORCH_CHECK(input_dtype == t_dtype, + "Input and ", name, " tensors are not the same dtype, found input tensor with ", + input_dtype, " and ", name, " tensor with ", t_dtype); + } + }; + + for (const auto& h : hiddens) check_tensors("hidden", h); + for (const auto& p : params) check_tensors("parameter", p); +} + +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ReduceOps.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ReduceOps.h new file mode 100644 index 0000000000000000000000000000000000000000..af68dee7709a4728af767b65b0b0ef80cd28b89f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ReduceOps.h @@ -0,0 +1,56 @@ +#pragma once + +#include +#include +#include + +namespace c10 { +class Scalar; +} + +namespace at { +struct TensorIterator; +class Tensor; +} + +namespace at::native { + +using reduce_fn = void(*)(TensorIterator &); + +DECLARE_DISPATCH(reduce_fn, sum_stub); +DECLARE_DISPATCH(reduce_fn, nansum_stub); +DECLARE_DISPATCH(reduce_fn, prod_stub); +DECLARE_DISPATCH(reduce_fn, mean_stub); +DECLARE_DISPATCH(reduce_fn, and_stub); +DECLARE_DISPATCH(reduce_fn, or_stub); +DECLARE_DISPATCH(reduce_fn, min_values_stub); +DECLARE_DISPATCH(reduce_fn, max_values_stub); +DECLARE_DISPATCH(reduce_fn, argmax_stub); +DECLARE_DISPATCH(reduce_fn, argmin_stub); + +using reduce_std_var_function = + void (*)(TensorIterator&, double correction, bool take_sqrt); +DECLARE_DISPATCH(reduce_std_var_function, std_var_stub); + +using reduce_norm_fn = + void (*)(Tensor&, const Tensor&, const c10::Scalar&, std::optional); +DECLARE_DISPATCH(reduce_norm_fn, norm_kernel); + +using reduce_fn_flag = void(*)(TensorIterator &, const c10::Scalar&); +DECLARE_DISPATCH(reduce_fn_flag, norm_stub); + +using structured_cum_fn = void (*)(const Tensor&, const Tensor&, int64_t); +using cum_fn = void (*)(Tensor&, const Tensor&, int64_t); +DECLARE_DISPATCH(structured_cum_fn, cumsum_stub); +DECLARE_DISPATCH(structured_cum_fn, cumprod_stub); +DECLARE_DISPATCH(cum_fn, logcumsumexp_stub); + +DECLARE_DISPATCH(void (*)(const Tensor&, int64_t, bool, Tensor&, Tensor&), aminmax_stub); +DECLARE_DISPATCH(void (*)(const Tensor&, Tensor&, Tensor&), aminmax_allreduce_stub); + +// Used in cuda/Normalization.cu +TORCH_API std::tuple var_mean_out( + Tensor &result1, Tensor &result2, const Tensor &self, IntArrayRef dim, + int64_t correction, bool keepdim); + +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ReduceOpsUtils.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ReduceOpsUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..928853ed44ca54aeed79b991d5feadee590d9cad --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ReduceOpsUtils.h @@ -0,0 +1,455 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#endif + +namespace at::native { + +// Maximum and minimum possible scalar values, including infinities +template +constexpr scalar_t upper_bound() { + using lim = std::numeric_limits; + return lim::has_infinity ? lim::infinity() : lim::max(); +} + +template +constexpr scalar_t lower_bound() { + using lim = std::numeric_limits; + return lim::has_infinity ? -lim::infinity() : lim::lowest(); +} + +inline Tensor restride_dim( + const Tensor& src, int64_t dim, + IntArrayRef replacement_shape +) { + auto strides = ensure_nonempty_vec(src.strides().vec()); + strides[dim] = 0; + return src.as_strided(replacement_shape, strides); +} + +inline void _dimreduce_setup(const Tensor &result, const Tensor &self, + int64_t dim) { + IntArrayRef self_sizes = self.sizes(); + std::vector result_sizes; + result_sizes.insert(result_sizes.end(), self_sizes.begin(), self_sizes.end()); + result_sizes[dim] = 1; + result.resize_(result_sizes); +} + +inline bool _dimreduce_return_trivial(const Tensor &result, const Tensor &self, + const Scalar& ident, int64_t dim, bool keepdim) { + if (self.numel() == 1 && self.ndimension() == 0) { + result.resize_({}); + result.fill_(self); + return true; + } + // Return identity + if (self.numel() == 0) { + _dimreduce_setup(result, self, dim); + result.fill_(ident); + if (!keepdim) result.squeeze_(dim); + return true; + } + return false; +} + +inline bool _dimreduce_return_trivial_no_ident(Tensor &result, const Tensor &self, + int64_t /*dim*/, bool /*keepdim*/, const char* /*fn_name*/) { + if (self.numel() == 1 && self.ndimension() == 0) { + result.resize_({}); + result.fill_(self); + return true; + } + + return false; +} + +inline std::optional _allreduce_return_trivial( + const Tensor& self, + const Scalar& ident) { + // Return identity + if (self.numel() == 0) { + return at::scalar_tensor(ident, self.options()); + } + return std::nullopt; +} + +#define OPTION_TYPE_EQUALITY_CHECK(option, out, self) \ +{ \ + TORCH_CHECK(\ + out.option() == self.option(),\ + "expected ", #option, " ",\ + self.option(),\ + " but found ", out.option())\ +} + +inline void check_scalar_type_device_layout_equal(const Tensor& out, const Tensor& self) { + OPTION_TYPE_EQUALITY_CHECK(scalar_type, out, self); + OPTION_TYPE_EQUALITY_CHECK(device, out.options(), self.options()); + OPTION_TYPE_EQUALITY_CHECK(layout, out.options(), self.options()); +} + +inline Tensor integer_upcast(const Tensor& self, std::optional dtype) { + ScalarType scalarType = self.scalar_type(); + TORCH_CHECK(!isBarebonesUnsignedType(scalarType), "integer upcasting for uint16, uint32 and uint64 is not currently implemented"); + ScalarType upcast_scalarType = dtype.value_or(at::isIntegralType(scalarType, /*includeBool=*/true) ? ScalarType::Long : scalarType); + return self.toType(upcast_scalarType); +} + +using DimMask = TensorIterator::DimMask; + +inline DimVector make_dim_vector(OptionalIntArrayRef opt_dims, int64_t ndim) { + if (opt_dims.has_value()) { + return DimVector(opt_dims.value()); + } else { + std::vector all_dims(ndim); + std::iota(all_dims.begin(), all_dims.end(), 0); + return DimVector(all_dims); + } +} + +inline DimMask make_dim_mask(OptionalIntArrayRef opt_dims, int64_t ndim, bool allow_empty_dims=false) { + DimMask mask; + if (opt_dims.has_value()) { + auto dims = opt_dims.value(); + if (dims.empty() && !allow_empty_dims) { + mask = DimMask().flip(); + } else { + mask = at::dim_list_to_bitset(dims, ndim); + } + } else { + mask = DimMask().flip(); + } + return mask; +} + +inline DimVector shape_from_dim_mask(const Tensor& self, DimMask mask, bool keepdim) { + auto shape = DimVector(self.sizes()); + for (int dim = shape.size() - 1; dim >= 0; dim--) { + if (mask[dim]) { + if (keepdim) { + shape[dim] = 1; + } else { + shape.erase(shape.begin() + dim); + } + } + } + return shape; +} + +inline void resize_reduction_result( + Tensor& result, const Tensor& self, DimMask mask, bool keepdim, + ScalarType /*dtype*/) +{ + auto shape = shape_from_dim_mask(self, mask, keepdim); + TORCH_CHECK(result.defined(), "Cannot create a new tensor inside a reduction op. You likely tried to call an operator with an out argument but the out argument was an undefined tensor."); + at::native::resize_output(result, shape); +} + +inline Tensor create_reduction_result( + const Tensor& self, at::OptionalIntArrayRef dim, bool keepdim, ScalarType dtype +) { + DimMask mask = make_dim_mask(dim, self.dim()); + auto shape = shape_from_dim_mask(self, mask, keepdim); + return at::empty(shape, self.options().dtype(dtype)); +} + +inline Tensor review_reduce_result(const Tensor& result, int ndim, DimMask mask, bool keepdim) { + if (keepdim) { + return result; + } + auto shape = DimVector(result.sizes()); + auto stride = DimVector(result.strides()); + for (const auto dim : c10::irange(ndim)) { + if (mask[dim]) { + shape.insert(shape.begin() + dim, 1); + stride.insert(stride.begin() + dim, 0); + } + } + return result.as_strided(shape, stride); +} + +inline TensorIterator make_reduction( + const char* name, Tensor& result, const Tensor& self, + at::OptionalIntArrayRef dim_opt, + bool keepdim, ScalarType in_dtype, ScalarType out_dtype) { + // check that result type and dtype match if provided + TORCH_CHECK( + !result.defined() || result.scalar_type() == out_dtype, + name, ": provided dtype must match dtype of result. Got ", + toString(result.scalar_type()), + " and ", + toString(out_dtype), + "."); + // dim={} performs an all-reduce, same as dim=None + IntArrayRef dim = dim_opt.value_or(IntArrayRef{}); + int64_t ndim = self.dim(); + auto mask = make_dim_mask(dim, ndim); + resize_reduction_result(result, self, mask, keepdim, out_dtype); + auto viewed_result = review_reduce_result(result, ndim, mask, keepdim); + namedinference::propagate_names_for_reduction(result, self, dim, keepdim); + if (self.scalar_type() == in_dtype) { + return TensorIterator::reduce_op(viewed_result, self); + } + return TensorIterator::reduce_op(viewed_result, self.to(in_dtype)); +} + +inline C10_UNUSED TensorIterator make_reduction( + const char* name, Tensor& result, const Tensor& self, + at::OptionalIntArrayRef dim, bool keepdim, ScalarType out_dtype) { + // special case for type promotion in mixed precision, improves computational + // efficiency. + // not generalize this to common mismatched input/output types to avoid cross + // product of templated kernel launches. + const bool gpu_lowp_to_f32 = ( + self.is_cuda() && (self.scalar_type() == kHalf || self.scalar_type() == kBFloat16) && out_dtype == kFloat); + auto in_dtype = gpu_lowp_to_f32 ? self.scalar_type() + : self.is_complex() ? c10::toComplexType(out_dtype) + : out_dtype; + return make_reduction(name, result, self, dim, keepdim, in_dtype, out_dtype); +} + +inline TensorIterator make_reduction( + const char* name, Tensor& result1, Tensor& result2, const Tensor& self, + at::OptionalIntArrayRef dim_opt, bool keepdim, ScalarType dtype1, + ScalarType dtype2) { + // check that result type and dtype match if provided + TORCH_CHECK( + (!result1.defined() || result1.scalar_type() == dtype1) && (!result2.defined() || result2.scalar_type() == dtype2), + name, ": provided dtype must match dtype of result. Got ", + toString(result1.scalar_type()), toString(result2.scalar_type()), + " and ", + toString(dtype1), toString(dtype2), + "."); + + // dim={} performs an all-reduce, same as dim=None + auto dim = dim_opt.value_or(IntArrayRef{}); + int64_t ndim = self.dim(); + DimMask mask = make_dim_mask(dim, ndim); + resize_reduction_result(result1, self, mask, keepdim, dtype1); + auto viewed_result1 = review_reduce_result(result1, ndim, mask, keepdim); + + resize_reduction_result(result2, self, mask, keepdim, dtype2); + auto viewed_result2 = review_reduce_result(result2, ndim, mask, keepdim); + + namedinference::propagate_names_for_reduction(result1, self, dim, keepdim); + namedinference::propagate_names_for_reduction(result2, self, dim, keepdim); + + // special case for type promotion in mixed precision, improves computational + // efficiency. + // We don't generalize this to common mismatched input/output types to avoid cross + // product of templated kernel launches. + if (self.scalar_type() == dtype1 || + (self.is_cuda() && self.scalar_type() == kHalf && dtype1 == kFloat)) { + return TensorIterator::reduce_op(viewed_result1, viewed_result2, self); + } + return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype1)); +} + +inline C10_UNUSED TensorIterator make_reduction( + const char* name, Tensor& result1, Tensor& result2, const Tensor& self, + at::OptionalIntArrayRef dim, bool keepdim, ScalarType dtype) { + return make_reduction(name, result1, result2, self, dim, keepdim, dtype, dtype); +} + +inline void zero_numel_check_dims(const Tensor& self, const int64_t dim, const char *fn_name) { + if (self.ndimension() == 0) { + TORCH_CHECK_INDEX(dim == 0 || dim == -1, fn_name, + ": Expected reduction dim -1 or 0 for scalar but got ", dim); + } + else { + TORCH_CHECK_INDEX(self.size(dim) != 0, fn_name, + ": Expected reduction dim ", dim, " to have non-zero size."); + } +} + +inline void zero_numel_check_dims(const Tensor& self, const IntArrayRef dim, const char *fn_name) { + TORCH_CHECK( + !dim.empty(), + fn_name, ": Expected reduction dim to be specified for input.numel() == 0. ", + "Specify the reduction dim with the 'dim' argument."); + for (const int64_t d : dim) { + zero_numel_check_dims(self, d, fn_name); + } +} + +inline std::vector get_zero_numel_tensor_size( + const Tensor& self, + const int64_t dim, + const bool keepdim, + const char* fn_name) { + TORCH_INTERNAL_ASSERT(self.numel() == 0, fn_name, ": Expected self.numel() == 0."); + zero_numel_check_dims(self, dim, fn_name); + std::vector sizes; + if (keepdim) { + sizes = self.sizes().vec(); + sizes[dim] = 1; + } + else { + for (const auto d : c10::irange(self.dim())) { + if (d != dim) { + sizes.push_back(self.sizes()[d]); + } + } + } + return sizes; +} + +// Resize the result tensor and indices when result.numel() == 0 depending on values of +// dim and keepdim for returning tensors containing reduction results. +// This function should be called when you are reducing a zero-numel tensor and want to +// resize the output and return it. This function exists for resizing zero-numel +// tensors when the size of the reduction dimension is non-zero. +inline C10_UNUSED void zero_numel_tensor_resize(Tensor& result, Tensor& result_indices, + const Tensor& self, const int64_t dim, + const bool keepdim, const char *fn_name) { + auto sizes = get_zero_numel_tensor_size(self, dim, keepdim, fn_name); + at::native::resize_output(result, sizes); + at::native::resize_output(result_indices, sizes); +} + +inline ScalarType get_dtype_from_self( + const Tensor& self, + const std::optional& dtype, + bool promote_integers) { + if (dtype.has_value()) { + return dtype.value(); + } + ScalarType src_type = self.scalar_type(); + if (promote_integers && at::isIntegralType(src_type, /*includeBool=*/true)) { + return kLong; + } + return src_type; +} + +inline ScalarType get_dtype_from_result(Tensor& result, std::optional dtype) { + TORCH_CHECK(result.defined(), "Cannot create a new tensor inside a reduction op. You likely tried to call an operator with an out argument but the out argument was an undefined tensor."); + if (dtype.has_value()) { + return dtype.value(); + } else { + return result.scalar_type(); + } +} + + +} // namespace at::native + +namespace at::meta { + +inline C10_UNUSED DimVector get_reduction_shape( + const Tensor& self, + IntArrayRef dims, + bool keepdim, + bool allow_empty_dims=false) { + auto mask = native::make_dim_mask(dims, self.dim(), allow_empty_dims); + return native::shape_from_dim_mask(self, mask, keepdim); +} + +inline void resize_reduction( + impl::MetaBase& meta, + const Tensor& self, + OptionalIntArrayRef opt_dims, + bool keepdim, + ScalarType out_dtype, + bool allow_empty_dims=false) { + DimVector dims_ = at::native::make_dim_vector(opt_dims, self.dim()); + maybe_wrap_dims(dims_, self.dim()); + auto shape = get_reduction_shape(self, dims_, keepdim, allow_empty_dims); + if (self.layout() == kStrided) { + meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype)); + } else if (shape.empty()) { + meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype).layout(kStrided)); + } else { + TORCH_CHECK(false, "resize_reduction: support for output with ", self.layout(), " layout is not implemented yet"); + } + namedinference::propagate_names_for_reduction( + meta.maybe_get_output(), self, dims_, keepdim); +} + +inline void resize_reduction_with_indices( + impl::MetaBase& meta, + const Tensor& self, + IntArrayRef dims, + bool keepdim, + ScalarType out_dtype) { + DimVector dims_(dims); + maybe_wrap_dims(dims_, self.dim()); + auto shape = get_reduction_shape(self, dims_, keepdim); + meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype)); + meta.set_output_raw_strided(1, shape, {}, self.options().dtype(kLong)); + namedinference::propagate_names_for_reduction( + meta.maybe_get_output(0), self, dims_, keepdim); + namedinference::propagate_names_for_reduction( + meta.maybe_get_output(1), self, dims_, keepdim); +} + +inline TensorIterator make_reduction( + const Tensor& self, + const Tensor& result, + OptionalIntArrayRef opt_dims, + bool keepdim, + ScalarType in_dtype) { + int64_t ndim = self.dim(); + auto mask = at::native::make_dim_mask(opt_dims, ndim); + auto viewed_result = + at::native::review_reduce_result(result, ndim, mask, keepdim); + if (self.scalar_type() == in_dtype) { + return TensorIterator::reduce_op(viewed_result, self); + } + return TensorIterator::reduce_op(viewed_result, self.to(in_dtype)); +} + +inline TensorIterator make_reduction( + const Tensor& self, + const Tensor& result1, + const Tensor& result2, + IntArrayRef dims, + bool keepdim, + ScalarType dtype1, + ScalarType /*dtype2*/) { + int64_t ndim = self.dim(); + auto mask = at::native::make_dim_mask(dims, ndim); + auto viewed_result1 = at::native::review_reduce_result(result1, ndim, mask, keepdim); + auto viewed_result2 = at::native::review_reduce_result(result2, ndim, mask, keepdim); + // special case for type promotion in mixed precision, improves computational efficiency. + // We don't generalize this to common mismatched input/output types to avoid cross product + // of templated kernel launches. + if (self.scalar_type() == dtype1 || + (self.is_cuda() && self.scalar_type() == kHalf && dtype1 == kFloat)) { + return TensorIterator::reduce_op(viewed_result1, viewed_result2, self); + } + return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype1)); +} + +inline C10_UNUSED TensorIterator make_reduction_from_out_ty( + const Tensor& self, + const Tensor& result, + OptionalIntArrayRef opt_dims, + bool keepdim, + ScalarType out_dtype) { + // special case for type promotion in mixed precision, improves computational + // efficiency. + // not generalize this to common mismatched input/output types to avoid cross + // product of templated kernel launches. + const bool gpu_lowp_to_f32 = + (self.is_cuda() && + (self.scalar_type() == kHalf || self.scalar_type() == kBFloat16) && + out_dtype == kFloat); + auto in_dtype = gpu_lowp_to_f32 ? self.scalar_type() : out_dtype; + return make_reduction(self, result, opt_dims, keepdim, in_dtype); +} + +} // namespace at::meta diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SegmentReduce.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SegmentReduce.h new file mode 100644 index 0000000000000000000000000000000000000000..25cd469ff5270a891e460f55239801a3fcf44c81 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SegmentReduce.h @@ -0,0 +1,50 @@ +#pragma once + +#include +#include +#include +#include + +namespace at { +class Tensor; + +namespace native { + +using segment_reduce_lengths_fn = Tensor (*)( + ReductionType, + const Tensor&, + const Tensor&, + int64_t, + const std::optional&); +DECLARE_DISPATCH(segment_reduce_lengths_fn, _segment_reduce_lengths_stub); + +using segment_reduce_offsets_fn = Tensor (*)( + ReductionType, + const Tensor&, + const Tensor&, + int64_t, + const std::optional&); +DECLARE_DISPATCH(segment_reduce_offsets_fn, _segment_reduce_offsets_stub); + +using segment_reduce_lengths_backward_fn = Tensor (*)( + const Tensor&, + const Tensor&, + const Tensor&, + ReductionType, + const Tensor&, + int64_t, + const std::optional&); +DECLARE_DISPATCH(segment_reduce_lengths_backward_fn, _segment_reduce_lengths_backward_stub); + +using segment_reduce_offsets_backward_fn = Tensor (*)( + const Tensor&, + const Tensor&, + const Tensor&, + ReductionType, + const Tensor&, + int64_t, + const std::optional&); +DECLARE_DISPATCH(segment_reduce_offsets_backward_fn, _segment_reduce_offsets_backward_stub); + +} // namespace native +} // namespace at diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SobolEngineOpsUtils.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SobolEngineOpsUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..17e42ebe84a0e8b0906a76ba9c937c6c46027caa --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SobolEngineOpsUtils.h @@ -0,0 +1,55 @@ +/// This file contains some tensor-agnostic operations to be used in the +/// core functions of the `SobolEngine` +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#include +#endif + +namespace at::native::sobol_utils { + +/// Function to return the minimum of number of bits to represent the integer `n` +inline int64_t bit_length(const int64_t n) { + int64_t nbits, nloc; + for (nloc = n, nbits = 0; nloc > 0; nloc /= 2, nbits++); + return nbits; +} + +/// Function to get the position of the rightmost zero in the bit representation of an integer +/// This value is the zero-indexed position +inline int64_t rightmost_zero(const int64_t n) { + int64_t z, i; + for (z = n, i = 0; z % 2 == 1; z /= 2, i++); + return i; +} + +/// Function to get a subsequence of bits in the representation of an integer starting from +/// `pos` and of length `length` +inline int64_t bitsubseq(const int64_t n, const int64_t pos, const int64_t length) { + return (n >> pos) & ((1 << length) - 1); +} + +/// Function to perform the inner product between a batched square matrix and a power of 2 vector +inline at::Tensor cdot_pow2(const at::Tensor& bmat) { + at::Tensor inter = at::arange(bmat.size(-1) - 1, -1, -1, bmat.options()); + inter = at::pow(2, inter).expand_as(bmat); + return at::mul(inter, bmat).sum(-1); +} + +/// All definitions below this point are data. These are constant, and should not be modified +/// without notice + +constexpr int64_t MAXDIM = 21201; +constexpr int64_t MAXDEG = 18; +constexpr int64_t MAXBIT = 30; +constexpr int64_t LARGEST_NUMBER = 1 << MAXBIT; +constexpr float RECIPD = 1.0 / LARGEST_NUMBER; + +extern const int64_t poly[MAXDIM]; +extern const int64_t initsobolstate[MAXDIM][MAXDEG]; + +} // namespace at::native::sobol_utils diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SpectralOpsUtils.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SpectralOpsUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..279e4ff59556793709e864ef79352275f1d148cf --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SpectralOpsUtils.h @@ -0,0 +1,84 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { + +// Normalization types used in _fft_with_size +enum class fft_norm_mode { + none, // No normalization + by_root_n, // Divide by sqrt(signal_size) + by_n, // Divide by signal_size +}; + +// NOTE [ Fourier Transform Conjugate Symmetry ] +// +// Real-to-complex Fourier transform satisfies the conjugate symmetry. That is, +// assuming X is the transformed K-dimensionsal signal, we have +// +// X[i_1, ..., i_K] = X[j_i, ..., j_K]*, +// +// where j_k = (N_k - i_k) mod N_k, N_k being the signal size at dim k, +// * is the conjugate operator. +// +// Therefore, in such cases, FFT libraries return only roughly half of the +// values to avoid redundancy: +// +// X[:, :, ..., :floor(N / 2) + 1] +// +// This is also the assumption in cuFFT and MKL. In ATen SpectralOps, such +// halved signal will also be returned by default (flag onesided=True). +// The following infer_ft_real_to_complex_onesided_size function calculates the +// onesided size from the twosided size. +// +// Note that this loses some information about the size of signal at last +// dimension. E.g., both 11 and 10 maps to 6. Hence, the following +// infer_ft_complex_to_real_onesided_size function takes in optional parameter +// to infer the twosided size from given onesided size. +// +// cuFFT doc: http://docs.nvidia.com/cuda/cufft/index.html#multi-dimensional +// MKL doc: https://software.intel.com/en-us/mkl-developer-reference-c-dfti-complex-storage-dfti-real-storage-dfti-conjugate-even-storage#CONJUGATE_EVEN_STORAGE + +inline int64_t infer_ft_real_to_complex_onesided_size(int64_t real_size) { + return (real_size / 2) + 1; +} + +inline int64_t infer_ft_complex_to_real_onesided_size(int64_t complex_size, + int64_t expected_size=-1) { + int64_t base = (complex_size - 1) * 2; + if (expected_size < 0) { + return base + 1; + } else if (base == expected_size) { + return base; + } else if (base + 1 == expected_size) { + return base + 1; + } else { + std::ostringstream ss; + ss << "expected real signal size " << expected_size << " is incompatible " + << "with onesided complex frequency size " << complex_size; + AT_ERROR(ss.str()); + } +} + +using fft_fill_with_conjugate_symmetry_fn = + void (*)(ScalarType dtype, IntArrayRef mirror_dims, IntArrayRef half_sizes, + IntArrayRef in_strides, const void* in_data, + IntArrayRef out_strides, void* out_data); +DECLARE_DISPATCH(fft_fill_with_conjugate_symmetry_fn, fft_fill_with_conjugate_symmetry_stub); + +// In real-to-complex transform, cuFFT and MKL only fill half of the values +// due to conjugate symmetry. This function fills in the other half of the full +// fft by using the Hermitian symmetry in the signal. +// self should be the shape of the full signal and dims.back() should be the +// one-sided dimension. +// See NOTE [ Fourier Transform Conjugate Symmetry ] +TORCH_API void _fft_fill_with_conjugate_symmetry_(const Tensor& self, IntArrayRef dims); + +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/StridedRandomAccessor.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/StridedRandomAccessor.h new file mode 100644 index 0000000000000000000000000000000000000000..ad8f6d7b8830f10575e75c65758cc846cc800ae6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/StridedRandomAccessor.h @@ -0,0 +1,301 @@ +#pragma once + +namespace at::native { + +// (Const)StridedRandomAccessor is a +// (const) random access iterator defined over +// a strided array. + +// The traits below are to introduce __restrict__ +// modifier on different platforms. + +template +struct DefaultPtrTraits { + using PtrType = T*; +}; + +#if (defined(_WIN32) || defined(_WIN64)) +#define RESTRICT __restrict +#else +#define RESTRICT __restrict__ +#endif + +template +struct RestrictPtrTraits { + using PtrType = T* RESTRICT; +}; + +template < + typename T, + typename index_t = int64_t, + template class PtrTraits = DefaultPtrTraits +> +class ConstStridedRandomAccessor { +public: + using difference_type = index_t; + using value_type = const T; + using pointer = const typename PtrTraits::PtrType; + using reference = const value_type&; + using iterator_category = std::random_access_iterator_tag; + + using PtrType = typename PtrTraits::PtrType; + using index_type = index_t; + + // Constructors { + C10_HOST_DEVICE + ConstStridedRandomAccessor(PtrType ptr, index_t stride) + : ptr{ptr}, stride{stride} + {} + + C10_HOST_DEVICE + explicit ConstStridedRandomAccessor(PtrType ptr) + : ptr{ptr}, stride{static_cast(1)} + {} + + C10_HOST_DEVICE + ConstStridedRandomAccessor() + : ptr{nullptr}, stride{static_cast(1)} + {} + // } + + // Pointer-like operations { + C10_HOST_DEVICE + reference operator*() const { + return *ptr; + } + + C10_HOST_DEVICE + const value_type* operator->() const { + return reinterpret_cast(ptr); + } + + C10_HOST_DEVICE + reference operator[](index_t idx) const { + return ptr[idx * stride]; + } + // } + + // Prefix/postfix increment/decrement { + C10_HOST_DEVICE + ConstStridedRandomAccessor& operator++() { + ptr += stride; + return *this; + } + + C10_HOST_DEVICE + ConstStridedRandomAccessor operator++(int) { + ConstStridedRandomAccessor copy(*this); + ++*this; + return copy; + } + + C10_HOST_DEVICE + ConstStridedRandomAccessor& operator--() { + ptr -= stride; + return *this; + } + + C10_HOST_DEVICE + ConstStridedRandomAccessor operator--(int) { + ConstStridedRandomAccessor copy(*this); + --*this; + return copy; + } + // } + + // Arithmetic operations { + C10_HOST_DEVICE + ConstStridedRandomAccessor& operator+=(index_t offset) { + ptr += offset * stride; + return *this; + } + + C10_HOST_DEVICE + ConstStridedRandomAccessor operator+(index_t offset) const { + return ConstStridedRandomAccessor(ptr + offset * stride, stride); + } + + C10_HOST_DEVICE + friend ConstStridedRandomAccessor operator+( + index_t offset, + const ConstStridedRandomAccessor& accessor + ) { + return accessor + offset; + } + + C10_HOST_DEVICE + ConstStridedRandomAccessor& operator-=(index_t offset) { + ptr -= offset * stride; + return *this; + } + + C10_HOST_DEVICE + ConstStridedRandomAccessor operator-(index_t offset) const { + return ConstStridedRandomAccessor(ptr - offset * stride, stride); + } + + // Note that this operator is well-defined when `this` and `other` + // represent the same sequences, i.e. when + // 1. this.stride == other.stride, + // 2. |other - this| / this.stride is an Integer. + C10_HOST_DEVICE + difference_type operator-(const ConstStridedRandomAccessor& other) const { + return (ptr - other.ptr) / stride; + } + // } + + // Comparison operators { + C10_HOST_DEVICE + bool operator==(const ConstStridedRandomAccessor& other) const { + return (ptr == other.ptr) && (stride == other.stride); + } + + C10_HOST_DEVICE + bool operator!=(const ConstStridedRandomAccessor& other) const { + return !(*this == other); + } + + C10_HOST_DEVICE + bool operator<(const ConstStridedRandomAccessor& other) const { + return ptr < other.ptr; + } + + C10_HOST_DEVICE + bool operator<=(const ConstStridedRandomAccessor& other) const { + return (*this < other) || (*this == other); + } + + C10_HOST_DEVICE + bool operator>(const ConstStridedRandomAccessor& other) const { + return !(*this <= other); + } + + C10_HOST_DEVICE + bool operator>=(const ConstStridedRandomAccessor& other) const { + return !(*this < other); + } + // } + +protected: + PtrType ptr; + index_t stride; +}; + +template < + typename T, + typename index_t = int64_t, + template class PtrTraits = DefaultPtrTraits +> +class StridedRandomAccessor + : public ConstStridedRandomAccessor { +public: + using difference_type = index_t; + using value_type = T; + using pointer = typename PtrTraits::PtrType; + using reference = value_type&; + + using BaseType = ConstStridedRandomAccessor; + using PtrType = typename PtrTraits::PtrType; + + // Constructors { + C10_HOST_DEVICE + StridedRandomAccessor(PtrType ptr, index_t stride) + : BaseType(ptr, stride) + {} + + C10_HOST_DEVICE + explicit StridedRandomAccessor(PtrType ptr) + : BaseType(ptr) + {} + + C10_HOST_DEVICE + StridedRandomAccessor() + : BaseType() + {} + // } + + // Pointer-like operations { + C10_HOST_DEVICE + reference operator*() const { + return *this->ptr; + } + + C10_HOST_DEVICE + value_type* operator->() const { + return reinterpret_cast(this->ptr); + } + + C10_HOST_DEVICE + reference operator[](index_t idx) const { + return this->ptr[idx * this->stride]; + } + // } + + // Prefix/postfix increment/decrement { + C10_HOST_DEVICE + StridedRandomAccessor& operator++() { + this->ptr += this->stride; + return *this; + } + + C10_HOST_DEVICE + StridedRandomAccessor operator++(int) { + StridedRandomAccessor copy(*this); + ++*this; + return copy; + } + + C10_HOST_DEVICE + StridedRandomAccessor& operator--() { + this->ptr -= this->stride; + return *this; + } + + C10_HOST_DEVICE + StridedRandomAccessor operator--(int) { + StridedRandomAccessor copy(*this); + --*this; + return copy; + } + // } + + // Arithmetic operations { + C10_HOST_DEVICE + StridedRandomAccessor& operator+=(index_t offset) { + this->ptr += offset * this->stride; + return *this; + } + + C10_HOST_DEVICE + StridedRandomAccessor operator+(index_t offset) const { + return StridedRandomAccessor(this->ptr + offset * this->stride, this->stride); + } + + C10_HOST_DEVICE + friend StridedRandomAccessor operator+( + index_t offset, + const StridedRandomAccessor& accessor + ) { + return accessor + offset; + } + + C10_HOST_DEVICE + StridedRandomAccessor& operator-=(index_t offset) { + this->ptr -= offset * this->stride; + return *this; + } + + C10_HOST_DEVICE + StridedRandomAccessor operator-(index_t offset) const { + return StridedRandomAccessor(this->ptr - offset * this->stride, this->stride); + } + + // Note that here we call BaseType::operator- version + C10_HOST_DEVICE + difference_type operator-(const BaseType& other) const { + return (static_cast(*this) - other); + } + // } +}; + +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorCompare.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorCompare.h new file mode 100644 index 0000000000000000000000000000000000000000..b4dfa689b1d216cb697076781935afb81a587fae --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorCompare.h @@ -0,0 +1,49 @@ +#pragma once + +#include + +namespace c10 { +class Scalar; +} + +namespace at { +class Tensor; +struct TensorIterator; +struct TensorIteratorBase; +} + +namespace at::native { + +using reduce_minmax_fn = + void (*)(Tensor&, Tensor&, const Tensor&, int64_t, bool); +using structured_reduce_minmax_fn = + void (*)(const Tensor&, const Tensor&, const Tensor&, int64_t, bool); + +DECLARE_DISPATCH(structured_reduce_minmax_fn, max_stub); +DECLARE_DISPATCH(structured_reduce_minmax_fn, min_stub); + +using where_fn = void (*)(TensorIterator &); +DECLARE_DISPATCH(where_fn, where_kernel); + +using is_infinity_op_fn = void (*)(TensorIteratorBase &); +DECLARE_DISPATCH(is_infinity_op_fn, isposinf_stub); +DECLARE_DISPATCH(is_infinity_op_fn, isneginf_stub); + +using mode_fn = void (*)(Tensor&, Tensor&, const Tensor&, int64_t, bool); +DECLARE_DISPATCH(mode_fn, mode_stub); + +using clamp_tensor_fn = void (*)(TensorIteratorBase &); +DECLARE_DISPATCH(clamp_tensor_fn, clamp_stub); + +namespace detail { + enum class ClampLimits {Min, Max, MinMax}; +} + +DECLARE_DISPATCH(void (*)(TensorIteratorBase &, const c10::Scalar&, const c10::Scalar&), clamp_scalar_stub); +DECLARE_DISPATCH(void (*)(TensorIteratorBase &, c10::Scalar), clamp_min_scalar_stub); +DECLARE_DISPATCH(void (*)(TensorIteratorBase &, c10::Scalar), clamp_max_scalar_stub); + +using isin_default_fn = void (*)(const Tensor&, const Tensor&, bool, const Tensor&); +DECLARE_DISPATCH(isin_default_fn, isin_default_stub); + +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorConversions.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorConversions.h new file mode 100644 index 0000000000000000000000000000000000000000..8a3853230b15ca77fb7e148e2a24a87adf629856 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorConversions.h @@ -0,0 +1,26 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace at { + class Tensor; +namespace native { +bool to_will_alias( + const Tensor& self, + std::optional dtype, + std::optional layout, + std::optional device, + bool copy, + std::optional optional_memory_format); + +Tensor to_meta(const Tensor& tensor); +std::optional to_meta(const std::optional& tensor); +std::vector to_meta(at::ITensorListRef t_list); +Tensor dense_to_sparse_with_mask(const Tensor& self, const Tensor& mask, std::optional layout, OptionalIntArrayRef blocksize, std::optional dense_dim_opt); + +} // namespace native +} // namespace at diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorIterator.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorIterator.h new file mode 100644 index 0000000000000000000000000000000000000000..e55d2a58d709926a24467a0056323096e0890fa9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorIterator.h @@ -0,0 +1,2 @@ +#pragma once +#include diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorProperties.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorProperties.h new file mode 100644 index 0000000000000000000000000000000000000000..87aca85fb3af1da0db523300a8cb3b310a0a88ad --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorProperties.h @@ -0,0 +1,12 @@ +#pragma once + +// See NOTE: [Tensor vs. TensorBase] +namespace at { +class TensorBase; +} + +namespace at::native { + +TORCH_API bool cudnn_is_acceptable(const TensorBase& self); + +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorTransformations.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorTransformations.h new file mode 100644 index 0000000000000000000000000000000000000000..f69c27edb976a4157dca1b0e0a38d748cdef9848 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorTransformations.h @@ -0,0 +1,30 @@ +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +#include + +namespace at::native { + +static inline Tensor roll_common(const Tensor& self, IntArrayRef shifts, IntArrayRef dims) { + TORCH_CHECK(!shifts.empty(), "`shifts` required"); + if (dims.empty() && shifts.size() == 1) { + auto flattened = self.contiguous().view(self.numel()); + return roll(flattened, shifts[0], 0).view(self.sizes()); + } + TORCH_CHECK( + shifts.size() == dims.size(), + "shifts and dimensions must align. shifts: ", shifts.size(), ", dims:", dims.size() + ); + AT_ASSERT(dims.size() > 1); + auto tail_shifts = shifts.slice(1); + auto tail_dims = dims.slice(1); + auto first_dim_rolled = roll(self, shifts[0], dims[0]); + return at::roll(first_dim_rolled, tail_shifts, tail_dims); +} + +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TopKImpl.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TopKImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..0a11f5f4087536c928c6294e92ce6cae03bd1378 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TopKImpl.h @@ -0,0 +1,98 @@ +#pragma once +#include +#include + +namespace at::native { + +#ifdef CPU_CAPABILITY +inline namespace CPU_CAPABILITY { +#else +inline namespace DEFAULT { +#endif + +// Core topk loop, shared between CPU and QuantizedCPU +template +void topk_impl_loop( + const int64_t mode_values_stride, + const int64_t mode_indices_stride, + const int64_t tmp_values_stride, + const int64_t k, + const int64_t dim_size, + const bool largest, + const bool sorted, + char** data, const int64_t* strides, const int64_t n) { + + // If k is zero, then output values and indices are empty tensors + // So iterating over other dims is pointless + if (k == 0) { + return; + } + using elem_t = std::pair; + std::vector queue(dim_size); + for (const auto i : c10::irange(n)) { + TensorAccessor mode_values( + reinterpret_cast(data[0] + i * strides[0]), + &k, &mode_values_stride); + TensorAccessor mode_indices( + reinterpret_cast(data[1] + i * strides[1]), + &k, &mode_indices_stride); + TensorAccessor tmp_values( + reinterpret_cast(data[2] + i * strides[2]), + &dim_size, &tmp_values_stride); + + auto n_2 = dim_size; + auto use_partial_sort = k * 64 <= n_2; + + for (const auto j : c10::irange(n_2)) { + queue[j].first = tmp_values[j]; + queue[j].second = j; + } + + // we want nan to be sorted as top for numpy compatibility + if (use_partial_sort) { + if (largest) { + std::partial_sort(queue.begin(), queue.begin() + k, queue.end(), + [](const elem_t& x, const elem_t& y) -> bool { + return ((_isnan(x.first) && !_isnan(y.first)) || (x.first > y.first)); + }); + } else { + std::partial_sort(queue.begin(), queue.begin() + k, queue.end(), + [](const elem_t& x, const elem_t& y) -> bool { + return ((!_isnan(x.first) && _isnan(y.first)) || (x.first < y.first)); + }); + } + } else { + if (largest) { + std::nth_element(queue.begin(), queue.begin() + k - 1, queue.end(), + [](const elem_t& x, const elem_t& y) -> bool { + return ((_isnan(x.first) && !_isnan(y.first)) || (x.first > y.first)); + }); + if (sorted) { + std::sort(queue.begin(), queue.begin() + k - 1, + [](const elem_t& x, const elem_t& y) -> bool { + return ((_isnan(x.first) && !_isnan(y.first)) || (x.first > y.first)); + }); + } + } else { + std::nth_element(queue.begin(), queue.begin() + k -1, queue.end(), + [](const elem_t& x, const elem_t& y) -> bool { + return ((!_isnan(x.first) && _isnan(y.first)) || (x.first < y.first)); + }); + if (sorted) { + std::sort(queue.begin(), queue.begin() + k -1, + [](const elem_t& x, const elem_t& y) -> bool { + return ((!_isnan(x.first) && _isnan(y.first)) || (x.first < y.first)); + }); + } + } + } + + for (const auto j : c10::irange(k)) { + mode_values[j] = queue[j].first; + mode_indices[j] = queue[j].second; + } + } +} + +} // namespace CPU_CAPABILITY +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TransposeType.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TransposeType.h new file mode 100644 index 0000000000000000000000000000000000000000..603bf6fee60aa2bc1850fae3eb0dac73345d7fb9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TransposeType.h @@ -0,0 +1,23 @@ +#pragma once +#include + +namespace at::native { + +// Used as an interface between the different BLAS-like libraries +enum class TransposeType { + NoTranspose, + Transpose, + ConjTranspose, +}; + +// Transforms TransposeType into the BLAS / LAPACK format +static inline char to_blas(TransposeType trans) { + switch (trans) { + case TransposeType::Transpose: return 'T'; + case TransposeType::NoTranspose: return 'N'; + case TransposeType::ConjTranspose: return 'C'; + } + TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); +} + +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TypeProperties.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TypeProperties.h new file mode 100644 index 0000000000000000000000000000000000000000..2d4845c758461c3435c83eaf7cafa3ddd6c9d784 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TypeProperties.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include + +namespace at::native { + +struct ResultTypeState { + c10::ScalarType dimResult = ScalarType::Undefined; + c10::ScalarType wrappedResult = ScalarType::Undefined; + c10::ScalarType zeroResult = ScalarType::Undefined; +}; + +TORCH_API ResultTypeState update_result_type_state(const Tensor& tensor, const ResultTypeState& in_state); +TORCH_API ResultTypeState update_result_type_state(const Scalar& scalar, const ResultTypeState& in_state); +TORCH_API ScalarType result_type(const ResultTypeState& state); + +TORCH_API ScalarType result_type(ITensorListRef tensors); + +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/group_norm.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/group_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..1673df9253eec782328db0f673e6526b698641d0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/group_norm.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include + +namespace at { +class Tensor; + +namespace native { + +using forward_fn = void (*)( + const Tensor& /* X */, + const Tensor& /* gamma */, + const Tensor& /* beta */, + int64_t /* N */, + int64_t /* C */, + int64_t /* HxW */, + int64_t /* group */, + double /* eps */, + Tensor& /* Y */, + Tensor& /* mean */, + Tensor& /* rstd */); + +using backward_fn = void (*)( + const Tensor& /* dY */, + const Tensor& /* X */, + const Tensor& /* mean */, + const Tensor& /* rstd */, + const Tensor& /* gamma */, + int64_t /* N */, + int64_t /* C */, + int64_t /* HxW */, + int64_t /* group */, + Tensor& /* dX */, + Tensor& /* dgamma */, + Tensor& /* dbeta */); + +DECLARE_DISPATCH(forward_fn, GroupNormKernel); +DECLARE_DISPATCH(backward_fn, GroupNormBackwardKernel); + +} // namespace native +} // namespace at diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/im2col.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/im2col.h new file mode 100644 index 0000000000000000000000000000000000000000..df94723ab2a216dcc2f98a6a373b46cb239a42b3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/im2col.h @@ -0,0 +1,149 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +namespace at::native { + +template +static void im2col( + const T* data_im, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t output_height, + const int64_t output_width, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_h, + const int64_t pad_w, + const int64_t stride_h, + const int64_t stride_w, + const int64_t dilation_h, + const int64_t dilation_w, + T* data_col, + bool is_channels_last = false) { + const int64_t height_col = output_height; + const int64_t width_col = output_width; + const int64_t channels_col = channels * kernel_h * kernel_w; + + if (is_channels_last) { + at::parallel_for(0, height_col * width_col, 0, [&](int64_t begin, int64_t end) { + int64_t h_col{0}, w_col{0}; + data_index_init(begin, h_col, height_col, w_col, width_col); + + for (const auto i_col : c10::irange(begin, end)) { + for (const auto h_offset : c10::irange(kernel_h)) { + int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h; + for (const auto w_offset : c10::irange(kernel_w)) { + int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w; + + const T* slice_im = data_im + (h_im * width + w_im) * channels; + T* slice_col = data_col + (i_col * kernel_h * kernel_w + h_offset * kernel_w + w_offset) * channels; + + if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + std::copy_n(slice_im, channels, slice_col); + } else { + std::fill_n(slice_col, channels, T(0)); + } + } + } + + // move the next index + data_index_step(h_col, height_col, w_col, width_col); + } + }); + } else { + at::parallel_for(0, channels_col, 0, [&](int64_t begin, int64_t end) { + int64_t c_im{0}, h_offset{0}, w_offset{0}; + data_index_init(begin, c_im, channels, h_offset, kernel_h, w_offset, kernel_w); + + for (const auto c_col : c10::irange(begin, end)) { + for (const auto h_col : c10::irange(height_col)) { + int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h; + for (const auto w_col : c10::irange(width_col)) { + int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w; + data_col[(c_col * height_col + h_col) * width_col + w_col] = + (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) + ? data_im[(c_im * height + h_im) * width + w_im] + : static_cast(0); + } + } + + // move to the next index + data_index_step(c_im, channels, h_offset, kernel_h, w_offset, kernel_w); + } + }); + } +} + +template +static void col2im( + const T* data_col, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t output_height, + const int64_t output_width, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_h, + const int64_t pad_w, + const int64_t stride_h, + const int64_t stride_w, + const int64_t dilation_h, + const int64_t dilation_w, + T* data_im, + bool is_channels_last = false) { + std::fill_n(data_im, height * width * channels, T(0)); + + const int64_t height_col = output_height; + const int64_t width_col = output_width; + const int64_t channels_col = channels * kernel_h * kernel_w; + + if (is_channels_last) { + for (const auto h_col : c10::irange(height_col)) { + for (const auto w_col : c10::irange(width_col)) { + for (const auto h_offset : c10::irange(kernel_h)) { + int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h; + for (const auto w_offset : c10::irange(kernel_w)) { + int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w; + + T* slice_im = data_im + (h_im * width + w_im) * channels; + const T* slice_col = data_col + ((h_col * width_col + w_col) * kernel_h * kernel_w + + h_offset * kernel_w + w_offset) * channels; + + if (h_im >= 0 && h_im < height && w_im >= 0 && w_im < width) { + std::transform(slice_col, slice_col + channels, slice_im, slice_im, std::plus()); + } + } + } + } + } + } else { + for (const auto c_col : c10::irange(channels_col)) { + int64_t w_offset = c_col % kernel_w; + int64_t h_offset = (c_col / kernel_w) % kernel_h; + int64_t c_im = c_col / kernel_h / kernel_w; + + for (const auto h_col : c10::irange(height_col)) { + int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h; + for (const auto w_col : c10::irange(width_col)) { + int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w; + + if (h_im >= 0 && h_im < height && w_im >= 0 && w_im < width) + data_im[(c_im * height + h_im) * width + w_im] += + data_col[(c_col * height_col + h_col) * width_col + w_col]; + } + } + } + } +} + +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/layer_norm.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/layer_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..e35ccf8634bccb3acaf40b39e169fc7b881a5e11 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/layer_norm.h @@ -0,0 +1,106 @@ +#pragma once + +#include +#include +#include + +namespace at::native { + +namespace { + +C10_ALWAYS_INLINE std::pair _check_layer_norm_inputs( + const Tensor& input, + IntArrayRef normalized_shape, + const Tensor& weight /* optional */, + const Tensor& bias /* optional */) { + + const int normalized_ndim = normalized_shape.size(); + TORCH_CHECK( + normalized_ndim >= 1, + "Expected normalized_shape to be at least 1-dimensional, i.e., ", + "containing at least one element, but got normalized_shape = ", + normalized_shape); + TORCH_CHECK( + !weight.defined() || weight.sizes().equals(normalized_shape), + "Expected weight to be of same shape as normalized_shape, but got ", + "weight of shape ", + weight.sizes(), + " and normalized_shape = ", + normalized_shape); + TORCH_CHECK( + !bias.defined() || bias.sizes().equals(normalized_shape), + "Expected bias to be of same shape as normalized_shape, but got ", + "bias of shape ", + bias.sizes(), + " and normalized_shape = ", + normalized_shape); + + const auto input_shape = input.sizes(); + const auto input_ndim = input.dim(); + + if (input_ndim < normalized_ndim || + !input_shape.slice(input_ndim - normalized_ndim) + .equals(normalized_shape)) { + std::stringstream ss; + ss << "Given normalized_shape=" << normalized_shape + << ", expected input with shape [*"; + for (auto size : normalized_shape) { + ss << ", " << size; + } + ss << "], but got input of size" << input_shape; + AT_ERROR(ss.str()); + } + + const int axis = input_ndim - normalized_ndim; + const int64_t M = + c10::multiply_integers(input_shape.cbegin(), input_shape.cbegin() + axis); + const int64_t N = + c10::multiply_integers(input_shape.cbegin() + axis, input_shape.cend()); + + return std::make_pair(M, N); +} + +} // namespace + +void layer_norm_cpu_out( + at::Tensor& out, + const at::Tensor& input, + const Tensor& gamma, + const Tensor& beta, + double eps, + int64_t M, + int64_t N); + +Tensor rms_norm( + const Tensor& input, + IntArrayRef normalized_shape, + const std::optional& weight_opt /* optional */, + std::optional eps); + +using forward_fn = void (*)( + const Tensor& /* X */, + const Tensor& /* gamma */, + const Tensor& /* beta */, + int64_t /* M */, + int64_t /* N */, + double /* eps */, + Tensor* /* Y */, + Tensor* /* mean */, + Tensor* /* rstd */); + +using backward_fn = void (*)( + const Tensor& /* dY */, + const Tensor& /* X */, + const Tensor& /* mean */, + const Tensor& /* rstd */, + const Tensor& /* gamma */, + int64_t /* M */, + int64_t /* N */, + Tensor* /* dX */, + Tensor* /* dgamma */, + Tensor* /* dbeta */); + +DECLARE_DISPATCH(forward_fn, LayerNormKernel); +DECLARE_DISPATCH(backward_fn, LayerNormBackwardKernel); + +} // namespace at::native diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/verbose_wrapper.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/verbose_wrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..59d9682e345b4440e103a1f95c6da42208764aba --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/verbose_wrapper.h @@ -0,0 +1,8 @@ +#pragma once + +#include + +namespace torch::verbose { +TORCH_API int _mkl_set_verbose(int enable); +TORCH_API int _mkldnn_set_verbose(int level); +} // namespace torch::verbose diff --git a/.venv/lib/python3.11/site-packages/torch/include/ATen/native/vol2col.h b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/vol2col.h new file mode 100644 index 0000000000000000000000000000000000000000..fa5c46b8c52e874791337a30fc9d4f1e5ff3db1d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/include/ATen/native/vol2col.h @@ -0,0 +1,109 @@ +#pragma once + +#include + +namespace at::native { + +template +void vol2col( + const T* data_vol, + const int64_t channels, + const int64_t depth, + const int64_t height, + const int64_t width, + const int64_t depth_col, + const int64_t height_col, + const int64_t width_col, + const int64_t kT, + const int64_t kernel_height, + const int64_t kernel_width, + const int64_t pT, + const int64_t pH, + const int64_t pW, + const int64_t dT, + const int64_t dH, + const int64_t dW, + const int64_t dilationT, + const int64_t dilationH, + const int64_t dilationW, + T* data_col) { + int64_t c, t, h, w; + int64_t channels_col = channels * kT * kernel_height * kernel_width; + for (c = 0; c < channels_col; ++c) { + int64_t w_offset = c % kernel_width; + int64_t h_offset = (c / kernel_width) % kernel_height; + int64_t t_offset = (c / kernel_width / kernel_height) % kT; + int64_t c_vol = c / kT / kernel_height / kernel_width; + for (t = 0; t < depth_col; ++t) { + int64_t t_pad = t * dT - pT + t_offset * dilationT; + for (h = 0; h < height_col; ++h) { + int64_t h_pad = h * dH - pH + h_offset * dilationH; + for (w = 0; w < width_col; ++w) { + int64_t w_pad = w * dW - pW + w_offset * dilationW; + if (t_pad >= 0 && t_pad < depth && h_pad >= 0 && h_pad < height && + w_pad >= 0 && w_pad < width) + data_col[((c * depth_col + t) * height_col + h) * width_col + w] = + data_vol + [((c_vol * depth + t_pad) * height + h_pad) * width + + w_pad]; + else + data_col[((c * depth_col + t) * height_col + h) * width_col + w] = + 0; + } + } + } + } +} + +template +void col2vol( + const T* data_col, + const int64_t channels, + const int64_t depth, + const int64_t height, + const int64_t width, + const int64_t out_depth, + const int64_t out_height, + const int64_t out_width, + const int64_t kT, + const int64_t kernel_height, + const int64_t kernel_width, + const int64_t pT, + const int64_t pH, + const int64_t pW, + const int64_t dT, + const int64_t dH, + const int64_t dW, + const int64_t dilationT, + const int64_t dilationH, + const int64_t dilationW, + T* data_vol) { + memset(data_vol, 0, sizeof(T) * depth * height * width * channels); + int64_t depth_col = out_depth; + int64_t height_col = out_height; + int64_t width_col = out_width; + int64_t channels_col = channels * kT * kernel_height * kernel_width; + for (int64_t c = 0; c < channels_col; ++c) { + int64_t w_offset = c % kernel_width; + int64_t h_offset = (c / kernel_width) % kernel_height; + int64_t t_offset = (c / kernel_width / kernel_height) % kT; + int64_t c_vol = c / kT / kernel_height / kernel_width; + for (int64_t t = 0; t < depth_col; ++t) { + int64_t t_pad = t * dT - pT + t_offset * dilationT; + for (int64_t h = 0; h < height_col; ++h) { + int64_t h_pad = h * dH - pH + h_offset * dilationH; + for (int64_t w = 0; w < width_col; ++w) { + int64_t w_pad = w * dW - pW + w_offset * dilationW; + if (t_pad >= 0 && t_pad < depth && h_pad >= 0 && h_pad < height && + w_pad >= 0 && w_pad < width) + data_vol + [((c_vol * depth + t_pad) * height + h_pad) * width + w_pad] += + data_col + [((c * depth_col + t) * height_col + h) * width_col + w]; + } + } + } + } +} + +} // namespace at::native